diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py index eae3ec3..73c3dc0 100644 --- a/dmlcloud/pipeline.py +++ b/dmlcloud/pipeline.py @@ -55,6 +55,7 @@ def register_model( name: str, model: torch.nn.Module, use_ddp: bool = True, + sync_bn: bool = False, save_latest: bool = True, save_interval: Optional[int] = None, save_best: bool = False, @@ -63,9 +64,11 @@ def register_model( ): if name in self.models: raise ValueError(f'Model with name {name} already exists') + model = model.to(self.device) # Doing it in this order is important for SyncBN + if sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if use_ddp: - model = DistributedDataParallel(model, broadcast_buffers=False) - model = model.to(self.device) + model = DistributedDataParallel(model, broadcast_buffers=False, device_ids=[self.device], output_device=self.device) self.models[name] = model if verbose: