From 20bbc7565cb0ce0a75a951bcb9e872f2f4b8f4ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mats=20Sj=C3=B6berg?= Date: Thu, 9 Jan 2025 16:39:38 +0200 Subject: [PATCH] Further LUMI fixes --- benchmarks/pytorch_visionmodel_ddp.py | 10 +++++++--- pytorch-ddp.sh | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmarks/pytorch_visionmodel_ddp.py b/benchmarks/pytorch_visionmodel_ddp.py index 2780e5a..10ac58f 100644 --- a/benchmarks/pytorch_visionmodel_ddp.py +++ b/benchmarks/pytorch_visionmodel_ddp.py @@ -1,6 +1,7 @@ # Based on multiprocessing example from # https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html +import multiprocessing from datetime import datetime import argparse import os @@ -91,7 +92,7 @@ def train(args): shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler) - scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) + scaler = torch.amp.GradScaler('cuda', enabled=args.fp16) if verbose and args.fp16: print(f"Using fp16 (PyTorch automatic mixed precision)") @@ -144,7 +145,7 @@ def train(args): images = images.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) - with torch.cuda.amp.autocast(enabled=args.fp16): + with torch.amp.autocast('cuda', enabled=args.fp16): outputs = model(images) loss = criterion(outputs, labels) @@ -198,6 +199,8 @@ def train(args): if args.steps is not None and tot_steps >= args.steps: break + dur = datetime.now() - avg_start + if args.profiler: if args.profiler_format == 'json' and verbose: trace_datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') @@ -211,13 +214,14 @@ def train(args): if avg_start is None: print("WARNING: stopped before warmup steps done, not printing stats.") else: - dur = datetime.now() - avg_start print(f"Training completed in: {dur}") print(f"Images/sec: {avg_images*world_size/dur.total_seconds():.2f} " f"(average, skipping {args.warmup_steps} warmup steps)") def main(): + multiprocessing.set_start_method('spawn') + parser = argparse.ArgumentParser() parser.add_argument('--epochs', default=1, type=int, metavar='N', help='number of total epochs to run') diff --git a/pytorch-ddp.sh b/pytorch-ddp.sh index 7dc22c5..6581f00 100644 --- a/pytorch-ddp.sh +++ b/pytorch-ddp.sh @@ -11,6 +11,7 @@ if [ -n "$SIF" ]; then fi echo "PYTHON3=$PYTHON3" +echo "NCCL_NET_GDR_LEVEL=$NCCL_NET_GDR_LEVEL" SCRIPT="benchmarks/pytorch_visionmodel_ddp.py" IMAGENET_DATA=/scratch/dac/data/ilsvrc2012-torch-resized-new.tar