Skip to content

Commit

Permalink
ResDiff integration: Switch to Modulus Logging (#251)
Browse files Browse the repository at this point in the history
* change to modulus logging

* better handling of logging

* remove redundant logger arg in training

* fix minor bugs

* address review comments
  • Loading branch information
mnabian authored Nov 30, 2023
1 parent 8ed3d41 commit 75ca204
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 98 deletions.
84 changes: 37 additions & 47 deletions modulus/experimental/resdiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@
try:
from edmss import edm_sampler
except ImportError:
raise ImportError("Please get the edm_sampler by running pip install git+https://github.com/mnabian/edmss.git")
raise ImportError("Please get the edm_sampler by running: pip install git+https://github.com/mnabian/edmss.git")

from modulus.distributed import DistributedManager
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper

def unet_regression(
net, latents, img_lr, class_labels=None, randn_like=torch.randn_like,
Expand Down Expand Up @@ -213,17 +214,9 @@ def randint(self, *args, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])


def load_pickle(network_pkl):

# Instantiate distributed manager
dist = DistributedManager()

# Load network.
print(f'Loading network from "{network_pkl}"...') # TODO print on rank zero
with dnnlib.util.open_url(network_pkl, verbose=(dist.rank == 0)) as f:
print('torch.__version__', torch.__version__)

def load_pickle(network_pkl, rank):
# Load network.
with dnnlib.util.open_url(network_pkl, verbose=(rank == 0)) as f:
return pickle.load(f)['ema']


Expand All @@ -247,17 +240,10 @@ def parse_int_list(s):
def get_config_file(data_type):
config_root = os.getenv("CONFIG_ROOT", "configs")
config_file = os.path.join(config_root, data_type + '.yaml')
print(config_file)
return config_file


def get_dataset_and_sampler(data_type, data_config, config_file=None):
root = os.getenv("DATA_ROOT", "")
if config_file is None:
config_file = get_config_file(data_type)
params = YParams(config_file, config_name=data_config)
params["train_data_path"] = os.path.join(root, params["train_data_path"])
print(params.train_data_path)
def get_dataset_and_sampler(data_type, params):

if data_type == 'cwb':
dataset = CWBDataset(params, params.test_data_path, train=False, task=opts.task)
Expand Down Expand Up @@ -333,14 +319,15 @@ def get_dataset_and_sampler(data_type, data_config, config_file=None):
def main(max_times: Optional[int], seeds: List[int], **kwargs):

opts = dnnlib.EasyDict(kwargs)

# Initialize distributed manager
DistributedManager.initialize()
dist = DistributedManager()

# wrapper class for distributed manager for print0. This will be removed when Modulus logging is implemented.
class DistributedManagerWrapper(DistributedManager):
def print0(self, *message):
if self.rank == 0:
print(*message)

dist = DistributedManagerWrapper()
# Initialize logger.
logger = PythonLogger("generate") # General python logger
logger0 = RankZeroLoggingWrapper(logger, dist)
logger.file_logging("generate.log")

det_batch = None
gen_batch = None
Expand All @@ -349,26 +336,32 @@ def print0(self, *message):
if det_batch is None: det_batch = 1 #max(gen_batch, 64)
assert det_batch % gen_batch == 0

print('opts.data_config', opts.data_config)
logger0.info(f'opts.data_config: {opts.data_config}')

# Data
config_file = get_config_file(opts.data_type)
logger0.info(f"Config file: {config_file}")
params = YParams(config_file, config_name=opts.data_config)
patch_size = params.patch_size
crop_size_x = params.crop_size_x
crop_size_y = params.crop_size_y

dataset, sampler = get_dataset_and_sampler(opts.data_type, opts.data_config)
root = os.getenv("DATA_ROOT", "")
params["train_data_path"] = os.path.join(root, params["train_data_path"])
logger0.info(f"Train data path: {params.train_data_path}")
dataset, sampler = get_dataset_and_sampler(opts.data_type, params)

with nc.Dataset(opts.outdir.format(rank=dist.rank), "w") as f:
# add attributes
f.history = ' '.join(sys.argv)
f.network_pkl = kwargs["network_pkl"]

# Load network
dist.print0('Generating images...')

net = load_pickle(opts.network_pkl)
net_reg = load_pickle(opts.network_reg_pkl) if opts.res_edm else None
logger.info(f'torch.__version__: {torch.__version__}')
logger0.info(f'Loading network from "{opts.network_pkl}"...')
net = load_pickle(opts.network_pkl, dist.rank)
logger0.info(f'Loading network from "{opts.network_reg_pkl}"...')
net_reg = load_pickle(opts.network_reg_pkl, dist.rank) if opts.res_edm else None

# move to device
num_gpus = dist.world_size
Expand Down Expand Up @@ -397,22 +390,23 @@ def generate_fn(image_lr):

sample_seeds = seeds

logger0.info(f'seeds: {sample_seeds}')
if net_reg:
image_mean = generate(
net=net_reg, img_lr=image_lr_patch,
max_batch_size=image_lr_patch.shape[0], seeds=sample_seeds,
pretext='reg', class_idx=class_idx,
pretext='reg', class_idx=class_idx
)
image_out = image_mean + generate(
net=net, img_lr=image_lr_patch,
max_batch_size=image_lr_patch.shape[0], seeds=sample_seeds,
pretext='gen', class_idx=class_idx,
pretext='gen', class_idx=class_idx
)
else:
image_out = generate(
net=net, img_lr=image_lr_patch,
max_batch_size=image_lr_patch.shape[0], seeds=sample_seeds,
pretext=opts.pretext, class_idx=class_idx,
pretext=opts.pretext, class_idx=class_idx
)

#reshape: (1*9*9)x3x50x50 --> 1x3x450x450
Expand All @@ -423,12 +417,14 @@ def generate_fn(image_lr):

return image_out

generate_and_save(dataset, sampler, f, generate_fn, batch_size)
# generate images
logger0.info('Generating images...')
generate_and_save(dataset, sampler, f, generate_fn, device, batch_size, logger0)

# Done.
if dist.world_size > 1:
torch.distributed.barrier()
dist.print0('Done.')
logger0.info('Done.')


def _get_name(channel_info):
Expand Down Expand Up @@ -505,8 +501,7 @@ def writer_from_input_dataset(f, dataset):
return NetCDFWriter(f, lat=dataset.latitude(), lon=dataset.longitude(), input_channels=dataset.input_channels(), output_channels=dataset.output_channels())


def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size):

def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, device, batch_size, logger):
# Instantiate distributed manager.
dist = DistributedManager()
device = dist.device
Expand All @@ -518,7 +513,7 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size):
for image_tar, image_lr, index in iter(data_loader):
time_index += 1
if dist.rank == 0:
print("starting index", time_index) # TODO print on rank zero
logger.info(f"starting index: {time_index}") # TODO print on rank zero
input_data = image_lr = image_lr.to(device=device).to(torch.float32)
image_tar = image_tar.to(device=device).to(torch.float32)
image_out = generate_fn(image_lr)
Expand All @@ -542,7 +537,6 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size):
image_lr2 = image_lr2.cpu().numpy()
image_lr2 = denormalize(image_lr2, mx, sx)


my, sy = dataset.info()['target_normalization']
my = my[dataset.out_channels]
sy = sy[dataset.out_channels]
Expand All @@ -558,7 +552,6 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size):
assert image_out2.ndim == 4

# Denormalize the input and outputs

image_out2 = image_out2.cpu().numpy()
image_out2 = denormalize(image_out2, my, sy)

Expand All @@ -582,7 +575,6 @@ def generate_and_save(dataset, sampler, f: nc.Dataset, generate_fn, batch_size):
writer.write_input(channel_name, time_index, image_lr2[0, channel_idx])



def generate(net, seeds, class_idx, max_batch_size, img_lr=None, pretext=None, **sampler_kwargs):
"""Generate random images using the techniques described in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models".
Expand All @@ -602,9 +594,7 @@ def generate(net, seeds, class_idx, max_batch_size, img_lr=None, pretext=None, *

# Instantiate distributed manager.
dist = DistributedManager()

device = dist.device
dist.print0('seeds', seeds) # TODO fix in logging

num_batches = ((len(seeds) - 1) // (max_batch_size * dist.world_size) + 1) * dist.world_size
all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
Expand Down Expand Up @@ -632,7 +622,7 @@ def generate(net, seeds, class_idx, max_batch_size, img_lr=None, pretext=None, *
#latents = rnd.randn([batch_size, net.img_in_channels, net.img_resolution, net.img_resolution], device=device)
latents = rnd.randn([max_batch_size, net.img_out_channels, net.img_resolution, net.img_resolution], device=device)


class_labels = None
if net.label_dim:
class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
Expand Down
54 changes: 25 additions & 29 deletions modulus/experimental/resdiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from training import training_loop

from modulus.distributed import DistributedManager
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper

try:
from apex.optimizers import FusedAdam
Expand Down Expand Up @@ -120,14 +121,12 @@ def main(**kwargs):

# Initialize distributed manager.
DistributedManager.initialize()
dist = DistributedManager()

# wrapper class for distributed manager for print0. This will be removed when Modulus logging is implemented.
class DistributedManagerWrapper(DistributedManager):
def print0(self, *message):
if self.rank == 0:
print(*message)

dist = DistributedManagerWrapper()
# Initialize logger.
logger = PythonLogger(name="train") # General python logger
logger0 = RankZeroLoggingWrapper(logger, dist)
logger.file_logging(file_name="train.log")

# Initialize config dict.
c = dnnlib.EasyDict()
Expand Down Expand Up @@ -258,7 +257,7 @@ def print0(self, *message):
#print('opts.resume', opts.resume)
f.close()

dist.print0('opts.resume', opts.resume)
logger0.info(f'opts.resume: { opts.resume}')

# Transfer learning and resume.
if opts.transfer is not None:
Expand All @@ -267,18 +266,18 @@ def print0(self, *message):
c.resume_pkl = opts.transfer
c.ema_rampup_ratio = None
elif opts.resume is not None:
print('gets into elif opts.resume is not None ...')
logger.info('gets into elif opts.resume is not None ...')
match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
print('match', match)
print('match.group(1)', match.group(1))
logger.info('match', match)
logger.info('match.group(1)', match.group(1))
# if not match or not os.path.isfile(opts.resume):
# raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
c.resume_kimg = int(match.group(1))
c.resume_state_dump = opts.resume
dist.print0('c.resume_pkl', c.resume_pkl)
dist.print0('c.resume_kimg', c.resume_kimg)
dist.print0('c.resume_state_dump', c.resume_state_dump)
logger0.info(f'c.resume_pkl: {c.resume_pkl}')
logger0.info(f'c.resume_kimg: {c.resume_kimg}')
logger0.info(f'c.resume_state_dump: {c.resume_state_dump}')
# import pdb; pdb.set_trace()


Expand Down Expand Up @@ -315,27 +314,24 @@ def print0(self, *message):
c.task = opts.task

# Print options.
dist.print0()
dist.print0('Training options:')
dist.print0(json.dumps(c, indent=2))
dist.print0()
dist.print0(f'Output directory: {c.run_dir}')
dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
dist.print0(f'Network architecture: {opts.arch}')
dist.print0(f'Preconditioning & loss: {opts.precond}')
dist.print0(f'Number of GPUs: {dist.world_size}')
dist.print0(f'Batch size: {c.batch_size}')
dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
dist.print0()
logger0.info('Training options:')
logger0.info(json.dumps(c, indent=2))
logger0.info(f'Output directory: {c.run_dir}')
logger0.info(f'Dataset path: {c.dataset_kwargs.path}')
logger0.info(f'Class-conditional: {c.dataset_kwargs.use_labels}')
logger0.info(f'Network architecture: {opts.arch}')
logger0.info(f'Preconditioning & loss: {opts.precond}')
logger0.info(f'Number of GPUs: {dist.world_size}')
logger0.info(f'Batch size: {c.batch_size}')
logger0.info(f'Mixed-precision: {c.network_kwargs.use_fp16}')

# Dry run?
if opts.dry_run:
dist.print0('Dry run; exiting.')
logger0.info('Dry run; exiting.')
return

# Create output directory.
dist.print0('Creating output directory...')
logger0.info('Creating output directory...')
if dist.rank == 0:
os.makedirs(c.run_dir, exist_ok=True)
with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
Expand Down
Loading

0 comments on commit 75ca204

Please sign in to comment.