Skip to content

Commit

Permalink
feat: sync batchnorm support
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed May 10, 2024
1 parent 7597366 commit 85e8443
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions dmlcloud/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def register_model(
name: str,
model: torch.nn.Module,
use_ddp: bool = True,
sync_bn: bool = False,
save_latest: bool = True,
save_interval: Optional[int] = None,
save_best: bool = False,
Expand All @@ -63,9 +64,11 @@ def register_model(
):
if name in self.models:
raise ValueError(f'Model with name {name} already exists')
model = model.to(self.device) # Doing it in this order is important for SyncBN
if sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if use_ddp:
model = DistributedDataParallel(model, broadcast_buffers=False)
model = model.to(self.device)
model = DistributedDataParallel(model, broadcast_buffers=False, device_ids=[self.device], output_device=self.device)
self.models[name] = model

if verbose:
Expand Down

0 comments on commit 85e8443

Please sign in to comment.