Skip to content

Latest commit

 

History

History
145 lines (98 loc) · 7.46 KB

README.md

File metadata and controls

145 lines (98 loc) · 7.46 KB

Probabilistic Emulation of a Global Climate Model with Spherical DYffusion (NeurIPS 2024, Spotlight)

Python PyTorch Lightning Config: hydra License

✨Official implementation of our Spherical DYffusion paper✨

| Environment Setup

We recommend installing in a virtual environment from PyPi or Conda. Then, run:

python3 -m pip install .[dev]
python3 -m pip install --no-deps nvidia-modulus@git+https://github.com/ai2cm/modulus.git@94f62e1ce2083640829ec12d80b00619c40a47f8

Alternatively, use the provided environment/install_dependencies.sh script.

Note that for some compute setups you may want to install pytorch first for proper GPU support. For more details about installing PyTorch, please refer to their official documentation.

| Dataset

The final training and validation data can be downloaded from Google Cloud Storage following the instructions of the ACE paper at https://zenodo.org/records/10791087. The data are licensed under Creative Commons Attribution 4.0 International.

| Checkpoints

Model weights are available at https://huggingface.co/salv47/spherical-dyffusion.

| Running experiments

Inference

Firstly, download the validation data as instructed in the Dataset section.

Secondly, use the run_inference.py script with a corresponding configuration file. The configurations files used for our paper can be found in the src/configs/inference directory. That is, you can run inference with the following command:

python run_inference.py <path-to-inference-config>.yaml

The available inference configurations are:

To use these configs, you need to correctly specify the dataset.data_path parameter in the configuration file to point to the downloaded validation data.

Training

We use Hydra for configuration management and PyTorch Lightning for training. We recommend familiarizing yourself with these tools before running training experiments.

Tips & Tricks

Memory Considerations and OOM Errors

To control memory usage and avoid OOM errors, you can adjust the training batch size and evaluation batch size:

For training, you can adjust the datamodule.batch_size_per_gpu parameter. Note that this will automatically adjust trainer.accumulate_grad_batches to keep the effective batch size (set by datamodule.batch_size) constant (so it need to be divisible by datamodule.batch_size_per_gpu).

For evaluation or OOMs during validation, you can adjust the datamodule.eval_batch_size parameter. Note that the effective validation-time batch size is datamodule.eval_batch_size * module.num_predictions. Be mindful of that when choosing eval_batch_size. You can control how many ensemble members to run in memory at once with module.num_predictions_in_memory.

Besides those main knobs, you may turn on mixed precision training with trainer.precision=16 to reduce memory usage and may also adjust the datamodule.num_workers parameter to control the number of data loading processes.

Wandb Integration

We use Weights & Biases for logging and checkpointing. Please set your wandb username/entity with one of the following options:

Checkpointing

By default, checkpoints are saved locally in the <work_dir>/checkpoints directory in the root of the repository, which you can control with the work_dir=<path> argument.

When using the wandb logger (default), checkpoints may be saved to wandb (logger.wandb.save_to_wandb) or S3 storage (logger.wandb.save_to_s3_bucket). Set these to False to disable saving them to wandb or S3. If disabling both (only save checkpoints locally), make sure to set logger.wandb.save_best_ckpt=False logger.wandb.save_last_ckpt=False. You can set these preferences in your local config file (see src/configs/local/example_local_config.yaml for an example).

Debugging

For minimal data and model size, you can use the following:

python run.py ++model.debug_mode=True ++datamodule.debug_mode=True

Note that the model and datamodule need to support to appropriately handle the debug mode.

Code Quality

Code quality is automatically checked when pushing to the repository. However, it is recommended that you also run the checks locally with make quality.

To automatically fix some issues (as much as possible), run:

make style
hydra.errors.InstantiationException

The hydra.errors.InstantiationException itself is not very informative, so you need to look at the preceding exception(s) (i.e. scroll up) to see what went wrong.

Local Configurations

You can use a local config file that, defines the local data dir, working dir etc., by putting a default.yaml config in the src/configs/local/ subdirectory. Hydra searches for & uses by default the file configs/local/default.yaml, if it exists. You may take inspiration from the example_local_config.yaml file.

| Citation

@inproceedings{cachay2024spherical,
    title={Probablistic Emulation of a Global Climate Model with Spherical {DY}ffusion},
    author={Salva R{\"u}hling Cachay and Brian Henn and Oliver Watt-Meyer and Christopher S. Bretherton and Rose Yu},
    booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
    year={2024},
    url={https://openreview.net/forum?id=Ib2iHIJRTh}
}