diff --git a/README.md b/README.md index 67d9d9b1..0f309db2 100644 --- a/README.md +++ b/README.md @@ -2,32 +2,38 @@
-Neural-LAM is a repository of graph-based neural weather prediction models for Limited Area Modeling (LAM). +Neural-LAM is a repository of graph-based neural weather prediction models. The code uses [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/pytorch-lightning). Graph Neural Networks are implemented using [PyG](https://pyg.org/) and logging is set up through [Weights & Biases](https://wandb.ai/). -The repository contains LAM versions of: +# This Branch: Probabilistic LAM Forecasting ++ +
++ Example ensemble forecast from Graph-EFM for net solar longwave radiation. +
+ +This branch contains the code for our paper *Probabilistic Weather Forecasting with Hierarchical Graph Neural Networks*, for Limited Area Modeling (LAM). +In particular, it contains implementations of: -* The graph-based model from [Keisler (2022)](https://arxiv.org/abs/2202.07575). -* GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794). -* The hierarchical model from [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370). +* Our ensemble forecasting model Graph-EFM. +* The hierarchical Graph-FM model (also called Hi-LAM in [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370) and on the `main` branch). +* Our re-implementation of GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794). -For more information see our paper: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370). -If you use Neural-LAM in your work, please cite: +If you use these models in your work, please cite: ``` -@inproceedings{oskarsson2023graphbased, - title={Graph-based Neural Weather Prediction for Limited Area Modeling}, - author={Oskarsson, Joel and Landelius, Tomas and Lindsten, Fredrik}, - booktitle={NeurIPS 2023 Workshop on Tackling Climate Change with Machine Learning}, - year={2023} +@article{probabilistic_weather_forecasting, + title={Probabilistic Weather Forecasting with Hierarchical Graph Neural Networks}, + author={Oskarsson, Joel and Landelius, Tomas and Deisenroth, Marc Peter and Lindsten, Fredrik}, + year={2024}, + journal={arXiv preprint} } ``` -As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper. -See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces the workshop paper. -We plan to continue updating this repository as we improve existing models and develop new ones. -Collaborations around this implementation are very welcome. -If you are working with Neural-LAM feel free to get in touch and/or submit pull requests to the repository. +We are currently working to merge these models also to the `main` branch of Neural-LAM. +This README describes how to use the different models and run the experiments from the paper. +Do also check the [`main` branch](https://github.com/mllam/neural-lam) for further details and more updated implementations for parts of the codebase. # Modularity The Neural-LAM code is designed to modularize the different components involved in training and evaluating neural weather prediction models. @@ -46,9 +52,7 @@ Still, some restrictions are inevitable: Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). -If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. -We would be happy to support such enhancements. -See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. +Work is being done to refactor the code to be fully area-agnostic. # Using Neural-LAM Below follows instructions on how to use Neural-LAM to train and evaluate models. @@ -56,11 +60,10 @@ Below follows instructions on how to use Neural-LAM to train and evaluate models ## Installation Follow the steps below to create the necessary python environment. -1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement. -2. Use python 3.9. -3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system. -4. Install required packages specified in `requirements.txt`. -5. Install PyTorch Geometric version 2.2.0. This can be done by running +1. Use python 3.10. +2. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system. +3. Install required packages specified in `requirements.txt`. +4. Install PyTorch Geometric version 2.3.1. This can be done by running ``` TORCH="2.0.1" CUDA="cu117" @@ -75,9 +78,9 @@ Datasets should be stored in a directory called `data`. See the [repository format section](#format-of-data-directory) for details on the directory structure. The full MEPS dataset can be shared with other researchers on request, contact us for this. -A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX). +A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:u:/g/personal/joeos82_liu_se/EaUUq6h9og1EsLwJmKAltWwB7zP2gmObe-K8pL6qGYYiGg?e=yQbFuV). Download the file and unzip in the neural-lam directory. -All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `create_mesh.py`). +After downloading this data, the graphs used in the paper can be generated as described below. Note that this is far too little data to train any useful models, but all scripts can be ran with it. It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues. @@ -94,11 +97,10 @@ In order to start training models at least three pre-processing scripts have to ### Create graph Run `create_mesh.py` with suitable options to generate the graph you want to use (see `python create_mesh.py --help` for a list of options). -The graphs used for the different models in the [paper](https://arxiv.org/abs/2309.17370) can be created as: +The graphs used in the paper can be created as: -* **GC-LAM**: `python create_mesh.py --graph multiscale` -* **Hi-LAM**: `python create_mesh.py --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel) -* **L1-LAM**: `python create_mesh.py --graph 1level --levels 1` +* **multi-scale**: `python create_mesh.py --graph multiscale` +* **hierarchical**: `python create_mesh.py --graph hierarchical --hierarchical 1 --levels 3` The graph-related files are stored in a directory called `graphs`. @@ -130,42 +132,44 @@ A few of the key ones are outlined below: * `--dataset`: Which data to train on * `--model`: Which model to train * `--graph`: Which graph to use with the model -* `--processor_layers`: Number of GNN layers to use in the processing part of the model +* `--processor_layers`: Number of GNN layers to use in the processing part of deterministic models, or in the predictor (decoder) for Graph-EFM +* `--encoder_processor_layers`: Number of GNN layers to use in the variatonal approximation for Graph-EFM +* `--prior_processor_layers`: Number of GNN layers to use in the latent map (prior) for Graph-EFM * `--ar_steps`: Number of time steps to unroll for when making predictions and computing the loss -Checkpoints of trained models are stored in the `saved_models` directory. +Checkpoints of trained models are stored in the `saved_models` directory when training. +For detailed hyperparameter settings we refer to the paper, in particular the appendices with model and experiment details. + The implemented models are: -### Graph-LAM -This is the basic graph-based LAM model. +### GraphCast +This is our re-implementation of GraphCast, and really can be used with any type of non-hierarchical graph (not just multi-scale). The encode-process-decode framework is used with a mesh graph in order to make one-step pedictions. -This model class is used both for the L1-LAM and GC-LAM models from the [paper](https://arxiv.org/abs/2309.17370), only with different graphs. -To train 1L-LAM use +To train GraphCast use ``` -python train_model.py --model graph_lam --graph 1level ... +python train_model.py --model graphcast --graph multiscale ... ``` -To train GC-LAM use +### Graph-FM +Deterministic graph-based forecasting model that uses a hierarchical mesh graph and performs sequential message passing through the hierarchy during processing. + +To train Graph-FM use ``` -python train_model.py --model graph_lam --graph multiscale ... +python train_model.py --model graph_fm --graph hierarchical ... ``` -### Hi-LAM -A version of Graph-LAM that uses a hierarchical mesh graph and performs sequential message passing through the hierarchy during processing. +### Graph-EFM +This is the probabibilistic graph-based ensemble model. +The same model can be used both with multi-scale and hierarchical graphs, with different behaviour internally. -To train Hi-LAM use +To train Graph-EFM use e.g. ``` -python train_model.py --model hi_lam --graph hierarchical ... +python train_model.py --model graph_efm --graph multiscale ... ``` - -### Hi-LAM-Parallel -A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in parallel. -Not included in the paper as initial experiments showed worse results than Hi-LAM, but could be interesting to try in more settings. - -To train Hi-LAM-Parallel use +or ``` -python train_model.py --model hi_lam_parallel --graph hierarchical ... +python train_model.py --model graph_efm --graph hierarchical ... ``` Checkpoint files for our models trained on the MEPS data are available upon request. @@ -178,8 +182,9 @@ Some options specifically important for evaluation are: * `--load`: Path to model checkpoint file (`.ckpt`) to load parameters from * `--n_example_pred`: Number of example predictions to plot during evaluation. +* `--ensemble_size`: Number of ensemble members to sample (for Graph-EFM) -**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this. +**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged if using a batch size > 1. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this. # Repository Structure Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory. @@ -270,16 +275,6 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels). Entries 0 in these lists describe edges between the lowest levels 1 and 2. -# Development and Contributing -Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks. -These hooks will run a series of checks on the code, like formatting and linting. -If any of these checks fail the push or PR will be rejected. -To test whether your code passes these checks before pushing, run -``` bash -pre-commit run --all-files -``` -from the root directory of the repository. - # Contact -If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. +For questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/figures/graph_efm_forecast_nlwrs.gif b/figures/graph_efm_forecast_nlwrs.gif new file mode 100644 index 00000000..8995ac80 Binary files /dev/null and b/figures/graph_efm_forecast_nlwrs.gif differ diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 527c31d8..8b971922 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -10,11 +10,16 @@ # Log prediction error for these lead times VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) +# Also save checkpoints for minimum loss at these lead times +VAL_STEP_CHECKPOINTS = (1, 19) # Log these metrics to wandb as scalar values for # specific variables and lead times # List of metrics to watch, including any prefix (e.g. val_rmse) -METRICS_WATCH = [] +METRICS_WATCH = [ + "val_spsk_ratio", + "val_spread", +] # Dict with variables and lead times to log watched metrics for # Format is a dictionary that maps from a variable index to # a list of lead time steps @@ -24,6 +29,18 @@ 15: [2, 19], # z_1000 } +# Plot forecasts for these variables at given lead times during validation step +# Format is a dictionary that maps from a variable index to a list of +# lead time steps +VAL_PLOT_VARS = { + 4: [2, 19], # r_2 + 14: [2, 19], # wvint_0 +} + +# During validation, plot example samples of latent variable from prior and +# variational distribution +LATENT_SAMPLES_PLOT = 4 # Number of samples to plot + # Variable names PARAM_NAMES = [ "pres_heightAboveGround_0_instant", @@ -67,8 +84,8 @@ PARAM_UNITS = [ "Pa", "Pa", - "W/m\\textsuperscript{2}", - "W/m\\textsuperscript{2}", + "W/m²", + "W/m²", "-", # unitless "-", "K", @@ -79,9 +96,9 @@ "m/s", "m/s", "m/s", - "kg/m\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", + "kg/m²", + "m²/s²", + "m²/s²", ] # Projection and grid diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 663f27e4..5f1c1dcf 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -131,6 +131,74 @@ def aggregate(self, inputs, index, ptr, dim_size): return aggr, inputs +class PropagationNet(InteractionNet): + """ + Alternative version of InteractionNet that incentivices the propagation of + information from sender nodes to receivers. + """ + + def __init__( + self, + edge_index, + input_dim, + update_edges=True, + hidden_layers=1, + hidden_dim=None, + edge_chunk_sizes=None, + aggr_chunk_sizes=None, + aggr="sum", + ): + # Use mean aggregation in propagation version to avoid instability + super().__init__( + edge_index, + input_dim, + update_edges=update_edges, + hidden_layers=hidden_layers, + hidden_dim=hidden_dim, + edge_chunk_sizes=edge_chunk_sizes, + aggr_chunk_sizes=aggr_chunk_sizes, + aggr="mean", + ) + + def forward(self, send_rep, rec_rep, edge_rep): + """ + Apply propagation network to update the representations of receiver + nodes, and optionally the edge representations. + + send_rep: (N_send, d_h), vector representations of sender nodes + rec_rep: (N_rec, d_h), vector representations of receiver nodes + edge_rep: (M, d_h), vector representations of edges used + + Returns: + rec_rep: (N_rec, d_h), updated vector representations of receiver nodes + (optionally) edge_rep: (M, d_h), updated vector representations + of edges + """ + # Always concatenate to [rec_nodes, send_nodes] for propagation, + # but only aggregate to rec_nodes + node_reps = torch.cat((rec_rep, send_rep), dim=-2) + edge_rep_aggr, edge_diff = self.propagate( + self.edge_index, x=node_reps, edge_attr=edge_rep + ) + rec_diff = self.aggr_mlp(torch.cat((rec_rep, edge_rep_aggr), dim=-1)) + + # Residual connections + rec_rep = edge_rep_aggr + rec_diff # residual is to aggregation + + if self.update_edges: + edge_rep = edge_rep + edge_diff + return rec_rep, edge_rep + + return rec_rep + + def message(self, x_j, x_i, edge_attr): + """ + Compute messages from node j to node i. + """ + # Residual connection is to sender node, propagating information to edge + return x_j + self.edge_mlp(torch.cat((edge_attr, x_j, x_i), dim=-1)) + + class SplitMLPs(nn.Module): """ Module that feeds chunks of input through different MLPs. diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py index 7db2cca6..1cac1f05 100644 --- a/neural_lam/metrics.py +++ b/neural_lam/metrics.py @@ -84,6 +84,8 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) +# Allow for unused pred_std for consistent signature +# pylint: disable-next=unused-argument def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ (Unweighted) Mean Squared Error @@ -92,7 +94,7 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target - pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. (unused) mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) sum_vars: boolean, if variable dimension -1 should be reduced (sum @@ -104,7 +106,7 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ # Replace pred_std with constant ones return wmse( - pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars + pred, target, torch.ones_like(pred), mask, average_grid, sum_vars ) @@ -139,6 +141,8 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) +# Allow for unused pred_std for consistent signature +# pylint: disable-next=unused-argument def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ (Unweighted) Mean Absolute Error @@ -147,7 +151,7 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): but broadcastable pred: (..., N, d_state), prediction target: (..., N, d_state), target - pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. (unused) mask: (N,), boolean mask describing which grid nodes to use in metric average_grid: boolean, if grid dimension -2 should be reduced (mean over N) sum_vars: boolean, if variable dimension -1 should be reduced (sum @@ -159,7 +163,7 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): """ # Replace pred_std with constant ones return wmae( - pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars + pred, target, torch.ones_like(pred), mask, average_grid, sum_vars ) @@ -227,6 +231,132 @@ def crps_gauss( ) +# Allow for unused pred_std for consistent signature +def crps_ens( + pred, + target, + pred_std, # pylint: disable=unused-argument + mask=None, + average_grid=True, + sum_vars=True, + ens_dim=1, +): + """ + (Negative) Continuous Ranked Probability Score (CRPS) + Unbiased estimator from samples. See e.g. Weatherbench 2. + + (..., M, ...,) is any number of batch dimensions, including ensemble + dimension M + pred: (..., M, ..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., M, ..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced + (sum over d_state) + ens_dim: batch dimension where ensemble members are laid out, to reduce over + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + num_ens = pred.shape[ens_dim] # Number of ensemble members + if num_ens == 1: + # With one sample CRPS reduces to MAE + return mae( + pred.squeeze(ens_dim), + target, + None, + mask=mask, + average_grid=average_grid, + ) + + if num_ens == 2: + mean_mae = torch.mean( + torch.abs(pred - target.unsqueeze(ens_dim)), dim=ens_dim + ) # (..., N, d_state) + + # Use simpler estimator + pair_diffs_term = -0.5 * torch.abs( + pred.select(ens_dim, 0) - pred.select(ens_dim, 1) + ) # (..., N, d_state) + + crps_estimator = mean_mae + pair_diffs_term # (..., N, d_state) + elif num_ens < 10: + # This is the rank-based implementation with O(M*log(M)) compute and + # O(M) memory. See Zamo and Naveau and WB2 for explanation. + # For smaller ensemble we can compute all of this directly in memory. + mean_mae = torch.mean( + torch.abs(pred - target.unsqueeze(ens_dim)), dim=ens_dim + ) # (..., N, d_state) + + # Ranks start at 1, two argsorts will compute entry ranks + ranks = pred.argsort(dim=ens_dim).argsort(ens_dim) + 1 + + pair_diffs_term = (1 / (num_ens - 1)) * torch.mean( + (num_ens + 1 - 2 * ranks) * pred, + dim=ens_dim, + ) # (..., N, d_state) + + crps_estimator = mean_mae + pair_diffs_term # (..., N, d_state) + else: + # For large ensembles we batch this over the variable dimension + crps_res = [] + for var_i in range(pred.shape[-1]): + pred_var = pred[..., var_i] + target_var = target[..., var_i] + + mean_mae = torch.mean( + torch.abs(pred_var - target_var.unsqueeze(ens_dim)), dim=ens_dim + ) # (..., N) + + # Ranks start at 1, two argsorts will compute entry ranks + ranks = pred_var.argsort(dim=ens_dim).argsort(ens_dim) + 1 + # (..., M, ..., N) + + pair_diffs_term = (1 / (num_ens - 1)) * torch.mean( + (num_ens + 1 - 2 * ranks) * pred_var, + dim=ens_dim, + ) # (..., N) + crps_res.append(mean_mae + pair_diffs_term) + + crps_estimator = torch.stack(crps_res, dim=-1) + + return mask_and_reduce_metric(crps_estimator, mask, average_grid, sum_vars) + + +def spread_squared( + pred, + target, # pylint: disable=unused-argument + pred_std, # pylint: disable=unused-argument + mask=None, + average_grid=True, + sum_vars=True, + ens_dim=1, +): + """ + (Squared) spread of ensemble. + Similarly to RMSE, we want to take sqrt after spatial and sample averaging, + so we need to average the squared spread. + + (..., M, ...,) is any number of batch dimensions, including ensemble + dimension M + pred: (..., M, ..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., M, ..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced + (sum over d_state) + ens_dim: batch dimension where ensemble members are laid out, to reduce over + + Returns: + metric_val: One of (...,), (..., d_state) depending on reduction arguments. + """ + entry_var = torch.var(pred, dim=ens_dim) # (..., N, d_state) + return mask_and_reduce_metric(entry_var, mask, average_grid, sum_vars) + + DEFINED_METRICS = { "mse": mse, "mae": mae, @@ -234,4 +364,6 @@ def crps_gauss( "wmae": wmae, "nll": nll, "crps_gauss": crps_gauss, + "crps_ens": crps_ens, + "spread_squared": spread_squared, } diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d0a8320..aa38b585 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -503,10 +503,13 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): metric_tensor_averaged = torch.mean(metric_tensor, dim=0) # (pred_steps, d_f) - # Take square root after all averaging to change MSE to RMSE + # Take square root after averaging to change squared metrics if "mse" in metric_name: metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) metric_name = metric_name.replace("mse", "rmse") + elif metric_name.endswith("_squared"): + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name[: -len("_squared")] # Note: we here assume rescaling for all metrics is linear metric_rescaled = metric_tensor_averaged * self.data_std diff --git a/neural_lam/models/base_graph_latent_decoder.py b/neural_lam/models/base_graph_latent_decoder.py new file mode 100644 index 00000000..b2222c01 --- /dev/null +++ b/neural_lam/models/base_graph_latent_decoder.py @@ -0,0 +1,105 @@ +# Third-party +from torch import nn + +# First-party +from neural_lam import constants, utils + + +class BaseGraphLatentDecoder(nn.Module): + """ + Decoder that maps grid input + latent variable on mesh to prediction on grid + """ + + def __init__( + self, + hidden_dim, + latent_dim, + hidden_layers=1, + output_std=True, + ): + super().__init__() + + # MLP for residual mapping of grid rep. + self.grid_update_mlp = utils.make_mlp( + [hidden_dim] * (hidden_layers + 2) + ) + + # Embedder for latent variable + self.latent_embedder = utils.make_mlp( + [latent_dim] + [hidden_dim] * (hidden_layers + 1) + ) + + # Either output input-dependent per-grid-node std or + # use common per-variable std + self.output_std = output_std + if self.output_std: + output_dim = 2 * constants.GRID_STATE_DIM + else: + output_dim = constants.GRID_STATE_DIM + + # Mapping to parameters of state distribution + self.param_map = utils.make_mlp( + [hidden_dim] * (hidden_layers + 1) + [output_dim], layer_norm=False + ) + + def combine_with_latent( + self, original_grid_rep, latent_rep, residual_grid_rep, graph_emb + ): + """ + Combine the grid representation with representation of latent variable. + The output should be on the grid again. + + original_grid_rep: (B, num_grid_nodes, d_h) + latent_rep: (B, num_mesh_nodes, d_h) + residual_grid_rep: (B, num_grid_nodes, d_h) + + Returns: + residual_grid_rep: (B, num_grid_nodes, d_h) + """ + raise NotImplementedError("combine_with_latent not implemented") + + def forward(self, grid_rep, latent_samples, last_state, graph_emb): + """ + Compute prediction (mean and std.-dev.) of next weather state + + grid_rep: (B, num_grid_nodes, d_h) + latent_samples: (B, N_mesh, d_latent) + last_state: (B, num_grid_nodes, d_state) + graph_emb: dict with graph embedding vectors, entries at least + g2m: (B, M_g2m, d_h) + m2m: (B, M_g2m, d_h) + m2g: (B, M_m2g, d_h) + + Returns: + mean: (B, N_mesh, d_latent), predicted mean + std: (B, N_mesh, d_latent), predicted std.-dev. + """ + # To mesh + latent_emb = self.latent_embedder(latent_samples) # (B, N_mesh, d_h) + + # Resiudal MLP for grid representation + residual_grid_rep = grid_rep + self.grid_update_mlp( + grid_rep + ) # (B, num_grid_nodes, d_h) + + combined_grid_rep = self.combine_with_latent( + grid_rep, latent_emb, residual_grid_rep, graph_emb + ) + + state_params = self.param_map( + combined_grid_rep + ) # (B, N_mesh, d_state_params) + + if self.output_std: + mean_delta, std_raw = state_params.chunk( + 2, dim=-1 + ) # (B, num_grid_nodes, d_state),(B, num_grid_nodes, d_state) + # pylint: disable-next=not-callable + pred_std = nn.functional.softplus(std_raw) # positive std. + else: + mean_delta = state_params # (B, num_grid_nodes, d_state) + pred_std = None + + pred_mean = last_state + mean_delta + + return pred_mean, pred_std diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 256d4adc..b0a06f1a 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -3,7 +3,7 @@ # First-party from neural_lam import utils -from neural_lam.interaction_net import InteractionNet +from neural_lam.interaction_net import InteractionNet, PropagationNet from neural_lam.models.ar_model import ARModel @@ -48,8 +48,9 @@ def __init__(self, args): self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) # GNNs + gnn_class = PropagationNet if args.vertical_propnets else InteractionNet # encoder - self.g2m_gnn = InteractionNet( + self.g2m_gnn = gnn_class( self.g2m_edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, @@ -60,7 +61,7 @@ def __init__(self, args): ) # decoder - self.m2g_gnn = InteractionNet( + self.m2g_gnn = gnn_class( self.m2g_edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, diff --git a/neural_lam/models/base_latent_encoder.py b/neural_lam/models/base_latent_encoder.py new file mode 100644 index 00000000..7ddc18ac --- /dev/null +++ b/neural_lam/models/base_latent_encoder.py @@ -0,0 +1,72 @@ +# Third-party +import torch +from torch import distributions as tdists +from torch import nn + + +class BaseLatentEncoder(nn.Module): + """ + Abstract class for encoder that maps input to distribution + over latent variable + """ + + def __init__( + self, + latent_dim, + output_dist="isotropic", + ): + super().__init__() + + # Mapping to parameters of latent distribution + self.output_dist = output_dist + if output_dist == "isotropic": + # Isotopic Gaussian, output only mean (\Sigma = I) + self.output_dim = latent_dim + elif output_dist == "diagonal": + # Isotopic Gaussian, output mean and std + self.output_dim = 2 * latent_dim + + # Small epsilon to prevent enccoding to dist. with std.-dev. 0 + self.latent_std_eps = 1e-4 + else: + assert False, f"Unknown encoder output distribution: {output_dist}" + + def compute_dist_params(self, grid_rep, **kwargs): + """ + Compute parameters of distribution over latent variable using the + grid representation + + grid_rep: (B, num_grid_nodes, d_h) + + Returns: + parameters: (B, num_mesh_nodes, d_output) + """ + raise NotImplementedError("compute_dist_params not implemented") + + def forward(self, grid_rep, **kwargs): + """ + Compute distribution over latent variable + + grid_rep: (B, N_grid, d_h) + mesh_rep: (B, N_mesh, d_h) + g2m_rep: (B, M_g2m, d_h) + + Returns: + distribution: latent var. dist. shaped (B, N_mesh, d_latent) + """ + latent_dist_params = self.compute_dist_params(grid_rep, **kwargs) + + if self.output_dist == "diagonal": + latent_mean, latent_std_raw = latent_dist_params.chunk( + 2, dim=-1 + ) # (B, N_mesh, d_latent) and (B, N_mesh, d_latent) + # pylint: disable-next=not-callable + latent_std = self.latent_std_eps + nn.functional.softplus( + latent_std_raw + ) # positive std. + else: + # isotropic + latent_mean = latent_dist_params + latent_std = torch.ones_like(latent_mean) + + return tdists.Normal(latent_mean, latent_std) diff --git a/neural_lam/models/constant_latent_encoder.py b/neural_lam/models/constant_latent_encoder.py new file mode 100644 index 00000000..582d6faa --- /dev/null +++ b/neural_lam/models/constant_latent_encoder.py @@ -0,0 +1,41 @@ +# Third-party +import torch + +# First-party +from neural_lam.models.base_latent_encoder import BaseLatentEncoder + + +class ConstantLatentEncoder(BaseLatentEncoder): + """ + Latent encoder parametrizing constant distribution + """ + + def __init__( + self, + latent_dim, + num_mesh_nodes, + output_dist="isotropic", + ): + super().__init__( + latent_dim, + output_dist, + ) + + self.num_mesh_nodes = num_mesh_nodes + + def compute_dist_params(self, grid_rep, **kwargs): + """ + Compute parameters of distribution over latent variable using the + grid representation + + grid_rep: (B, num_grid_nodes, d_h) + + Returns: + parameters: (B, num_mesh_nodes, d_output) + """ + return torch.ones( + grid_rep.shape[0], + self.num_mesh_nodes, + self.output_dim, + device=grid_rep.device, + ) # (B, num_mesh_nodes, d_output) diff --git a/neural_lam/models/graph_efm.py b/neural_lam/models/graph_efm.py new file mode 100644 index 00000000..d7cc2876 --- /dev/null +++ b/neural_lam/models/graph_efm.py @@ -0,0 +1,1128 @@ +# Third-party +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb + +# First-party +from neural_lam import constants, metrics, utils, vis +from neural_lam.models.ar_model import ARModel +from neural_lam.models.constant_latent_encoder import ConstantLatentEncoder +from neural_lam.models.graph_latent_decoder import GraphLatentDecoder +from neural_lam.models.graph_latent_encoder import GraphLatentEncoder +from neural_lam.models.hi_graph_latent_decoder import HiGraphLatentDecoder +from neural_lam.models.hi_graph_latent_encoder import HiGraphLatentEncoder + + +class GraphEFM(ARModel): + """ + Graph-based Ensemble Forecasting Model + """ + + def __init__(self, args): + super().__init__(args) + + assert ( + args.n_example_pred <= args.batch_size + ), "Can not plot more examples than batch size in GraphEFM" + self.sample_obs_noise = bool(args.sample_obs_noise) + self.ensemble_size = args.ensemble_size + self.kl_beta = args.kl_beta + self.crps_weight = args.crps_weight + + # Load graph with static features + self.hierarchical_graph, graph_ldict = utils.load_graph(args.graph) + for name, attr_value in graph_ldict.items(): + # Make BufferLists module members and register tensors as buffers + if isinstance(attr_value, torch.Tensor): + self.register_buffer(name, attr_value, persistent=False) + else: + setattr(self, name, attr_value) + + # Specify dimensions of data + # grid_dim from data + static + grid_current_dim = self.grid_dim + constants.GRID_STATE_DIM + g2m_dim = self.g2m_features.shape[1] + m2g_dim = self.m2g_features.shape[1] + + # Define sub-models + # Feature embedders for grid + self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) + self.grid_prev_embedder = utils.make_mlp( + [self.grid_dim] + self.mlp_blueprint_end + ) # For states up to t-1 + self.grid_current_embedder = utils.make_mlp( + [grid_current_dim] + self.mlp_blueprint_end + ) # For states including t + # Embedders for mesh + self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) + self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) + if self.hierarchical_graph: + # Print some useful info + print("Loaded hierarchical graph with structure:") + level_mesh_sizes = [ + mesh_feat.shape[0] for mesh_feat in self.mesh_static_features + ] + self.num_mesh_nodes = level_mesh_sizes[-1] + num_levels = len(self.mesh_static_features) + for level_index, level_mesh_size in enumerate(level_mesh_sizes): + same_level_edges = self.m2m_features[level_index].shape[0] + print( + f"level {level_index} - {level_mesh_size} nodes, " + f"{same_level_edges} same-level edges" + ) + + if level_index < (num_levels - 1): + up_edges = self.mesh_up_features[level_index].shape[0] + down_edges = self.mesh_down_features[level_index].shape[0] + print(f" {level_index}<->{level_index+1}") + print(f" - {up_edges} up edges, {down_edges} down edges") + # Embedders + # Assume all levels have same static feature dimensionality + mesh_dim = self.mesh_static_features[0].shape[1] + m2m_dim = self.m2m_features[0].shape[1] + mesh_up_dim = self.mesh_up_features[0].shape[1] + mesh_down_dim = self.mesh_down_features[0].shape[1] + + # Separate mesh node embedders for each level + self.mesh_embedders = torch.nn.ModuleList( + [ + utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) + for _ in range(num_levels) + ] + ) + self.mesh_up_embedders = torch.nn.ModuleList( + [ + utils.make_mlp([mesh_up_dim] + self.mlp_blueprint_end) + for _ in range(num_levels - 1) + ] + ) + self.mesh_down_embedders = torch.nn.ModuleList( + [ + utils.make_mlp([mesh_down_dim] + self.mlp_blueprint_end) + for _ in range(num_levels - 1) + ] + ) + # If not using any processor layers, no need to embed m2m + self.embedd_m2m = ( + max( + args.prior_processor_layers, + args.encoder_processor_layers, + args.processor_layers, + ) + > 0 + ) + if self.embedd_m2m: + self.m2m_embedders = torch.nn.ModuleList( + [ + utils.make_mlp([m2m_dim] + self.mlp_blueprint_end) + for _ in range(num_levels) + ] + ) + else: + self.num_mesh_nodes, mesh_static_dim = ( + self.mesh_static_features.shape + ) + print( + f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes}" + f"nodes ({self.num_grid_nodes} grid, " + f"{self.num_mesh_nodes} mesh)" + ) + mesh_static_dim = self.mesh_static_features.shape[1] + self.mesh_embedder = utils.make_mlp( + [mesh_static_dim] + self.mlp_blueprint_end + ) + m2m_dim = self.m2m_features.shape[1] + self.m2m_embedder = utils.make_mlp( + [m2m_dim] + self.mlp_blueprint_end + ) + + latent_dim = ( + args.latent_dim if args.latent_dim is not None else args.hidden_dim + ) + # Prior + if args.learn_prior: + if self.hierarchical_graph: + self.prior_model = HiGraphLatentEncoder( + latent_dim, + self.g2m_edge_index, + self.m2m_edge_index, + self.mesh_up_edge_index, + args.hidden_dim, + args.prior_processor_layers, + hidden_layers=args.hidden_layers, + output_dist=args.prior_dist, + ) + else: + self.prior_model = GraphLatentEncoder( + latent_dim, + self.g2m_edge_index, + self.m2m_edge_index, + args.hidden_dim, + args.prior_processor_layers, + hidden_layers=args.hidden_layers, + output_dist=args.prior_dist, + ) + else: + self.prior_model = ConstantLatentEncoder( + latent_dim, + self.num_mesh_nodes, + output_dist=args.prior_dist, + ) + + # Enc. + Dec. + if self.hierarchical_graph: + # Encoder + self.encoder = HiGraphLatentEncoder( + latent_dim, + self.g2m_edge_index, + self.m2m_edge_index, + self.mesh_up_edge_index, + args.hidden_dim, + args.encoder_processor_layers, + hidden_layers=args.hidden_layers, + output_dist="diagonal", + ) + # Decoder + self.decoder = HiGraphLatentDecoder( + self.g2m_edge_index, + self.m2m_edge_index, + self.m2g_edge_index, + self.mesh_up_edge_index, + self.mesh_down_edge_index, + args.hidden_dim, + latent_dim, + args.processor_layers, + hidden_layers=args.hidden_layers, + output_std=bool(args.output_std), + ) + else: + # Encoder + self.encoder = GraphLatentEncoder( + latent_dim, + self.g2m_edge_index, + self.m2m_edge_index, + args.hidden_dim, + args.encoder_processor_layers, + hidden_layers=args.hidden_layers, + output_dist="diagonal", + ) + # Decoder + self.decoder = GraphLatentDecoder( + self.g2m_edge_index, + self.m2m_edge_index, + self.m2g_edge_index, + args.hidden_dim, + latent_dim, + args.processor_layers, + hidden_layers=args.hidden_layers, + output_std=bool(args.output_std), + ) + + # Add lists for val and test errors of ensemble prediction + self.val_metrics.update( + { + "spread_squared": [], + "ens_mse": [], + } + ) + self.test_metrics.update( + { + "ens_mae": [], + "ens_mse": [], + "crps_ens": [], + "spread_squared": [], + } + ) + + def sample_next_state(self, pred_mean, pred_std): + """ + Sample state at next time step given Gaussian distribution. + If self.sample_obs_noise is False, only return mean. + + pred_mean: (B, num_grid_nodes, d_state) + pred_std: (B, num_grid_nodes, d_state) or + None (if not output_std) + + Return: + next_state: (B, num_grid_nodes, d_state) + """ + if not self.output_std: + pred_std = self.per_var_std # (d_f,) + + if self.sample_obs_noise: + return torch.distributions.Normal(pred_mean, pred_std).rsample() + # (B, num_grid_nodes, d_state) + + return pred_mean # (B, num_grid_nodes, d_state) + + def embedd_current( + self, + prev_state, + prev_prev_state, + forcing, + current_state, + ): + """ + embed grid representation including current (target) state. Used as + input to the encoder, which is conditioned also on the target. + + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) + current_state: (B, num_grid_nodes, feature_dim), X_{t+1} + + Returns: + current_emb: (B, num_grid_nodes, d_h) + """ + batch_size = prev_state.shape[0] + + grid_current_features = torch.cat( + ( + prev_prev_state, + prev_state, + forcing, + self.expand_to_batch(self.grid_static_features, batch_size), + current_state, + ), + dim=-1, + ) # (B, num_grid_nodes, grid_current_dim) + + return self.grid_current_embedder( + grid_current_features + ) # (B, num_grid_nodes, d_h) + + def embedd_all(self, prev_state, prev_prev_state, forcing): + """ + embed all node and edge representations + + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) + + Returns: + grid_emb: (B, num_grid_nodes, d_h) + graph_embedding: dict with entries of shape (B, *, d_h) + """ + batch_size = prev_state.shape[0] + + grid_features = torch.cat( + ( + prev_prev_state, + prev_state, + forcing, + self.expand_to_batch(self.grid_static_features, batch_size), + ), + dim=-1, + ) # (B, num_grid_nodes, grid_dim) + + grid_emb = self.grid_prev_embedder(grid_features) + # (B, num_grid_nodes, d_h) + + # Graph embedding + graph_emb = { + "g2m": self.expand_to_batch( + self.g2m_embedder(self.g2m_features), batch_size + ), # (B, M_g2m, d_h) + "m2g": self.expand_to_batch( + self.m2g_embedder(self.m2g_features), batch_size + ), # (B, M_m2g, d_h) + } + + if self.hierarchical_graph: + graph_emb["mesh"] = [ + self.expand_to_batch(emb(node_static_features), batch_size) + for emb, node_static_features in zip( + self.mesh_embedders, + self.mesh_static_features, + ) + ] # each (B, num_mesh_nodes[l], d_h) + + if self.embedd_m2m: + graph_emb["m2m"] = [ + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.m2m_embedders, self.m2m_features + ) + ] + else: + # Need a placeholder otherwise, just use raw features + graph_emb["m2m"] = list(self.m2m_features) + + graph_emb["mesh_up"] = [ + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_up_embedders, self.mesh_up_features + ) + ] + graph_emb["mesh_down"] = [ + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_down_embedders, self.mesh_down_features + ) + ] + else: + graph_emb["mesh"] = self.expand_to_batch( + self.mesh_embedder(self.mesh_static_features), batch_size + ) # (B, num_mesh_nodes, d_h) + graph_emb["m2m"] = self.expand_to_batch( + self.m2m_embedder(self.m2m_features), batch_size + ) # (B, M_m2m, d_h) + + return grid_emb, graph_emb + + def compute_step_loss( + self, + prev_states, + current_state, + forcing_features, + ): + """ + Perform forward pass and compute loss for one time step + + prev_states: (B, 2, num_grid_nodes, d_features), X^{t-p}, ..., X^{t-1} + current_state: (B, num_grid_nodes, d_features) X^t + forcing_features: (B, num_grid_nodes, d_forcing) corresponding to + index 1 of prev_states + """ + # embed all features + grid_prev_emb, graph_emb = self.embedd_all( + prev_states[:, 1], + prev_states[:, 0], + forcing_features, + ) + # embed also including current grid state, for encoder + grid_current_emb = self.embedd_current( + prev_states[:, 1], + prev_states[:, 0], + forcing_features, + current_state, + ) # (B, num_grid_nodes, d_h) + + # Compute variational approximation (encoder) + var_dist = self.encoder( + grid_current_emb, graph_emb=graph_emb + ) # Gaussian, (B, num_mesh_nodes, d_latent) + + # Compute likelihood + last_state = prev_states[:, -1] + likelihood_term, pred_mean, pred_std = self.estimate_likelihood( + var_dist, current_state, last_state, grid_prev_emb, graph_emb + ) + if self.kl_beta > 0: + # Compute prior + prior_dist = self.prior_model( + grid_prev_emb, graph_emb=graph_emb + ) # Gaussian, (B, num_mesh_nodes, d_latent) + + # Compute KL + kl_term = torch.sum( + torch.distributions.kl_divergence(var_dist, prior_dist), + dim=(1, 2), + ) # (B,) + else: + # If beta=0, do not need to even compute prior nor KL + kl_term = None # Set to None to crash if erroneously used + + return likelihood_term, kl_term, pred_mean, pred_std + + def estimate_likelihood( + self, latent_dist, current_state, last_state, grid_prev_emb, graph_emb + ): + """ + Estimate (masked) likelihood using given distribution over + latent variables + + latent_dist: distribution, (B, num_mesh_nodes, d_latent) + current_state: (B, num_grid_nodes, d_state) + last_state: (B, num_grid_nodes, d_state) + grid_prev_emb: (B, num_grid_nodes, d_state) + g2m_emb: (B, M_g2m, d_h) + m2m_emb: (B, M_m2m, d_h) + m2g_emb: (B, M_m2g, d_h) + + Returns: + likelihood_term: (B,) + pred_mean: (B, num_grid_nodes, d_state) + pred_std: (B, num_grid_nodes, d_state) or (d_state,) + """ + # Sample from variational distribution + latent_samples = latent_dist.rsample() # (B, num_mesh_nodes, d_latent) + + # Compute reconstruction (decoder) + pred_mean, model_pred_std = self.decoder( + grid_prev_emb, latent_samples, last_state, graph_emb + ) # both (B, num_grid_nodes, d_state) + + if self.output_std: + pred_std = model_pred_std # (B, num_grid_nodes, d_state) + else: + # Use constant set std.-devs. + pred_std = self.per_var_std # (d_f,) + + # Compute likelihood (negative loss, exactly likelihood for nll loss) + # Note: There are some round-off errors here due to float32 + # and large values + entry_likelihoods = -self.loss( + pred_mean, + current_state, + pred_std, + mask=self.interior_mask_bool, + average_grid=False, + sum_vars=False, + ) # (B, num_grid_nodes', d_state) + likelihood_term = torch.sum(entry_likelihoods, dim=(1, 2)) # (B,) + return likelihood_term, pred_mean, pred_std + + def training_step(self, batch): + """ + Train on single batch + + batch, containing: + init_states: (B, 2, num_grid_nodes, d_state) + target_states: (B, pred_steps, num_grid_nodes, d_state) + forcing_features: (B, pred_steps, num_grid_nodes, d_forcing), where + index 0 corresponds to index 1 of init_states + """ + init_states, target_states, forcing_features = batch + + prev_prev_state = init_states[:, 0] # (B, num_grid_nodes, d_state) + prev_state = init_states[:, 1] # (B, num_grid_nodes, d_state) + pred_steps = forcing_features.shape[1] + + loss_like_list = [] + loss_kl_list = [] + + for i in range(pred_steps): + forcing = forcing_features[:, i] # (B, num_grid_nodes, d_forcing) + target_state = target_states[:, i] # (B, num_grid_nodes, d_state) + + prev_states_stacked = torch.stack( + (prev_prev_state, prev_state), dim=1 + ) # (B, 2, num_grid_nodes, d_state) + loss_like_term, loss_kl_term, pred_mean, pred_std = ( + self.compute_step_loss( + prev_states_stacked, + target_state, + forcing, + ) + ) + # (B,), (B,), (B, num_grid_nodes, d_state), + # pred_std is (B, num_grid_nodes, d_state) or (d_state) + + loss_like_list.append(loss_like_term) + loss_kl_list.append(loss_kl_term) + + # Get predicted next state (sample or mean) + predicted_state = self.sample_next_state(pred_mean, pred_std) + + # Overwrite border with true state + new_state = ( + self.border_mask * target_state + + self.interior_mask * predicted_state + ) + + # Update conditioning states + prev_prev_state = prev_state + prev_state = new_state + + # Compute final ELBO and loss, sum over time, mean over batch + per_sample_likelihood = torch.sum( + torch.stack(loss_like_list, dim=1), dim=1 + ) # (B,) + mean_likelihood = torch.mean(per_sample_likelihood) + log_dict = { + "elbo_likelihood": mean_likelihood, + } + + if self.kl_beta > 0: + # Only compute full KL + ELBO if beta > 0 + per_sample_kl = torch.sum( + torch.stack(loss_kl_list, dim=1), dim=1 + ) # (B,) + mean_kl = torch.mean(per_sample_kl) + elbo = mean_likelihood - mean_kl + loss = -mean_likelihood + self.kl_beta * mean_kl + + log_dict["elbo"] = elbo + log_dict["elbo_kl"] = mean_kl + else: + # Pure auto-encoder training + loss = -mean_likelihood + + # Optionally sample trajectories and compute CRPS loss + if self.crps_weight > 0: + # Sample trajectories using prior + pred_traj_means, pred_traj_stds = self.sample_trajectories( + init_states, + forcing_features, + target_states, + 2, + ) + # (B, S=2, pred_steps, num_grid_nodes, d_f), always 2 samples + + # Compute CRPS + crps_estimate = metrics.crps_ens( + pred_traj_means, + target_states, + pred_traj_stds, + mask=self.interior_mask_bool, + ) # (B, pred_steps) + crps_loss = torch.mean(crps_estimate) + + # Add onto loss + loss = loss + self.crps_weight * crps_loss + log_dict["crps_loss"] = crps_loss + + log_dict["train_loss"] = loss + self.log_dict( + log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + ) + return loss + + def predict_step(self, prev_state, prev_prev_state, forcing): + """ + Sample one time step prediction + + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) + + Returns: + new_state: (B, num_grid_nodes, feature_dim) + """ + # embed all features + grid_prev_emb, graph_emb = self.embedd_all( + prev_state, prev_prev_state, forcing + ) + + # Compute prior + prior_dist = self.prior_model( + grid_prev_emb, graph_emb=graph_emb + ) # (B, num_mesh_nodes, d_latent) + + # Sample from prior + latent_samples = prior_dist.rsample() + # (B, num_mesh_nodes, d_latent) + + # Compute reconstruction (decoder) + last_state = prev_state + pred_mean, pred_std = self.decoder( + grid_prev_emb, latent_samples, last_state, graph_emb + ) # (B, num_grid_nodes, d_state) + + return self.sample_next_state(pred_mean, pred_std), pred_std + + def sample_trajectories( + self, + init_states, + forcing_features, + true_states, + num_traj, + use_encoder=False, + ): + """ + init_states: (B, 2, num_grid_nodes, d_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + true_states: (B, pred_steps, num_grid_nodes, d_f) + num_traj: S, number of trajectories to sample + use_encoder: bool, if latent variables should be sampled from + var. distribution + + Returns + traj_means: (B, S, pred_steps, num_grid_nodes, d_f) + traj_stds: (B, S, pred_steps, num_grid_nodes, d_f) or (d_f) + """ + unroll_func = ( + self.unroll_prediction_vi if use_encoder else self.unroll_prediction + ) + traj_list = [ + unroll_func( + init_states, + forcing_features, + true_states, + ) + for _ in range(num_traj) + ] + # List of tuples, each containing + # mean: (B, pred_steps, num_grid_nodes, d_f) and + # std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + + traj_means = torch.stack( + [pred_pair[0] for pred_pair in traj_list], dim=1 + ) + if self.output_std: + traj_stds = torch.stack( + [pred_pair[1] for pred_pair in traj_list], dim=1 + ) + else: + traj_stds = self.per_var_std + + return traj_means, traj_stds + + def unroll_prediction_vi(self, init_states, forcing_features, true_states): + """ + Roll out prediction, sampling latent var. from variational + encoder distribution + + init_states: (B, 2, num_grid_nodes, d_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + true_states: (B, pred_steps, num_grid_nodes, d_f) + """ + prev_prev_state = init_states[:, 0] + prev_state = init_states[:, 1] + prediction_list = [] + pred_std_list = [] + pred_steps = forcing_features.shape[1] + + for i in range(pred_steps): + # Compute 1-step prediction, but using encoder + forcing = forcing_features[:, i] + current_state = true_states[:, i] + + # embed all features + grid_prev_emb, graph_emb = self.embedd_all( + prev_state, prev_prev_state, forcing + ) + + # embed also including current grid state, for encoder + grid_current_emb = self.embedd_current( + prev_state, + prev_prev_state, + forcing, + current_state, + ) + + # Compute variational distribution + var_dist = self.encoder( + grid_current_emb, graph_emb=graph_emb + ) # Gaussian, (B, num_mesh_nodes, d_latent) + + # Sample from var. dist. + latent_samples = var_dist.rsample() + # (B, num_mesh_nodes, d_latent) + + # Compute reconstruction (decoder) + pred_mean, pred_std = self.decoder( + grid_prev_emb, latent_samples, prev_state, graph_emb + ) # (B, num_grid_nodes, d_state) + + pred_state = self.sample_next_state(pred_mean, pred_std) + # pred_state: (B, num_grid_nodes, d_f) + # pred_std: (B, num_grid_nodes, d_f) or None + + # Overwrite border with true state + new_state = ( + self.border_mask * current_state + + self.interior_mask * pred_state + ) + + prediction_list.append(new_state) + if self.output_std: + pred_std_list.append(pred_std) + + # Update conditioning states + prev_prev_state = prev_state + prev_state = new_state + + prediction = torch.stack( + prediction_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + if self.output_std: + pred_std = torch.stack( + pred_std_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + else: + pred_std = self.per_var_std # (d_f,) + + return prediction, pred_std + + def plot_examples(self, batch, n_examples, prediction=None): + """ + Plot ensemble forecast + mean and std + """ + init_states, target_states, forcing_features = batch + + trajectories, _ = self.sample_trajectories( + init_states, + forcing_features, + target_states, + self.ensemble_size, + ) + # (B, S, pred_steps, num_grid_nodes, d_f) + + # Rescale to original data scale + traj_rescaled = trajectories * self.data_std + self.data_mean + target_rescaled = target_states * self.data_std + self.data_mean + + # Compute mean and std of ensemble + ens_mean = torch.mean( + traj_rescaled, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + ens_std = torch.std( + traj_rescaled, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + + # Iterate over the examples + for traj_slice, target_slice, ens_mean_slice, ens_std_slice in zip( + traj_rescaled[:n_examples], + target_rescaled[:n_examples], + ens_mean[:n_examples], + ens_std[:n_examples], + ): + # traj_slice is (S, pred_steps, num_grid_nodes, d_f) + # others are (pred_steps, num_grid_nodes, d_f) + self.plotted_examples += 1 # Increment already here + + # Note: min and max values can not be in ensemble mean + var_vmin = ( + torch.minimum( + traj_slice.flatten(0, 2).min(dim=0)[0], + target_slice.flatten(0, 1).min(dim=0)[0], + ) + .cpu() + .numpy() + ) # (d_f,) + var_vmax = ( + torch.maximum( + traj_slice.flatten(0, 2).max(dim=0)[0], + target_slice.flatten(0, 1).max(dim=0)[0], + ) + .cpu() + .numpy() + ) # (d_f,) + var_vranges = list(zip(var_vmin, var_vmax)) + + # Iterate over prediction horizon time steps + for t_i, (samples_t, target_t, ens_mean_t, ens_std_t) in enumerate( + zip( + traj_slice.transpose(0, 1), + # (pred_steps, S, num_grid_nodes, d_f) + target_slice, + ens_mean_slice, + ens_std_slice, + ), + start=1, + ): + time_title_part = f"t={t_i} ({self.step_length*t_i} h)" + # Create one figure per variable at this time step + var_figs = [ + vis.plot_ensemble_prediction( + samples_t[:, :, var_i], + target_t[:, var_i], + ens_mean_t[:, var_i], + ens_std_t[:, var_i], + self.interior_mask[:, 0], + title=f"{var_name} ({var_unit}), {time_title_part}", + vrange=var_vrange, + ) + for var_i, (var_name, var_unit, var_vrange) in enumerate( + zip( + constants.PARAM_NAMES_SHORT, + constants.PARAM_UNITS, + var_vranges, + ) + ) + ] + + example_title = f"example_{self.plotted_examples}" + wandb.log( + { + f"{var_name}_{example_title}": wandb.Image(fig) + for var_name, fig in zip( + constants.PARAM_NAMES_SHORT, var_figs + ) + } + ) + plt.close( + "all" + ) # Close all figs for this time step, saves memory + + def ensemble_common_step(self, batch): + """ + Perform ensemble forecast and compute basic metrics. + Common step done during both evaluation and testing + + batch: tuple of tensors, batch to perform ensemble forecast on + + Returns: + trajectories: (B, S, pred_steps, num_grid_nodes, d_f) + traj_stds: (B, S, pred_steps, num_grid_nodes, d_f) + target_states: (B, pred_steps, num_grid_nodes, d_f) + spread_squared_batch: (B, pred_steps, d_f) + ens_mse_batch: (B, pred_steps, d_f) + """ + # Compute and store metrics for ensemble forecast + init_states, target_states, forcing_features = batch + + trajectories, traj_stds = self.sample_trajectories( + init_states, + forcing_features, + target_states, + self.ensemble_size, + ) + # (B, S, pred_steps, num_grid_nodes, d_f) + + spread_squared_batch = metrics.spread_squared( + trajectories, + target_states, + traj_stds, + mask=self.interior_mask_bool, + sum_vars=False, + ) + # (B, pred_steps, d_f) + + ens_mean = torch.mean( + trajectories, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + ens_mse_batch = metrics.mse( + ens_mean, + target_states, + None, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + + return ( + trajectories, + traj_stds, + target_states, + spread_squared_batch, + ens_mse_batch, + ) + + def validation_step(self, batch, *args): + """ + Run validation on single batch + """ + super().validation_step(batch, *args) + batch_idx = args[0] + + # Run ensemble forecast + prior_trajectories, _, _, spread_squared_batch, ens_mse_batch = ( + self.ensemble_common_step(batch) + ) + self.val_metrics["spread_squared"].append(spread_squared_batch) + self.val_metrics["ens_mse"].append(ens_mse_batch) + + # Plot some example predictions using prior and encoder + if ( + self.trainer.is_global_zero + and batch_idx == 0 + and self.n_example_pred > 0 + ): + # Roll out trajectories using variational distribution (encoder) + ( + init_states, + target_states, + forcing_features, + ) = batch + # Only create ens. forecast for as many examples as needed + init_states = init_states[: self.n_example_pred] + target_states = target_states[: self.n_example_pred] + forcing_features = forcing_features[: self.n_example_pred] + + # Sample trajectories using variational dist. for latent var. + enc_trajectories, _ = self.sample_trajectories( + init_states, + forcing_features, + target_states, + self.ensemble_size, + use_encoder=True, + ) + + # Only need n_example_pred prior trajectories + prior_trajectories = prior_trajectories[: self.n_example_pred] + + # Plot samples + log_plot_dict = {} + for example_i, (prior_traj, enc_traj, target_traj) in enumerate( + zip(prior_trajectories, enc_trajectories, target_states), + start=1, + ): + # prior_traj and enc traj are + # (S, pred_steps, num_grid_nodes, d_f) + + for var_i, timesteps in constants.VAL_PLOT_VARS.items(): + var_name = constants.PARAM_NAMES_SHORT[var_i] + var_unit = constants.PARAM_UNITS[var_i] + for step in timesteps: + prior_states = prior_traj[ + :, step - 1, :, var_i + ] # (S, num_grid_nodes) + enc_states = enc_traj[ + :, step - 1, :, var_i + ] # (S, num_grid_nodes) + target_state = target_traj[ + step - 1, :, var_i + ] # (num_grid_nodes,) + + plot_title = ( + f"{var_name} ({var_unit}), t={step} " + f"({self.step_length*step} h)" + ) + + # Make plots + log_plot_dict[ + f"prior_{var_name}_step_{step}_ex{example_i}" + ] = vis.plot_ensemble_prediction( + prior_states, + target_state, + prior_states.mean(dim=0), + prior_states.std(dim=0), + self.interior_mask[:, 0], + title=f"{plot_title} (prior)", + ) + log_plot_dict[ + f"vi_{var_name}_step_{step}_ex{example_i}" + ] = vis.plot_ensemble_prediction( + enc_states, + target_state, + enc_states.mean(dim=0), + enc_states.std(dim=0), + self.interior_mask[:, 0], + title=f"{plot_title} (vi)", + ) + + # Sample latent variable and plot + # embed all features + grid_prev_emb, graph_emb = self.embedd_all( + init_states[:, 1], + init_states[:, 0], + forcing_features[:, 0], + ) # (B, num_grid_nodes, d_h) + # embed also including current grid state, for encoder + grid_current_emb = self.embedd_current( + init_states[:, 1], + init_states[:, 0], + forcing_features[:, 0], + target_states[:, 0], + ) # (B, num_grid_nodes, d_h) + + # Create latent variable samples + prior_dist = self.prior_model( + grid_prev_emb, graph_emb=graph_emb + ) # Gaussian, (B, num_mesh_nodes, d_latent) + prior_samples = prior_dist.rsample( + (constants.LATENT_SAMPLES_PLOT,) + ).transpose( + 0, 1 + ) # (B, samples, num_mesh_nodes, d_latent) + + vi_dist = self.encoder( + grid_current_emb, graph_emb=graph_emb + ) # Gaussian, (B, num_mesh_nodes, d_latent) + vi_samples = vi_dist.rsample( + (constants.LATENT_SAMPLES_PLOT,) + ).transpose( + 0, 1 + ) # (B, samples, num_mesh_nodes, d_latent) + + # Make plot for each example + for example_i, (prior_ex_samples, vi_ex_samples) in enumerate( + zip(prior_samples, vi_samples), start=1 + ): + log_plot_dict[f"latent_samples_ex{example_i}"] = ( + vis.plot_latent_samples(prior_ex_samples, vi_ex_samples) + ) + + if not self.trainer.sanity_checking: + # Log all plots to wandb + wandb.log(log_plot_dict) + + plt.close("all") + + def log_spsk_ratio(self, metric_vals, prefix): + """ + Compute the mean spread-skill ratio for logging in evaluation + + metric_vals: dict with all metric values + prefix: string, prefix to use for logging + """ + # Compute mean spsk_ratio + spread_squared_tensor = self.all_gather_cat( + torch.cat(metric_vals["spread_squared"], dim=0) + ) # (N_eval, pred_steps, d_f) + ens_mse_tensor = self.all_gather_cat( + torch.cat(metric_vals["ens_mse"], dim=0) + ) # (N_eval, pred_steps, d_f) + + # Do not log during sanity check? + if self.trainer.is_global_zero and not self.trainer.sanity_checking: + # Note that spsk_ratio is scale-invariant, so do not have to rescale + spread = torch.sqrt(torch.mean(spread_squared_tensor, dim=0)) + skill = torch.sqrt(torch.mean(ens_mse_tensor, dim=0)) + # Both (pred_steps, d_f) + + # Include finite sample correction + spsk_ratios = np.sqrt( + (self.ensemble_size + 1) / self.ensemble_size + ) * ( + spread / skill + ) # (pred_steps, d_f) + log_dict = self.create_metric_log_dict( + spsk_ratios, prefix, "spsk_ratio" + ) + + log_dict[f"{prefix}_mean_spsk_ratio"] = torch.mean( + spsk_ratios + ) # log mean + wandb.log(log_dict) + + def on_validation_epoch_end(self): + """ + Compute val metrics at the end of val epoch + """ + # Must log before super call, as metric lists are cleared at end of step + self.log_spsk_ratio(self.val_metrics, "val") + super().on_validation_epoch_end() + + def test_step(self, batch, batch_idx): + """ + Run test on single batch + Include metrics computation for ensemble mean prediction + """ + super().test_step(batch, batch_idx) + + ( + trajectories, + traj_stds, + target_states, + spread_squared_batch, + ens_mse_batch, + ) = self.ensemble_common_step(batch) + self.test_metrics["spread_squared"].append(spread_squared_batch) + self.test_metrics["ens_mse"].append(ens_mse_batch) + + # Compute additional ensemble metrics + ens_mean = torch.mean( + trajectories, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + ens_std = torch.std(trajectories, dim=1) + # (B, pred_steps, num_grid_nodes, d_f) + + # Compute MAE for ensemble mean + ensemble CRPS + ens_maes = metrics.mae( + ens_mean, + target_states, + ens_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.test_metrics["ens_mae"].append(ens_maes) + crps_batch = metrics.crps_ens( + trajectories, + target_states, + traj_stds, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.test_metrics["crps_ens"].append(crps_batch) + + def on_test_epoch_end(self): + """ + Compute test metrics and make plots at the end of test epoch. + Will gather stored tensors and perform plotting and logging on rank 0. + """ + super().on_test_epoch_end() + self.log_spsk_ratio(self.test_metrics, "test") diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/graph_fm.py similarity index 95% rename from neural_lam/models/hi_lam.py rename to neural_lam/models/graph_fm.py index 4d7eb94c..e69fba4d 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/graph_fm.py @@ -2,15 +2,15 @@ from torch import nn # First-party -from neural_lam.interaction_net import InteractionNet +from neural_lam.interaction_net import InteractionNet, PropagationNet from neural_lam.models.base_hi_graph_model import BaseHiGraphModel -class HiLAM(BaseHiGraphModel): +class GraphFM(BaseHiGraphModel): """ - Hierarchical graph model with message passing that goes sequentially down + Hierarchical Graph-based Forecasting Model + with message passing that goes sequentially down and up the hierarchy during processing. - The Hi-LAM model from Oskarsson et al. (2023) """ def __init__(self, args): @@ -51,9 +51,10 @@ def make_up_gnns(self, args): """ Make GNNs for processing steps up through the hierarchy. """ + gnn_class = PropagationNet if args.vertical_propnets else InteractionNet return nn.ModuleList( [ - InteractionNet( + gnn_class( edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, diff --git a/neural_lam/models/graph_latent_decoder.py b/neural_lam/models/graph_latent_decoder.py new file mode 100644 index 00000000..327ec9cc --- /dev/null +++ b/neural_lam/models/graph_latent_decoder.py @@ -0,0 +1,86 @@ +# Third-party +import torch_geometric as pyg + +# First-party +from neural_lam.interaction_net import InteractionNet, PropagationNet +from neural_lam.models.base_graph_latent_decoder import BaseGraphLatentDecoder + + +class GraphLatentDecoder(BaseGraphLatentDecoder): + """ + Decoder that maps grid input + latent variable on mesh to prediction on grid + Uses non-hierarchical graph + """ + + def __init__( + self, + g2m_edge_index, + m2m_edge_index, + m2g_edge_index, + hidden_dim, + latent_dim, + processor_layers, + hidden_layers=1, + output_std=True, + ): + super().__init__(hidden_dim, latent_dim, hidden_layers, output_std) + + # GNN from grid to mesh + self.g2m_gnn = InteractionNet( + g2m_edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + + # Processor layers on mesh + self.processor = pyg.nn.Sequential( + "mesh_rep, edge_rep", + [ + ( + InteractionNet( + m2m_edge_index, hidden_dim, hidden_layers=hidden_layers + ), + "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep", + ) + for _ in range(processor_layers) + ], + ) + + # GNN from mesh to grid + self.m2g_gnn = PropagationNet( + m2g_edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + + def combine_with_latent( + self, original_grid_rep, latent_rep, residual_grid_rep, graph_emb + ): + """ + Combine the grid representation with representation of latent variable. + The output should be on the grid again. + + original_grid_rep: (B, num_grid_nodes, d_h) + latent_rep: (B, num_mesh_nodes, d_h) + residual_grid_rep: (B, num_grid_nodes, d_h) + + Returns: + grid_rep: (B, num_grid_nodes, d_h) + """ + mesh_rep = self.g2m_gnn( + original_grid_rep, latent_rep, graph_emb["g2m"] + ) # (B, N_mesh, d_h) + + # Process on mesh + mesh_rep, _ = self.processor( + mesh_rep, graph_emb["m2m"] + ) # (B, N_mesh, d_h) + + # Back to grid + grid_rep = self.m2g_gnn( + mesh_rep, residual_grid_rep, graph_emb["m2g"] + ) # (B, N_mesh, d_h) + + return grid_rep diff --git a/neural_lam/models/graph_latent_encoder.py b/neural_lam/models/graph_latent_encoder.py new file mode 100644 index 00000000..43487961 --- /dev/null +++ b/neural_lam/models/graph_latent_encoder.py @@ -0,0 +1,65 @@ +# First-party +from neural_lam import utils +from neural_lam.interaction_net import PropagationNet +from neural_lam.models.base_latent_encoder import BaseLatentEncoder + + +class GraphLatentEncoder(BaseLatentEncoder): + """ + Encoder that maps from grid to mesh and defines a latent distribution + on mesh + """ + + def __init__( + self, + latent_dim, + g2m_edge_index, + m2m_edge_index, + hidden_dim, + processor_layers, + hidden_layers=1, + output_dist="isotropic", + ): + super().__init__( + latent_dim, + output_dist, + ) + + # GNN from grid to mesh + self.g2m_gnn = PropagationNet( + g2m_edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + + # Processor layers on mesh + self.processor = utils.make_gnn_seq( + m2m_edge_index, processor_layers, hidden_layers, hidden_dim + ) + + self.latent_param_map = utils.make_mlp( + [hidden_dim] * (hidden_layers + 1) + [self.output_dim], + layer_norm=False, + ) + + # pylint: disable-next=arguments-differ + def compute_dist_params(self, grid_rep, graph_emb, **kwargs): + """ + Compute parameters of distribution over latent variable using the + grid representation + + grid_rep: (B, N_grid, d_h) + graph_emb: dict with graph embedding vectors, entries at least + mesh: (B, N_mesh, d_h) + g2m: (B, M_g2m, d_h) + m2m: (B, M_g2m, d_h) + + Returns: + parameters: (B, num_mesh_nodes, d_output) + """ + mesh_rep = self.g2m_gnn( + grid_rep, graph_emb["mesh"], graph_emb["g2m"] + ) # (B, N_mesh, d_h) + mesh_rep, _ = self.processor(mesh_rep, graph_emb["m2m"]) + return self.latent_param_map(mesh_rep) # (B, N_mesh, d_output) diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graphcast.py similarity index 88% rename from neural_lam/models/graph_lam.py rename to neural_lam/models/graphcast.py index f767fba0..93efbe49 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graphcast.py @@ -7,12 +7,11 @@ from neural_lam.models.base_graph_model import BaseGraphModel -class GraphLAM(BaseGraphModel): +class GraphCast(BaseGraphModel): """ - Full graph-based LAM model that can be used with different - (non-hierarchical )graphs. Mainly based on GraphCast, but the model from - Keisler (2022) is almost identical. Used for GC-LAM and L1-LAM in - Oskarsson et al. (2023). + Full graph-based model that can be used with different + (non-hierarchical) graphs. Mainly based on GraphCast, but the model from + Keisler (2022) is almost identical. """ def __init__(self, args): @@ -20,7 +19,7 @@ def __init__(self, args): assert ( not self.hierarchical - ), "GraphLAM does not use a hierarchical mesh graph" + ), "GraphCast does not use a hierarchical mesh graph" # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] diff --git a/neural_lam/models/hi_graph_latent_decoder.py b/neural_lam/models/hi_graph_latent_decoder.py new file mode 100644 index 00000000..e0ad2c38 --- /dev/null +++ b/neural_lam/models/hi_graph_latent_decoder.py @@ -0,0 +1,177 @@ +# Third-party +from torch import nn + +# First-party +from neural_lam import utils +from neural_lam.interaction_net import InteractionNet, PropagationNet +from neural_lam.models.base_graph_latent_decoder import BaseGraphLatentDecoder + + +class HiGraphLatentDecoder(BaseGraphLatentDecoder): + """ + Decoder that maps grid input + latent variable on mesh to prediction on grid + Uses hierarchical graph + """ + + def __init__( + self, + g2m_edge_index, + m2m_edge_index, + m2g_edge_index, + mesh_up_edge_index, + mesh_down_edge_index, + hidden_dim, + latent_dim, + intra_level_layers, + hidden_layers=1, + output_std=True, + ): + super().__init__(hidden_dim, latent_dim, hidden_layers, output_std) + + # GNN from grid to mesh + self.g2m_gnn = InteractionNet( + g2m_edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + # GNN from mesh to grid + self.m2g_gnn = PropagationNet( + m2g_edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + + # GNNs going up through mesh levels + self.mesh_up_gnns = nn.ModuleList( + [ + # Note: We keep these as InteractionNets + InteractionNet( + edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + for edge_index in mesh_up_edge_index + ] + ) + # GNNs going down through mesh levels + self.mesh_down_gnns = nn.ModuleList( + [ + PropagationNet( + edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + for edge_index in mesh_down_edge_index + ] + ) + # GNNs applied on intra-level in-between up and down propagation + # Identity mappings if intra_level_layers = 0 + self.intra_up_gnns = nn.ModuleList( + [ + utils.make_gnn_seq( + edge_index, intra_level_layers, hidden_layers, hidden_dim + ) + for edge_index in m2m_edge_index + ] + ) + self.intra_down_gnns = nn.ModuleList( + [ + utils.make_gnn_seq( + edge_index, intra_level_layers, hidden_layers, hidden_dim + ) + for edge_index in list(m2m_edge_index)[:-1] + # Not needed for level L + ] + ) + + def combine_with_latent( + self, original_grid_rep, latent_rep, residual_grid_rep, graph_emb + ): + """ + Combine the grid representation with representation of latent variable. + The output should be on the grid again. + + original_grid_rep: (B, num_grid_nodes, d_h) + latent_rep: (B, num_mesh_nodes, d_h) + residual_grid_rep: (B, num_grid_nodes, d_h) + + Returns: + grid_rep: (B, num_grid_nodes, d_h) + """ + # Map to bottom mesh level + current_mesh_rep = self.g2m_gnn( + original_grid_rep, graph_emb["mesh"][0], graph_emb["g2m"] + ) # (B, num_mesh_nodes[0], d_h) + + # Up hierarchy + # Run intra-level processing before propagating up + mesh_level_reps = [] + m2m_level_reps = [] + for ( + up_gnn, + intra_gnn_seq, + mesh_up_level_rep, + m2m_level_rep, + mesh_level_rep, + ) in zip( + self.mesh_up_gnns, + self.intra_up_gnns[:-1], + graph_emb["mesh_up"], + graph_emb["m2m"][:-1], + # Last propagation up combines with latent representation + graph_emb["mesh"][1:-1] + [latent_rep], + ): # Loop goes L-1 times, from intra-level processing at l=1 to l=L-1 + # Run intra-level processing on level l + new_mesh_rep, new_m2m_rep = intra_gnn_seq( + current_mesh_rep, m2m_level_rep + ) # (B, num_mesh_nodes[l], d_h) + + # Store representation for this level for downward pass + mesh_level_reps.append(new_mesh_rep) # Will append L-1 times + m2m_level_reps.append(new_m2m_rep) + + # Apply up GNN, don't need to store these reps. + current_mesh_rep = up_gnn( + new_mesh_rep, mesh_level_rep, mesh_up_level_rep + ) # (B, num_mesh_nodes[l], d_h) + + # Run intra-level processing for highest mesh level + current_mesh_rep, _ = self.intra_up_gnns[-1]( + current_mesh_rep, graph_emb["m2m"][-1] + ) # (B, num_mesh_nodes[L], d_h) + + # Down hierarchy + # Propagate down before running intra-level processing + for ( + down_gnn, + intra_gnn_seq, + mesh_down_level_rep, + m2m_level_rep, + mesh_level_rep, + ) in zip( + reversed(self.mesh_down_gnns), + reversed(self.intra_down_gnns), + reversed(graph_emb["mesh_down"]), + reversed(m2m_level_reps), # Residual connections to up pass + reversed(mesh_level_reps), # ^ + ): # Loop goes L-1 times, from intra level processing at l=L-1 to l=1 + # Apply down GNN, don't need to store these reps. + new_mesh_rep = down_gnn( + current_mesh_rep, mesh_level_rep, mesh_down_level_rep + ) # (B, num_mesh_nodes[l], d_h) + + # Run same level processing on level l + current_mesh_rep, _ = intra_gnn_seq( + new_mesh_rep, m2m_level_rep + ) # (B, num_mesh_nodes[l], d_h) + + # Map back to grid + grid_rep = self.m2g_gnn( + current_mesh_rep, residual_grid_rep, graph_emb["m2g"] + ) # (B, num_mesh_nodes[0], d_h) + + return grid_rep diff --git a/neural_lam/models/hi_graph_latent_encoder.py b/neural_lam/models/hi_graph_latent_encoder.py new file mode 100644 index 00000000..62be8110 --- /dev/null +++ b/neural_lam/models/hi_graph_latent_encoder.py @@ -0,0 +1,124 @@ +# Third-party +from torch import nn + +# First-party +from neural_lam import utils +from neural_lam.interaction_net import PropagationNet +from neural_lam.models.base_latent_encoder import BaseLatentEncoder + + +class HiGraphLatentEncoder(BaseLatentEncoder): + """ + Encoder that maps from grid to mesh and defines a latent distribution + on mesh. + Uses a hierarchical mesh graph. + """ + + def __init__( + self, + latent_dim, + g2m_edge_index, + m2m_edge_index, + mesh_up_edge_index, + hidden_dim, + intra_level_layers, + hidden_layers=1, + output_dist="isotropic", + ): + super().__init__( + latent_dim, + output_dist, + ) + + # GNN from grid to mesh + self.g2m_gnn = PropagationNet( + g2m_edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + + # GNNs going up through mesh levels + self.mesh_up_gnns = nn.ModuleList( + [ + PropagationNet( + edge_index, + hidden_dim, + hidden_layers=hidden_layers, + update_edges=False, + ) + for edge_index in mesh_up_edge_index + ] + ) + + # GNNs applied on intra-level in-between upwards propagation + # Identity mappings if intra_level_layers = 0 + self.intra_level_gnns = nn.ModuleList( + [ + utils.make_gnn_seq( + edge_index, intra_level_layers, hidden_layers, hidden_dim + ) + for edge_index in m2m_edge_index + ] + ) + + # Final map to parameters + self.latent_param_map = utils.make_mlp( + [hidden_dim] * (hidden_layers + 1) + [self.output_dim], + layer_norm=False, + ) + + # pylint: disable-next=arguments-differ + def compute_dist_params(self, grid_rep, graph_emb, **kwargs): + """ + Compute parameters of distribution over latent variable using the + grid representation + + grid_rep: (B, N_grid, d_h) + graph_emb: dict with graph embedding vectors, entries at least + mesh: list of (B, N_mesh, d_h) + g2m: (B, M_g2m, d_h) + m2m: (B, M_g2m, d_h) + mesh_up: list of (B, N_mesh, d_h) + + Returns: + parameters: (B, num_mesh_nodes, d_output) + """ + current_mesh_rep = self.g2m_gnn( + grid_rep, graph_emb["mesh"][0], graph_emb["g2m"] + ) # (B, N_mesh, d_h) + + # Run same level processing on level 0 + current_mesh_rep, _ = self.intra_level_gnns[0]( + current_mesh_rep, graph_emb["m2m"][0] + ) + + # Do not need to keep track of old edge or mesh reps here + # Go from mesh level 1 to L + for ( + up_gnn, + intra_gnn_seq, + mesh_up_level_rep, + m2m_level_rep, + mesh_level_rep, + ) in zip( + self.mesh_up_gnns, + self.intra_level_gnns[1:], + graph_emb["mesh_up"], + graph_emb["m2m"][1:], + graph_emb["mesh"][1:], + ): + # Apply up GNN + new_node_rep = up_gnn( + current_mesh_rep, mesh_level_rep, mesh_up_level_rep + ) # (B, N_mesh[l], d_h) + + # Run same level processing on level l + current_mesh_rep, _ = intra_gnn_seq( + new_node_rep, m2m_level_rep + ) # (B, N_mesh[l], d_h) + + # At final mesh level, map to parameter dim + return self.latent_param_map( + current_mesh_rep + ) # (B, N_mesh[L], d_output) diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py deleted file mode 100644 index 740824e1..00000000 --- a/neural_lam/models/hi_lam_parallel.py +++ /dev/null @@ -1,96 +0,0 @@ -# Third-party -import torch -import torch_geometric as pyg - -# First-party -from neural_lam.interaction_net import InteractionNet -from neural_lam.models.base_hi_graph_model import BaseHiGraphModel - - -class HiLAMParallel(BaseHiGraphModel): - """ - Version of HiLAM where all message passing in the hierarchical mesh (up, - down, inter-level) is ran in parallel. - - This is a somewhat simpler alternative to the sequential message passing - of Hi-LAM. - """ - - def __init__(self, args): - super().__init__(args) - - # Processor GNNs - # Create the complete edge_index combining all edges for processing - total_edge_index_list = ( - list(self.m2m_edge_index) - + list(self.mesh_up_edge_index) - + list(self.mesh_down_edge_index) - ) - total_edge_index = torch.cat(total_edge_index_list, dim=1) - self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list] - - if args.processor_layers == 0: - self.processor = lambda x, edge_attr: (x, edge_attr) - else: - processor_nets = [ - InteractionNet( - total_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, - edge_chunk_sizes=self.edge_split_sections, - aggr_chunk_sizes=self.level_mesh_sizes, - ) - for _ in range(args.processor_layers) - ] - self.processor = pyg.nn.Sequential( - "mesh_rep, edge_rep", - [ - (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") - for net in processor_nets - ], - ) - - def hi_processor_step( - self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep - ): - """ - Internal processor step of hierarchical graph models. - Between mesh init and read out. - - Each input is list with representations, each with shape - - mesh_rep_levels: (B, N_mesh[l], d_h) - mesh_same_rep: (B, M_same[l], d_h) - mesh_up_rep: (B, M_up[l -> l+1], d_h) - mesh_down_rep: (B, M_down[l <- l+1], d_h) - - Returns same lists - """ - - # First join all node and edge representations to single tensors - mesh_rep = torch.cat(mesh_rep_levels, dim=1) # (B, N_mesh, d_h) - mesh_edge_rep = torch.cat( - mesh_same_rep + mesh_up_rep + mesh_down_rep, axis=1 - ) # (B, M_mesh, d_h) - - # Here, update mesh_*_rep and mesh_rep - mesh_rep, mesh_edge_rep = self.processor(mesh_rep, mesh_edge_rep) - - # Split up again for read-out step - mesh_rep_levels = list( - torch.split(mesh_rep, self.level_mesh_sizes, dim=1) - ) - mesh_edge_rep_sections = torch.split( - mesh_edge_rep, self.edge_split_sections, dim=1 - ) - - mesh_same_rep = mesh_edge_rep_sections[: self.num_levels] - mesh_up_rep = mesh_edge_rep_sections[ - self.num_levels : self.num_levels + (self.num_levels - 1) - ] - mesh_down_rep = mesh_edge_rep_sections[ - self.num_levels + (self.num_levels - 1) : - ] # Last are down edges - - # Note: We return all, even though only down edges really are used later - return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 31715502..53f1a4c6 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -4,11 +4,13 @@ # Third-party import numpy as np import torch +import torch_geometric as pyg from torch import nn from tueplots import bundles, figsizes # First-party from neural_lam import constants +from neural_lam.interaction_net import InteractionNet def load_dataset_stats(dataset_name, device="cpu"): @@ -253,7 +255,7 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=True, family="serif") + bundle = bundles.neurips2023(usetex=False, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( @@ -271,3 +273,36 @@ def init_wandb_metrics(wandb_logger): experiment.define_metric("val_mean_loss", summary="min") for step in constants.VAL_STEP_LOG_ERRORS: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +class IdentityModule(nn.Module): + """ + A identity operator that can return multiple inputs + """ + + def forward(self, *args): + """Return input args""" + return args + + +def make_gnn_seq(edge_index, num_gnn_layers, hidden_layers, hidden_dim): + """ + Make a sequential GNN module propagating both node and edge representations + """ + if num_gnn_layers == 0: + # If no layers, return identity + return IdentityModule() + return pyg.nn.Sequential( + "mesh_rep, edge_rep", + [ + ( + InteractionNet( + edge_index, + hidden_dim, + hidden_layers=hidden_layers, + ), + "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep", + ) + for _ in range(num_gnn_layers) + ], + ) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index cef34a84..57fb56b0 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import torch # First-party from neural_lam import constants, utils @@ -111,6 +112,117 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): return fig +@matplotlib.rc_context(utils.fractional_plot_bundle(1)) +def plot_ensemble_prediction( + samples, target, ens_mean, ens_std, obs_mask, title=None, vrange=None +): + """ + Plot example predictions, ground truth, mean and std.-dev. + from ensemble forecast + + samples: (S, N_grid,) + target: (N_grid,) + ens_mean: (N_grid,) + ens_std: (N_grid,) + obs_mask: (N_grid,) + (optional) title: title of plot + (optional) vrange: tuple of length with common min and max of values + (not for std.) + """ + # Get common scale for values + if vrange is None: + vmin = min(vals.min().cpu().item() for vals in (samples, target)) + vmax = max(vals.max().cpu().item() for vals in (samples, target)) + else: + vmin, vmax = vrange + + # Set up masking of border region + mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + pixel_alpha = ( + mask_reshaped.clamp(0.7, 1).cpu().numpy() + ) # Faded border region + + fig, axes = plt.subplots( + 3, + 3, + figsize=(15, 15), + subplot_kw={"projection": constants.LAMBERT_PROJ}, + ) + axes = axes.flatten() + + # Plot target, ensemble mean and std. + gt_im = plot_on_axis( + axes[0], + target, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + ax_title="Ground Truth", + ) + plot_on_axis( + axes[1], + ens_mean, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + ax_title="Ens. Mean", + ) + std_im = plot_on_axis( + axes[2], ens_std, alpha=pixel_alpha, ax_title="Ens. Std." + ) # Own vrange + + # Plot samples + for member_i, (ax, member) in enumerate( + zip(axes[3:], samples[:6]), start=1 + ): + plot_on_axis( + ax, + member, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + ax_title=f"Member {member_i}", + ) + + # Turn off unused axes + for ax in axes[(3 + samples.shape[0]) :]: + ax.axis("off") + + # Add colorbars + values_cbar = fig.colorbar( + gt_im, ax=axes[:2], aspect=60, location="bottom", shrink=0.9 + ) + values_cbar.ax.tick_params(labelsize=10) + std_cbar = fig.colorbar(std_im, aspect=30, location="bottom", shrink=0.9) + std_cbar.ax.tick_params(labelsize=10) + + if title: + fig.suptitle(title, size=20) + + return fig + + +def plot_on_axis(ax, data, alpha=None, vmin=None, vmax=None, ax_title=None): + """ + Plot weather state on given axis + """ + ax.coastlines() # Add coastline outlines + data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy() + im = ax.imshow( + data_grid, + origin="lower", + extent=constants.GRID_LIMITS, + alpha=alpha, + vmin=vmin, + vmax=vmax, + cmap="plasma", + ) + + if ax_title: + ax.set_title(ax_title, size=15) + return im + + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_spatial_error(error, obs_mask, title=None, vrange=None): """ @@ -157,3 +269,82 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): fig.suptitle(title, size=10) return fig + + +@matplotlib.rc_context(utils.fractional_plot_bundle(1)) +def plot_latent_samples(prior_samples, vi_samples, title=None): + """ + Plot samples of latent variable drawn from prior and + variational distribution + + prior_samples: (samples, N_mesh, d_latent) + vi_samples: (samples, N_mesh, d_latent) + + Returns: + fig: the plot figure + """ + num_samples, num_mesh_nodes, latent_dim = prior_samples.shape + plot_dims = min(latent_dim, 3) # Plot first 3 dimensions + img_side_size = int(np.sqrt(num_mesh_nodes)) + assert img_side_size**2 == num_mesh_nodes, ( + "Number of mesh nodes is not a " + "square number, can not plot latent samples as images" + ) + + # Get common scale for values + vmin = min( + vals[..., :plot_dims].min().cpu().item() + for vals in (prior_samples, vi_samples) + ) + vmax = max( + vals[..., :plot_dims].max().cpu().item() + for vals in (prior_samples, vi_samples) + ) + + # Create figure + fig, axes = plt.subplots(num_samples, 2 * plot_dims, figsize=(20, 16)) + + # Plot samples + for row_i, (axes_row, prior_sample, vi_sample) in enumerate( + zip(axes, prior_samples, vi_samples) + ): + + for dim_i in range(plot_dims): + prior_sample_reshaped = ( + prior_sample[:, dim_i] + .reshape(img_side_size, img_side_size) + .cpu() + .to(torch.float32) + .numpy() + ) + vi_sample_reshaped = ( + vi_sample[:, dim_i] + .reshape(img_side_size, img_side_size) + .cpu() + .to(torch.float32) + .numpy() + ) + # Plot every other as prior and vi + prior_ax = axes_row[2 * dim_i] + vi_ax = axes_row[2 * dim_i + 1] + prior_ax.imshow(prior_sample_reshaped, vmin=vmin, vmax=vmax) + vi_im = vi_ax.imshow(vi_sample_reshaped, vmin=vmin, vmax=vmax) + + if row_i == 0: + # Add titles at top of columns + prior_ax.set_title(f"d{dim_i} (prior)", size=15) + vi_ax.set_title(f"d{dim_i} (vi)", size=15) + + # Remove ticks from all axes + for ax in axes.flatten(): + ax.set_xticks([]) + ax.set_yticks([]) + + # Add colorbar + cbar = fig.colorbar(vi_im, ax=axes, aspect=60, location="bottom") + cbar.ax.tick_params(labelsize=15) + + if title: + fig.suptitle(title, size=20) + + return fig diff --git a/pyproject.toml b/pyproject.toml index b513a258..53f82b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,8 +51,10 @@ generated-members = [ [tool.pylint.'MESSAGES CONTROL'] disable = [ "C0114", # 'missing-module-docstring', Do not require module docstrings + "C0302", # 'too-many-lines ', Allow longer files "R0901", # 'too-many-ancestors', Allow many layers of sub-classing "R0902", # 'too-many-instance-attribtes', Allow many attributes + "R0912", # 'too-many-branches', Allow more branching "R0913", # 'too-many-arguments', Allow many function arguments "R0914", # 'too-many-locals', Allow many local variables "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods diff --git a/train_model.py b/train_model.py index 96d21a3f..3978054c 100644 --- a/train_model.py +++ b/train_model.py @@ -10,15 +10,15 @@ # First-party from neural_lam import constants, utils -from neural_lam.models.graph_lam import GraphLAM -from neural_lam.models.hi_lam import HiLAM -from neural_lam.models.hi_lam_parallel import HiLAMParallel +from neural_lam.models.graph_efm import GraphEFM +from neural_lam.models.graph_fm import GraphFM +from neural_lam.models.graphcast import GraphCast from neural_lam.weather_dataset import WeatherDataset MODELS = { - "graph_lam": GraphLAM, - "hi_lam": HiLAM, - "hi_lam_parallel": HiLAMParallel, + "graphcast": GraphCast, + "graph_fm": GraphFM, + "graph_efm": GraphEFM, } @@ -102,6 +102,13 @@ def main(): default=64, help="Dimensionality of all hidden representations (default: 64)", ) + parser.add_argument( + "--latent_dim", + type=int, + default=None, + help="Dimensionality of latent R.V. at each node (if different than" + " hidden_dim) (default: None (same as hidden_dim))", + ) parser.add_argument( "--hidden_layers", type=int, @@ -112,7 +119,20 @@ def main(): "--processor_layers", type=int, default=4, - help="Number of GNN layers in processor GNN (default: 4)", + help="Number of GNN layers in processor GNN (for prob. model: in " + "decoder) (default: 4)", + ) + parser.add_argument( + "--encoder_processor_layers", + type=int, + default=2, + help="Number of on-mesh GNN layers in encoder GNN (default: 2)", + ) + parser.add_argument( + "--prior_processor_layers", + type=int, + default=2, + help="Number of on-mesh GNN layers in prior GNN (default: 2)", ) parser.add_argument( "--mesh_aggr", @@ -129,6 +149,28 @@ def main(): "output dimensions " "(default: 0 (no))", ) + parser.add_argument( + "--prior_dist", + type=str, + default="isotropic", + help="Structure of Gaussian distribution in prior network output " + "(isotropic/diagonal) (default: isotropic)", + ) + parser.add_argument( + "--learn_prior", + type=int, + default=1, + help="If the prior should be learned as a mapping from previous state " + "and forcing, otherwise static with mean 0 (default: 1 (yes))", + ) + parser.add_argument( + "--vertical_propnets", + type=int, + default=0, + help="If PropagationNets should be used for all vertical message " + "passing (g2m, m2g, up in hierarchy), in deterministic models." + "(default: 0 (no))", + ) # Training options parser.add_argument( @@ -168,6 +210,28 @@ def main(): help="Number of epochs training between each validation run " "(default: 1)", ) + parser.add_argument( + "--kl_beta", + type=float, + default=1.0, + help="Beta weighting in front of kl-term in ELBO (default: 1)", + ) + parser.add_argument( + "--crps_weight", + type=float, + default=0, + help="Weighting for CRPS term of loss, not computed if = 0. CRPS is " + "computed based on trajectories sampled using prior distribution. " + "(default: 0)", + ) + parser.add_argument( + "--sample_obs_noise", + type=int, + default=0, + help="If observation noise should be sampled during rollouts (both " + "training and eval), or just mean prediction used " + "(default: 0 (no))", + ) # Evaluation options parser.add_argument( @@ -180,9 +244,15 @@ def main(): "--n_example_pred", type=int, default=1, - help="Number of example predictions to plot during evaluation " + help="Number of example predictions to plot during val/test " "(default: 1)", ) + parser.add_argument( + "--ensemble_size", + type=int, + default=5, + help="Number of ensemble members during evaluation (default: 5)", + ) args = parser.parse_args() # Asserts for arguments @@ -256,24 +326,46 @@ def main(): f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" ) - checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=f"saved_models/{run_name}", - filename="min_val_loss", - monitor="val_mean_loss", - mode="min", - save_last=True, + + # Callbacks for saving model checkpoint + callbacks = [] + callbacks.append( + pl.callbacks.ModelCheckpoint( + dirpath=f"saved_models/{run_name}", + filename="min_val_loss", + monitor="val_mean_loss", + mode="min", + save_last=True, + ) ) + # Save checkpoints for minimum loss at specific lead times + for unroll_time in constants.VAL_STEP_CHECKPOINTS: + metric_name = f"val_loss_unroll{unroll_time}" + callbacks.append( + pl.callbacks.ModelCheckpoint( + dirpath=f"saved_models/{run_name}", + filename=f"min_{metric_name}", + monitor=metric_name, + mode="min", + ) + ) logger = pl.loggers.WandbLogger( project=constants.WANDB_PROJECT, name=run_name, config=args ) + + # Training strategy + # If doing pure autoencoder training (kl_beta = 0), the prior network is not + # used at all in producing the loss. This is desired, but DDP complains. + strategy = "ddp" if args.kl_beta > 0 else "ddp_find_unused_parameters_true" + trainer = pl.Trainer( max_epochs=args.epochs, deterministic=True, - strategy="ddp", + strategy=strategy, accelerator=device_name, logger=logger, log_every_n_steps=1, - callbacks=[checkpoint_callback], + callbacks=callbacks, check_val_every_n_epoch=args.val_interval, precision=args.precision, )