diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2b1607ad..110578f0 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -167,6 +167,7 @@ def __init__( drop_last=( self.cfg.trainer.skip_last_batch if view == "train" else False ), + pin_memory=self.cfg.trainer.pin_memory, sampler=sampler if view == "train" else None, ) for view in ["train", "val", "test"] diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 44c00637..e3e4c0fb 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -283,6 +283,7 @@ class TrainerConfig(BaseModelExtraForbid): validation_interval: Literal[-1] | PositiveInt = 1 num_log_images: NonNegativeInt = 4 skip_last_batch: bool = True + pin_memory: bool = True log_sub_losses: bool = True save_top_k: Literal[-1] | NonNegativeInt = 3