Skip to content

Commit

Permalink
Update train script for newer torch and lightning versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexanders101 committed Aug 23, 2023
1 parent 8c2605b commit 37ebdfd
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,4 @@ ipython_config.py
# pytorch-lightning output
*_logs/
*_output/
!/environment/
environment/
12 changes: 6 additions & 6 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
name: spanet

channels:
- pytorch
- nvidia
- nvidia/label/cuda-11.7.0
- conda-forge
- nodefaults

dependencies:
- python=3.10.*
- conda
- pip
- cuda-toolkit

- numpy=1.23
- numpy=1.24
- sympy
- scikit-learn
- numba
- opt_einsum
- h5py
- cytoolz

- pytorch=1.13.1
- pytorch-lightning=1.9.0
- pytorch=2.0.1
- pytorch-lightning=2.0.7
- pytorch-cuda=11.7
- cudatoolkit=11.7
- cuda-nvcc
- torchvision
- torchaudio

Expand Down
34 changes: 34 additions & 0 deletions environment_cuda118.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: spanet

channels:
- pytorch
- nvidia/label/cuda-11.8.0
- conda-forge
- nodefaults

dependencies:
- python=3.10.*
- conda
- pip
- cuda-toolkit

- numpy=1.24
- sympy
- scikit-learn
- numba
- opt_einsum
- h5py
- cytoolz

- pytorch=2.0.1
- pytorch-lightning=2.0.7
- pytorch-cuda=11.8
- torchvision
- torchaudio

- tensorboard
- tensorboardx

- jupyterlab
- seaborn
- rich
4 changes: 2 additions & 2 deletions spanet/dataset/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def cluster_purity(self, predictions, target_jets, target_masks):
cluster_predictions = np.stack([predictions[i] for i in cluster_indices])

# Keep track of the best accuracy achieved for each event
best_accuracy = np.zeros(cluster_target_masks.shape[1], dtype=np.int)
best_accuracy = np.zeros(cluster_target_masks.shape[1], dtype=np.int64)

for target_permutation in permutations(range(len(cluster_indices))):
target_permutation = list(target_permutation)
Expand All @@ -107,7 +107,7 @@ def event_purity(self, predictions, target_jets, target_masks):
target_masks = np.stack(target_masks)

# Keep track of the best accuracy achieved for each event
best_accuracy = np.zeros(target_masks.shape[1], dtype=np.int)
best_accuracy = np.zeros(target_masks.shape[1], dtype=np.int64)

for target_permutation in self.event_info.event_permutation_group:
permuted_targets = self.permute_arrays(target_jets, target_permutation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def numpy_tensor_array(tensor_list):
output = np.empty(len(tensor_list), dtype=np.object)
output = np.empty(len(tensor_list), dtype=object)
output[:] = tensor_list

return output
Expand Down
14 changes: 3 additions & 11 deletions spanet/network/jet_reconstruction/jet_reconstruction_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def compute_metrics(self, jet_predictions, particle_scores, stacked_targets, sta
# Compute all possible target permutations and take the best performing permutation
# First compute raw_old accuracy so that we can get an accuracy score for each event
# This will also act as the method for choosing the best permutation to compare for the other metrics.
jet_accuracies = np.zeros((num_permutations, num_targets, batch_size), dtype=np.bool)
particle_accuracies = np.zeros((num_permutations, num_targets, batch_size), dtype=np.bool)
jet_accuracies = np.zeros((num_permutations, num_targets, batch_size), dtype=bool)
particle_accuracies = np.zeros((num_permutations, num_targets, batch_size), dtype=bool)
for i, permutation in enumerate(event_permutation_group):
for j, (prediction, target) in enumerate(zip(jet_predictions, stacked_targets[permutation])):
jet_accuracies[i, j] = np.all(prediction == target, axis=1)
Expand Down Expand Up @@ -100,7 +100,7 @@ def validation_step(self, batch, batch_idx) -> Dict[str, np.float32]:

# Stack all of the targets into single array, we will also move to numpy for easier the numba computations.
stacked_targets = np.zeros(num_targets, dtype=object)
stacked_masks = np.zeros((num_targets, batch_size), dtype=np.bool)
stacked_masks = np.zeros((num_targets, batch_size), dtype=bool)
for i, (target, mask) in enumerate(targets):
stacked_targets[i] = target.detach().cpu().numpy()
stacked_masks[i] = mask.detach().cpu().numpy()
Expand Down Expand Up @@ -151,13 +151,5 @@ def validation_step(self, batch, batch_idx) -> Dict[str, np.float32]:

return metrics

def validation_epoch_end(self, outputs):
# Optionally use this accuracy score for something like hyperparameter search
# validation_accuracy = sum(x['validation_accuracy'] for x in outputs) / len(outputs)

if self.options.verbose_output:
for name, parameter in self.named_parameters():
self.logger.experiment.add_histogram(name, parameter)

def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
106 changes: 49 additions & 57 deletions spanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import torch
import pytorch_lightning as pl
from pytorch_lightning.profiler import PyTorchProfiler
from pytorch_lightning.profilers import PyTorchProfiler
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy, DDPFullyShardedStrategy
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE, WandbLogger

from pytorch_lightning.callbacks import (
LearningRateMonitor,
ModelCheckpoint,
Expand All @@ -36,19 +37,18 @@ def main(
name: str,

torch_script: bool,
fairscale: bool,
fp16: bool,
graph: bool,
verbose: bool,
full_events: bool,

profile: bool,
gpus: Optional[int],
epochs: Optional[int],
time_limit: Optional[str],
batch_size: Optional[int],
limit_dataset: Optional[float],
random_seed: int,
):
):

# Whether or not this script version is the master run or a worker
master = True
Expand Down Expand Up @@ -116,6 +116,7 @@ def main(
if state_dict is not None:
if master:
print(f"Loading state dict from: {state_dict}")

state_dict = torch.load(state_dict, map_location="cpu")["state_dict"]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

Expand All @@ -128,25 +129,13 @@ def main(
if pname in state_dict:
parameter.requires_grad_(False)

# if torch_script:
# model = torch.jit.script(model)

# If we are using more than one gpu, then switch to DDP training
# distributed_backend = 'dp' if options.num_gpu > 1 else None
distributed_backend = None
if options.num_gpu > 1:
if fairscale:
distributed_backend = DDPFullyShardedStrategy(
reshard_after_forward=False
)
else:
distributed_backend = DDPStrategy(
find_unused_parameters=False
)

# Construct the logger for this training run. Logs will be saved in {logdir}/{name}/version_i
log_dir = getcwd() if log_dir is None else log_dir
logger = TensorBoardLogger(save_dir=log_dir, name=name, log_graph=graph)
logger = (
WandbLogger(name=name, save_dir=log_dir)
if _WANDB_AVAILABLE else
TensorBoardLogger(save_dir=log_dir, name=name)
)

# Create the checkpoint for this training run. We will save the best validation networks based on 'accuracy'
callbacks = [
Expand All @@ -170,17 +159,20 @@ def main(
profiler = PyTorchProfiler(emit_nvtx=True)

# Create the final pytorch-lightning manager
trainer = pl.Trainer(logger=logger,
max_epochs=epochs,
callbacks=callbacks,
resume_from_checkpoint=checkpoint,
strategy=distributed_backend,
accelerator="gpu" if options.num_gpu > 0 else None,
devices=options.num_gpu if options.num_gpu > 0 else None,
track_grad_norm=2 if options.verbose_output else -1,
gradient_clip_val=options.gradient_clip if options.gradient_clip > 0 else None,
precision=16 if fp16 else 32,
profiler=profiler)
trainer = pl.Trainer(
accelerator="gpu" if options.num_gpu > 0 else "auto",
devices=options.num_gpu if options.num_gpu > 0 else "auto",
strategy="ddp" if options.num_gpu > 1 else "auto",
precision="16-mixed" if fp16 else "32-true",

gradient_clip_val=options.gradient_clip if options.gradient_clip > 0 else None,
max_epochs=epochs,
max_time=time_limit,

logger=logger,
profiler=profiler,
callbacks=callbacks
)

# Save the current hyperparameters to a json file in the checkpoint directory
if master:
Expand All @@ -192,7 +184,7 @@ def main(

shutil.copy2(options.event_info_file, f"{trainer.logger.log_dir}/event.yaml")

trainer.fit(model)
trainer.fit(model, ckpt_path=checkpoint)
# -------------------------------------------------------------------------------------------------------


Expand All @@ -212,32 +204,32 @@ def main(
help="JSON file with option overloads.")

parser.add_argument("-cf", "--checkpoint", type=str, default=None,
help="Optional checkpoint to load from")
help="Optional checkpoint to load the training state from. "
"Fully restores model weights and optimizer state.")

parser.add_argument("-sf", "--state_dict", type=str, default=None,
help="Load from checkpoint but only the model weights.")
help="Load from checkpoint but only the model weights. "
"Can be partial as the weights don't have to match one-to-one.")

parser.add_argument("-fsf", "--freeze_state_dict", action='store_true',
help="Freeze any weights that were loaded from the state dict.")
help="Freeze any weights that were loaded from the state dict. "
"Used for finetuning new layers.")

parser.add_argument("-l", "--log_dir", type=str, default=None,
help="Output directory for the checkpoints and tensorboard logs. Default to current directory.")

parser.add_argument("-n", "--name", type=str, default="spanet_output",
help="The sub-directory to create for this run.")
help="The sub-directory to create for this run and an identifier for WANDB.")

parser.add_argument("-fp16", action="store_true",
help="Use AMP for training.")

parser.add_argument("--fairscale", action="store_true",
help="Use Fairscale Sharded Training.")

parser.add_argument("-g", "--graph", action="store_true",
help="Log the computation graph.")

parser.add_argument("-v", "--verbose", action='store_true',
help="Output additional information to console and log.")
parser.add_argument("-e", "--epochs", type=int, default=None,
help="Override number of epochs to train for")

parser.add_argument("-t", "--time_limit", type=str, default=None,
help="Time limit for training, in the format DD:HH:MM:SS.")

parser.add_argument("-g", "--gpus", type=int, default=None,
help="Override GPU count in hyperparameters.")

parser.add_argument("-b", "--batch_size", type=int, default=None,
help="Override batch size in hyperparameters.")

Expand All @@ -247,19 +239,19 @@ def main(
parser.add_argument("-p", "--limit_dataset", type=float, default=None,
help="Limit dataset to only the first L percent of the data (0 - 100).")

parser.add_argument("-fp16", "--fp16", action="store_true",
help="Use Torch AMP for training.")

parser.add_argument("-v", "--verbose", action='store_true',
help="Output additional information to console and log.")

parser.add_argument("-r", "--random_seed", type=int, default=0,
help="Set random seed for cross-validation.")

parser.add_argument("-ts", "--torch_script", action='store_true',
help="Compile the neural network using torchscript.")

parser.add_argument("--profile", action='store_true',
help="Profile network for a single training run.")

parser.add_argument("--epochs", type=int, default=None,
help="Override number of epochs to train for")

parser.add_argument("--gpus", type=int, default=None,
help="Override GPU count in hyperparameters.")
help="Profile network for a single training epoch.")

main(**parser.parse_args().__dict__)

0 comments on commit 37ebdfd

Please sign in to comment.