Skip to content

Commit

Permalink
Add probabilistic Graph-EFM model from paper and update repo
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Jun 6, 2024
1 parent b0050b9 commit 89c5ce9
Show file tree
Hide file tree
Showing 22 changed files with 2,439 additions and 201 deletions.
121 changes: 58 additions & 63 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,38 @@
<img src="figures/neural_lam_header.png" width="700">
</p>

Neural-LAM is a repository of graph-based neural weather prediction models for Limited Area Modeling (LAM).
Neural-LAM is a repository of graph-based neural weather prediction models.
The code uses [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/pytorch-lightning).
Graph Neural Networks are implemented using [PyG](https://pyg.org/) and logging is set up through [Weights & Biases](https://wandb.ai/).

The repository contains LAM versions of:
# This Branch: Probabilistic LAM Forecasting
<p align="middle">
<img src="figures/graph_efm_forecast_nlwrs.gif" width="700"/>
</p>
<p align="middle">
<em>Example ensemble forecast from Graph-EFM for net solar longwave radiation.</em>
</p>

This branch contains the code for our paper *Probabilistic Weather Forecasting with Hierarchical Graph Neural Networks*, for Limited Area Modeling (LAM).
In particular, it contains implementations of:

* The graph-based model from [Keisler (2022)](https://arxiv.org/abs/2202.07575).
* GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794).
* The hierarchical model from [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370).
* Our ensemble forecasting model Graph-EFM.
* The hierarchical Graph-FM model (also called Hi-LAM in [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370) and on the `main` branch).
* Our re-implementation of GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794).

For more information see our paper: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370).
If you use Neural-LAM in your work, please cite:
If you use these models in your work, please cite:
```
@inproceedings{oskarsson2023graphbased,
title={Graph-based Neural Weather Prediction for Limited Area Modeling},
author={Oskarsson, Joel and Landelius, Tomas and Lindsten, Fredrik},
booktitle={NeurIPS 2023 Workshop on Tackling Climate Change with Machine Learning},
year={2023}
@article{probabilistic_weather_forecasting,
title={Probabilistic Weather Forecasting with Hierarchical Graph Neural Networks},
author={Oskarsson, Joel and Landelius, Tomas and Deisenroth, Marc Peter and Lindsten, Fredrik},
year={2024},
journal={arXiv preprint}
}
```
As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper.
See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces the workshop paper.

We plan to continue updating this repository as we improve existing models and develop new ones.
Collaborations around this implementation are very welcome.
If you are working with Neural-LAM feel free to get in touch and/or submit pull requests to the repository.
We are currently working to merge these models also to the `main` branch of Neural-LAM.
This README describes how to use the different models and run the experiments from the paper.
Do also check the [`main` branch](https://github.com/mllam/neural-lam) for further details and more updated implementations for parts of the codebase.

# Modularity
The Neural-LAM code is designed to modularize the different components involved in training and evaluating neural weather prediction models.
Expand All @@ -46,21 +52,18 @@ Still, some restrictions are inevitable:
Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
There are still some parts of the code that is quite specific for the MEPS area use case.
This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`).
If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
We would be happy to support such enhancements.
See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
Work is being done to refactor the code to be fully area-agnostic.

# Using Neural-LAM
Below follows instructions on how to use Neural-LAM to train and evaluate models.

## Installation
Follow the steps below to create the necessary python environment.

1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement.
2. Use python 3.9.
3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
4. Install required packages specified in `requirements.txt`.
5. Install PyTorch Geometric version 2.2.0. This can be done by running
1. Use python 3.10.
2. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
3. Install required packages specified in `requirements.txt`.
4. Install PyTorch Geometric version 2.3.1. This can be done by running
```
TORCH="2.0.1"
CUDA="cu117"
Expand All @@ -75,9 +78,9 @@ Datasets should be stored in a directory called `data`.
See the [repository format section](#format-of-data-directory) for details on the directory structure.

The full MEPS dataset can be shared with other researchers on request, contact us for this.
A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:u:/g/personal/joeos82_liu_se/EaUUq6h9og1EsLwJmKAltWwB7zP2gmObe-K8pL6qGYYiGg?e=yQbFuV).
Download the file and unzip in the neural-lam directory.
All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `create_mesh.py`).
After downloading this data, the graphs used in the paper can be generated as described below.
Note that this is far too little data to train any useful models, but all scripts can be ran with it.
It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.

Expand All @@ -94,11 +97,10 @@ In order to start training models at least three pre-processing scripts have to

### Create graph
Run `create_mesh.py` with suitable options to generate the graph you want to use (see `python create_mesh.py --help` for a list of options).
The graphs used for the different models in the [paper](https://arxiv.org/abs/2309.17370) can be created as:
The graphs used in the paper can be created as:

* **GC-LAM**: `python create_mesh.py --graph multiscale`
* **Hi-LAM**: `python create_mesh.py --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel)
* **L1-LAM**: `python create_mesh.py --graph 1level --levels 1`
* **multi-scale**: `python create_mesh.py --graph multiscale`
* **hierarchical**: `python create_mesh.py --graph hierarchical --hierarchical 1 --levels 3`

The graph-related files are stored in a directory called `graphs`.

Expand Down Expand Up @@ -130,42 +132,44 @@ A few of the key ones are outlined below:
* `--dataset`: Which data to train on
* `--model`: Which model to train
* `--graph`: Which graph to use with the model
* `--processor_layers`: Number of GNN layers to use in the processing part of the model
* `--processor_layers`: Number of GNN layers to use in the processing part of deterministic models, or in the predictor (decoder) for Graph-EFM
* `--encoder_processor_layers`: Number of GNN layers to use in the variatonal approximation for Graph-EFM
* `--prior_processor_layers`: Number of GNN layers to use in the latent map (prior) for Graph-EFM
* `--ar_steps`: Number of time steps to unroll for when making predictions and computing the loss

Checkpoints of trained models are stored in the `saved_models` directory.
Checkpoints of trained models are stored in the `saved_models` directory when training.
For detailed hyperparameter settings we refer to the paper, in particular the appendices with model and experiment details.

The implemented models are:

### Graph-LAM
This is the basic graph-based LAM model.
### GraphCast
This is our re-implementation of GraphCast, and really can be used with any type of non-hierarchical graph (not just multi-scale).
The encode-process-decode framework is used with a mesh graph in order to make one-step pedictions.
This model class is used both for the L1-LAM and GC-LAM models from the [paper](https://arxiv.org/abs/2309.17370), only with different graphs.

To train 1L-LAM use
To train GraphCast use
```
python train_model.py --model graph_lam --graph 1level ...
python train_model.py --model graphcast --graph multiscale ...
```

To train GC-LAM use
### Graph-FM
Deterministic graph-based forecasting model that uses a hierarchical mesh graph and performs sequential message passing through the hierarchy during processing.

To train Graph-FM use
```
python train_model.py --model graph_lam --graph multiscale ...
python train_model.py --model graph_fm --graph hierarchical ...
```

### Hi-LAM
A version of Graph-LAM that uses a hierarchical mesh graph and performs sequential message passing through the hierarchy during processing.
### Graph-EFM
This is the probabibilistic graph-based ensemble model.
The same model can be used both with multi-scale and hierarchical graphs, with different behaviour internally.

To train Hi-LAM use
To train Graph-EFM use e.g.
```
python train_model.py --model hi_lam --graph hierarchical ...
python train_model.py --model graph_efm --graph multiscale ...
```

### Hi-LAM-Parallel
A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in parallel.
Not included in the paper as initial experiments showed worse results than Hi-LAM, but could be interesting to try in more settings.

To train Hi-LAM-Parallel use
or
```
python train_model.py --model hi_lam_parallel --graph hierarchical ...
python train_model.py --model graph_efm --graph hierarchical ...
```

Checkpoint files for our models trained on the MEPS data are available upon request.
Expand All @@ -178,8 +182,9 @@ Some options specifically important for evaluation are:

* `--load`: Path to model checkpoint file (`.ckpt`) to load parameters from
* `--n_example_pred`: Number of example predictions to plot during evaluation.
* `--ensemble_size`: Number of ensemble members to sample (for Graph-EFM)

**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this.
**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged if using a batch size > 1. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this.

# Repository Structure
Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory.
Expand Down Expand Up @@ -270,16 +275,6 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w
These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels).
Entries 0 in these lists describe edges between the lowest levels 1 and 2.

# Development and Contributing
Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks.
These hooks will run a series of checks on the code, like formatting and linting.
If any of these checks fail the push or PR will be rejected.
To test whether your code passes these checks before pushing, run
``` bash
pre-commit run --all-files
```
from the root directory of the repository.

# Contact
If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch.
For questions about our implementation or ideas for extending it, feel free to get in touch.
You can open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]).
Binary file added figures/graph_efm_forecast_nlwrs.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 23 additions & 6 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@

# Log prediction error for these lead times
VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19])
# Also save checkpoints for minimum loss at these lead times
VAL_STEP_CHECKPOINTS = (1, 19)

# Log these metrics to wandb as scalar values for
# specific variables and lead times
# List of metrics to watch, including any prefix (e.g. val_rmse)
METRICS_WATCH = []
METRICS_WATCH = [
"val_spsk_ratio",
"val_spread",
]
# Dict with variables and lead times to log watched metrics for
# Format is a dictionary that maps from a variable index to
# a list of lead time steps
Expand All @@ -24,6 +29,18 @@
15: [2, 19], # z_1000
}

# Plot forecasts for these variables at given lead times during validation step
# Format is a dictionary that maps from a variable index to a list of
# lead time steps
VAL_PLOT_VARS = {
4: [2, 19], # r_2
14: [2, 19], # wvint_0
}

# During validation, plot example samples of latent variable from prior and
# variational distribution
LATENT_SAMPLES_PLOT = 4 # Number of samples to plot

# Variable names
PARAM_NAMES = [
"pres_heightAboveGround_0_instant",
Expand Down Expand Up @@ -67,8 +84,8 @@
PARAM_UNITS = [
"Pa",
"Pa",
"W/m\\textsuperscript{2}",
"W/m\\textsuperscript{2}",
"W/m²",
"W/m²",
"-", # unitless
"-",
"K",
Expand All @@ -79,9 +96,9 @@
"m/s",
"m/s",
"m/s",
"kg/m\\textsuperscript{2}",
"m\\textsuperscript{2}/s\\textsuperscript{2}",
"m\\textsuperscript{2}/s\\textsuperscript{2}",
"kg/m²",
"m²/s²",
"m²/s²",
]

# Projection and grid
Expand Down
68 changes: 68 additions & 0 deletions neural_lam/interaction_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,74 @@ def aggregate(self, inputs, index, ptr, dim_size):
return aggr, inputs


class PropagationNet(InteractionNet):
"""
Alternative version of InteractionNet that incentivices the propagation of
information from sender nodes to receivers.
"""

def __init__(
self,
edge_index,
input_dim,
update_edges=True,
hidden_layers=1,
hidden_dim=None,
edge_chunk_sizes=None,
aggr_chunk_sizes=None,
aggr="sum",
):
# Use mean aggregation in propagation version to avoid instability
super().__init__(
edge_index,
input_dim,
update_edges=update_edges,
hidden_layers=hidden_layers,
hidden_dim=hidden_dim,
edge_chunk_sizes=edge_chunk_sizes,
aggr_chunk_sizes=aggr_chunk_sizes,
aggr="mean",
)

def forward(self, send_rep, rec_rep, edge_rep):
"""
Apply propagation network to update the representations of receiver
nodes, and optionally the edge representations.
send_rep: (N_send, d_h), vector representations of sender nodes
rec_rep: (N_rec, d_h), vector representations of receiver nodes
edge_rep: (M, d_h), vector representations of edges used
Returns:
rec_rep: (N_rec, d_h), updated vector representations of receiver nodes
(optionally) edge_rep: (M, d_h), updated vector representations
of edges
"""
# Always concatenate to [rec_nodes, send_nodes] for propagation,
# but only aggregate to rec_nodes
node_reps = torch.cat((rec_rep, send_rep), dim=-2)
edge_rep_aggr, edge_diff = self.propagate(
self.edge_index, x=node_reps, edge_attr=edge_rep
)
rec_diff = self.aggr_mlp(torch.cat((rec_rep, edge_rep_aggr), dim=-1))

# Residual connections
rec_rep = edge_rep_aggr + rec_diff # residual is to aggregation

if self.update_edges:
edge_rep = edge_rep + edge_diff
return rec_rep, edge_rep

return rec_rep

def message(self, x_j, x_i, edge_attr):
"""
Compute messages from node j to node i.
"""
# Residual connection is to sender node, propagating information to edge
return x_j + self.edge_mlp(torch.cat((edge_attr, x_j, x_i), dim=-1))


class SplitMLPs(nn.Module):
"""
Module that feeds chunks of input through different MLPs.
Expand Down
Loading

0 comments on commit 89c5ce9

Please sign in to comment.