diff --git a/modulus/experimental/resdiff/generate.py b/modulus/experimental/resdiff/generate.py index 878cf90032..1f7b56ff2f 100644 --- a/modulus/experimental/resdiff/generate.py +++ b/modulus/experimental/resdiff/generate.py @@ -47,9 +47,10 @@ try: from edmss import edm_sampler except ImportError: - raise ImportError("Please get the edm_sampler by running pip install git+https://github.com/mnabian/edmss.git") + raise ImportError("Please get the edm_sampler by running: pip install git+https://github.com/mnabian/edmss.git") from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper def unet_regression( net, latents, img_lr, class_labels=None, randn_like=torch.randn_like, @@ -213,17 +214,9 @@ def randint(self, *args, size, **kwargs): assert size[0] == len(self.generators) return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) - -def load_pickle(network_pkl): - - # Instantiate distributed manager - dist = DistributedManager() - - # Load network. - print(f'Loading network from "{network_pkl}"...') # TODO print on rank zero - with dnnlib.util.open_url(network_pkl, verbose=(dist.rank == 0)) as f: - print('torch.__version__', torch.__version__) - +def load_pickle(network_pkl, rank): + # Load network. + with dnnlib.util.open_url(network_pkl, verbose=(rank == 0)) as f: return pickle.load(f)['ema'] @@ -247,17 +240,10 @@ def parse_int_list(s): def get_config_file(data_type): config_root = os.getenv("CONFIG_ROOT", "configs") config_file = os.path.join(config_root, data_type + '.yaml') - print(config_file) return config_file -def get_dataset_and_sampler(data_type, data_config, config_file=None): - root = os.getenv("DATA_ROOT", "") - if config_file is None: - config_file = get_config_file(data_type) - params = YParams(config_file, config_name=data_config) - params["train_data_path"] = os.path.join(root, params["train_data_path"]) - print(params.train_data_path) +def get_dataset_and_sampler(data_type, params): if data_type == 'cwb': dataset = CWBDataset(params, params.test_data_path, train=False, task=opts.task) @@ -333,14 +319,15 @@ def get_dataset_and_sampler(data_type, data_config, config_file=None): def main(max_times: Optional[int], seeds: List[int], **kwargs): opts = dnnlib.EasyDict(kwargs) + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() - # wrapper class for distributed manager for print0. This will be removed when Modulus logging is implemented. - class DistributedManagerWrapper(DistributedManager): - def print0(self, *message): - if self.rank == 0: - print(*message) - - dist = DistributedManagerWrapper() + # Initialize logger. + logger = PythonLogger("generate") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging("generate.log") det_batch = None gen_batch = None @@ -349,26 +336,32 @@ def print0(self, *message): if det_batch is None: det_batch = 1 #max(gen_batch, 64) assert det_batch % gen_batch == 0 - print('opts.data_config', opts.data_config) + logger0.info(f'opts.data_config: {opts.data_config}') # Data config_file = get_config_file(opts.data_type) + logger0.info(f"Config file: {config_file}") params = YParams(config_file, config_name=opts.data_config) patch_size = params.patch_size crop_size_x = params.crop_size_x crop_size_y = params.crop_size_y - dataset, sampler = get_dataset_and_sampler(opts.data_type, opts.data_config) + root = os.getenv("DATA_ROOT", "") + params["train_data_path"] = os.path.join(root, params["train_data_path"]) + logger0.info(f"Train data path: {params.train_data_path}") + dataset, sampler = get_dataset_and_sampler(opts.data_type, params) + with nc.Dataset(opts.outdir.format(rank=dist.rank), "w") as f: # add attributes f.history = ' '.join(sys.argv) f.network_pkl = kwargs["network_pkl"] # Load network - dist.print0('Generating images...') - - net = load_pickle(opts.network_pkl) - net_reg = load_pickle(opts.network_reg_pkl) if opts.res_edm else None + logger.info(f'torch.__version__: {torch.__version__}') + logger0.info(f'Loading network from "{opts.network_pkl}"...') + net = load_pickle(opts.network_pkl, dist.rank) + logger0.info(f'Loading network from "{opts.network_reg_pkl}"...') + net_reg = load_pickle(opts.network_reg_pkl, dist.rank) if opts.res_edm else None # move to device num_gpus = dist.world_size @@ -397,22 +390,23 @@ def generate_fn(image_lr): sample_seeds = seeds + logger0.info(f'seeds: {sample_seeds}') if net_reg: image_mean = generate( net=net_reg, img_lr=image_lr_patch, max_batch_size=image_lr_patch.shape[0], seeds=sample_seeds, - pretext='reg', class_idx=class_idx, + pretext='reg', class_idx=class_idx ) image_out = image_mean + generate( net=net, img_lr=image_lr_patch, max_batch_size=image_lr_patch.shape[0], seeds=sample_seeds, - pretext='gen', class_idx=class_idx, + pretext='gen', class_idx=class_idx ) else: image_out = generate( net=net, img_lr=image_lr_patch, max_batch_size=image_lr_patch.shape[0], seeds=sample_seeds, - pretext=opts.pretext, class_idx=class_idx, + pretext=opts.pretext, class_idx=class_idx ) #reshape: (1*9*9)x3x50x50 --> 1x3x450x450 @@ -423,12 +417,14 @@ def generate_fn(image_lr): return image_out - generate_and_save(dataset, sampler, f, generate_fn, batch_size) + # generate images + logger0.info('Generating images...') + generate_and_save(dataset, sampler, f, generate_fn, device, batch_size, logger0) # Done. if dist.world_size > 1: torch.distributed.barrier() - dist.print0('Done.') + logger0.info('Done.') def _get_name(channel_info): @@ -505,8 +501,7 @@ def writer_from_input_dataset(f, dataset): return NetCDFWriter(f, lat=dataset.latitude(), lon=dataset.longitude(), input_channels=dataset.input_channels(), output_channels=dataset.output_channels()) -def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size): - +def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, device, batch_size, logger): # Instantiate distributed manager. dist = DistributedManager() device = dist.device @@ -518,7 +513,7 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size): for image_tar, image_lr, index in iter(data_loader): time_index += 1 if dist.rank == 0: - print("starting index", time_index) # TODO print on rank zero + logger.info(f"starting index: {time_index}") # TODO print on rank zero input_data = image_lr = image_lr.to(device=device).to(torch.float32) image_tar = image_tar.to(device=device).to(torch.float32) image_out = generate_fn(image_lr) @@ -542,7 +537,6 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size): image_lr2 = image_lr2.cpu().numpy() image_lr2 = denormalize(image_lr2, mx, sx) - my, sy = dataset.info()['target_normalization'] my = my[dataset.out_channels] sy = sy[dataset.out_channels] @@ -558,7 +552,6 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size): assert image_out2.ndim == 4 # Denormalize the input and outputs - image_out2 = image_out2.cpu().numpy() image_out2 = denormalize(image_out2, my, sy) @@ -582,7 +575,6 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size): writer.write_input(channel_name, time_index, image_lr2[0, channel_idx]) - def generate(net, seeds, class_idx, max_batch_size, img_lr=None, pretext=None, **sampler_kwargs): """Generate random images using the techniques described in the paper "Elucidating the Design Space of Diffusion-Based Generative Models". @@ -602,9 +594,7 @@ def generate(net, seeds, class_idx, max_batch_size, img_lr=None, pretext=None, * # Instantiate distributed manager. dist = DistributedManager() - device = dist.device - dist.print0('seeds', seeds) # TODO fix in logging num_batches = ((len(seeds) - 1) // (max_batch_size * dist.world_size) + 1) * dist.world_size all_batches = torch.as_tensor(seeds).tensor_split(num_batches) @@ -632,7 +622,7 @@ def generate(net, seeds, class_idx, max_batch_size, img_lr=None, pretext=None, * #latents = rnd.randn([batch_size, net.img_in_channels, net.img_resolution, net.img_resolution], device=device) latents = rnd.randn([max_batch_size, net.img_out_channels, net.img_resolution, net.img_resolution], device=device) - + class_labels = None if net.label_dim: class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] diff --git a/modulus/experimental/resdiff/train.py b/modulus/experimental/resdiff/train.py index 821096a245..c5f0b57e07 100644 --- a/modulus/experimental/resdiff/train.py +++ b/modulus/experimental/resdiff/train.py @@ -27,6 +27,7 @@ from training import training_loop from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper try: from apex.optimizers import FusedAdam @@ -120,14 +121,12 @@ def main(**kwargs): # Initialize distributed manager. DistributedManager.initialize() + dist = DistributedManager() - # wrapper class for distributed manager for print0. This will be removed when Modulus logging is implemented. - class DistributedManagerWrapper(DistributedManager): - def print0(self, *message): - if self.rank == 0: - print(*message) - - dist = DistributedManagerWrapper() + # Initialize logger. + logger = PythonLogger(name="train") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging(file_name="train.log") # Initialize config dict. c = dnnlib.EasyDict() @@ -258,7 +257,7 @@ def print0(self, *message): #print('opts.resume', opts.resume) f.close() - dist.print0('opts.resume', opts.resume) + logger0.info(f'opts.resume: { opts.resume}') # Transfer learning and resume. if opts.transfer is not None: @@ -267,18 +266,18 @@ def print0(self, *message): c.resume_pkl = opts.transfer c.ema_rampup_ratio = None elif opts.resume is not None: - print('gets into elif opts.resume is not None ...') + logger.info('gets into elif opts.resume is not None ...') match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume)) - print('match', match) - print('match.group(1)', match.group(1)) + logger.info('match', match) + logger.info('match.group(1)', match.group(1)) # if not match or not os.path.isfile(opts.resume): # raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl') c.resume_kimg = int(match.group(1)) c.resume_state_dump = opts.resume - dist.print0('c.resume_pkl', c.resume_pkl) - dist.print0('c.resume_kimg', c.resume_kimg) - dist.print0('c.resume_state_dump', c.resume_state_dump) + logger0.info(f'c.resume_pkl: {c.resume_pkl}') + logger0.info(f'c.resume_kimg: {c.resume_kimg}') + logger0.info(f'c.resume_state_dump: {c.resume_state_dump}') # import pdb; pdb.set_trace() @@ -315,27 +314,24 @@ def print0(self, *message): c.task = opts.task # Print options. - dist.print0() - dist.print0('Training options:') - dist.print0(json.dumps(c, indent=2)) - dist.print0() - dist.print0(f'Output directory: {c.run_dir}') - dist.print0(f'Dataset path: {c.dataset_kwargs.path}') - dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') - dist.print0(f'Network architecture: {opts.arch}') - dist.print0(f'Preconditioning & loss: {opts.precond}') - dist.print0(f'Number of GPUs: {dist.world_size}') - dist.print0(f'Batch size: {c.batch_size}') - dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') - dist.print0() + logger0.info('Training options:') + logger0.info(json.dumps(c, indent=2)) + logger0.info(f'Output directory: {c.run_dir}') + logger0.info(f'Dataset path: {c.dataset_kwargs.path}') + logger0.info(f'Class-conditional: {c.dataset_kwargs.use_labels}') + logger0.info(f'Network architecture: {opts.arch}') + logger0.info(f'Preconditioning & loss: {opts.precond}') + logger0.info(f'Number of GPUs: {dist.world_size}') + logger0.info(f'Batch size: {c.batch_size}') + logger0.info(f'Mixed-precision: {c.network_kwargs.use_fp16}') # Dry run? if opts.dry_run: - dist.print0('Dry run; exiting.') + logger0.info('Dry run; exiting.') return # Create output directory. - dist.print0('Creating output directory...') + logger0.info('Creating output directory...') if dist.rank == 0: os.makedirs(c.run_dir, exist_ok=True) with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: diff --git a/modulus/experimental/resdiff/training/training_loop.py b/modulus/experimental/resdiff/training/training_loop.py index 888a13daac..b33b322734 100644 --- a/modulus/experimental/resdiff/training/training_loop.py +++ b/modulus/experimental/resdiff/training/training_loop.py @@ -29,6 +29,7 @@ from torch_utils import misc from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper #weather related from .YParams import YParams @@ -67,16 +68,17 @@ def training_loop( data_config = None, task = None, ): - # Initialize. - start_time = time.time() + + # Instantiate distributed manager. + dist = DistributedManager() - # wrapper class for distributed manager for print0. This will be removed when Modulus logging is implemented. - class DistributedManagerWrapper(DistributedManager): - def print0(self, *message): - if self.rank == 0: - print(*message) + # Initialize logger. + logger = PythonLogger(name="training_loop") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging(file_name="training_loop.log") - dist = DistributedManagerWrapper() + # Initialize. + start_time = time.time() device = dist.device np.random.seed((seed * dist.world_size + dist.rank) % (1 << 31)) @@ -88,7 +90,7 @@ def print0(self, *message): # Select batch size per GPU. batch_gpu_total = batch_size // dist.world_size - dist.print0('batch_gpu', batch_gpu) + logger0.info(f'batch_gpu: {batch_gpu}') if batch_gpu is None or batch_gpu > batch_gpu_total: batch_gpu = batch_gpu_total num_accumulation_rounds = batch_gpu_total // batch_gpu @@ -103,7 +105,7 @@ def print0(self, *message): ''' # Load dataset: weather - dist.print0('Loading dataset...') + logger0.info('Loading dataset...') yparams = YParams(data_type + '.yaml', config_name=data_config) @@ -145,7 +147,7 @@ def print0(self, *message): # img_in_channels = img_in_channels + yparams.N_grid_channels + img_out_channels # Construct network. - dist.print0('Constructing network...') + logger0.info('Constructing network...') #interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) #cifar10 interface_kwargs = dict(img_resolution=yparams.crop_size_x, img_channels=img_out_channels, img_in_channels=img_in_channels, img_out_channels=img_out_channels, label_dim=0) #weather net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module @@ -175,7 +177,7 @@ def print0(self, *message): # Setup optimizer. - dist.print0('Setting up optimizer...') + logger0.info('Setting up optimizer...') loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe @@ -195,11 +197,11 @@ def print0(self, *message): from userlib.auto_resume import AutoResume AutoResume.init() except ImportError: - print('AutoResume not imported') + logger0.warning('AutoResume not imported') # Resume training from previous snapshot. if resume_pkl is not None: - dist.print0(f'Loading network weights from "{resume_pkl}"...') + logger0.info(f'Loading network weights from "{resume_pkl}"...') if dist.rank != 0: torch.distributed.barrier() # rank 0 goes first with dnnlib.util.open_url(resume_pkl, verbose=(dist.rank == 0)) as f: @@ -210,7 +212,7 @@ def print0(self, *message): misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) del data # conserve memory if resume_state_dump: - dist.print0(f'Loading training state from "{resume_state_dump}"...') + logger0.info(f'Loading training state from "{resume_state_dump}"...') data = torch.load(resume_state_dump, map_location=torch.device('cpu')) misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) #dist.print0('data-optimizer', data['optimizer_state']) @@ -228,7 +230,7 @@ def print0(self, *message): # import pdb; pdb.set_trace() # Train. - dist.print0(f'Training for {total_kimg} kimg...') + logger0.info(f'Training for {total_kimg} kimg...') cur_nimg = resume_kimg * 1000 cur_tick = 0 tick_start_nimg = cur_nimg @@ -310,25 +312,34 @@ def print0(self, *message): fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] torch.cuda.reset_peak_memory_stats() - dist.print0(' '.join(fields)) + logger0.info(' '.join(fields)) ckpt_dir = run_dir - print('AutoResume.termination_requested()', AutoResume.termination_requested()) - print('AutoResume', AutoResume) + logger0.info(f'AutoResume.termination_requested(): {AutoResume.termination_requested()}') + logger0.info(f'AutoResume: {AutoResume}') if AutoResume.termination_requested(): AutoResume.request_resume() - print("Training terminated. Returning") + logger0.info("Training terminated. Returning") done = True #print('dist.rank', dist.rank) #with open(os.path.join(os.path.split(ckpt_dir)[0],'resume.txt'), "w") as f: with open(os.path.join(ckpt_dir,'resume.txt'), "w") as f: f.write(os.path.join(ckpt_dir, f'training-state-{cur_nimg//1000:06d}.pt')) - print(os.path.join(ckpt_dir, f'training-state-{cur_nimg//1000:06d}.pt')) + logger0.info(os.path.join(ckpt_dir, f'training-state-{cur_nimg//1000:06d}.pt')) f.close() #return 0 + # Check for abort. + logger0.info(f'dist.should_stop(): {dist.should_stop()}') + logger0.info(f'done: {done}') + + # if (not done) and dist.should_stop(): + # done = True + # dist.print0() + # dist.print0('Aborting...') + # Save network snapshot. if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) @@ -367,6 +378,6 @@ def print0(self, *message): # Done. - dist.print0('Exiting...') + logger0.info('Exiting...') #----------------------------------------------------------------------------