From 4537427cd3c68974e6975f0cab6e7775827aaff0 Mon Sep 17 00:00:00 2001
From: sadamov <45732287+sadamov@users.noreply.github.com>
Date: Thu, 7 Mar 2024 17:42:18 +0100
Subject: [PATCH] Merge pull request #10 from MeteoSwiss/merge_upstream
Merge upstream
Conflicts:
.github/workflows/pre-commit.yml
.gitignore
.pre-commit-config.yaml
README.md
create_grid_features.py
create_mesh.py
create_parameter_weights.py
create_static_features.py
create_zarr_archive.py
environment.yml
helper.py
neural_lam/constants.py
neural_lam/models/ar_model.py
neural_lam/models/base_graph_model.py
neural_lam/models/base_hi_graph_model.py
neural_lam/rotate_grid.py
neural_lam/utils.py
neural_lam/vis.py
neural_lam/weather_dataset.py
plot_graph.py
pyproject.toml
requirements.txt
slurm_eval.sh
slurm_train.sh
train_model.py
---
README.md | 46 +++++++++++++++
create_mesh.py | 1 +
create_parameter_weights.py | 1 +
neural_lam/models/ar_model.py | 104 ++++++++++++++++++++++++++++++----
neural_lam/vis.py | 1 +
neural_lam/weather_dataset.py | 3 +
train_model.py | 49 ++++++----------
7 files changed, 162 insertions(+), 43 deletions(-)
diff --git a/README.md b/README.md
index ba0bb3fe..4b25d9bd 100644
--- a/README.md
+++ b/README.md
@@ -29,6 +29,37 @@ We plan to continue updating this repository as we improve existing models and d
Collaborations around this implementation are very welcome.
If you are working with Neural-LAM feel free to get in touch and/or submit pull requests to the repository.
+Additions relevant to the COSMO Neural-LAM implementation are highlighted in __blue__.
+# Quick Start
+
+Follow the steps below to get started with Neural-LAM on Balfrin.cscs.ch.
+Don't worry everything is carried out on a small subset of data for a limited number of epochs.
+
+
+```{bash}
+# Clone the repository
+git clone https://github.com/MeteoSwiss/neural-lam/
+cd neural-lam
+
+# Link the data folder containing the COSMO zarr archives
+ln -s /scratch/mch/sadamov/pyprojects_data/neural_lam/data
+mkdir lightning_logs
+
+# Create the conda environment (~10min)
+mamba env create -f environment.yml
+mamba activate neural-lam
+
+# Run the preprocessing/training scripts
+# (don't execute preprocessing scripts at the same time as training)
+sbatch slurm_train.sh
+
+# Run the evaluation script and generate plots and gif for TQV
+# (by default this will use the pre-trained model from `wandb/example.ckpt`)
+sbatch slurm_eval.sh
+
+```
+
+
# Modularity
The Neural-LAM code is designed to modularize the different components involved in training and evaluating neural weather prediction models.
Models, graphs and data are stored separately and it should be possible to swap out individual components.
@@ -56,6 +87,21 @@ Below follows instructions on how to use Neural-LAM to train and evaluate models
## Installation
Follow the steps below to create the necessary python environment.
+
+
+For COSMO we use conda to avoid the Cartopy installation issues and because conda environments usually work well on the vCluster called Balfrin.cscs.ch.
+
+1. Simply run `conda env create -f environment.yml` to create the environment.
+2. Activate the environment with `conda activate neural-lam`.
+3. Happy Coding \o/
+
+Note that only the cuda version is pinned to 11.8, otherwise all the latest libraries are installed. This might break in the future and must be adjusted to the users conda version.
+
+
+
+\
+Follow the steps below to create the necessary python environment.
+
1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement.
2. Use python 3.9.
3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
diff --git a/create_mesh.py b/create_mesh.py
index f04b4d4b..a06b59b8 100644
--- a/create_mesh.py
+++ b/create_mesh.py
@@ -16,6 +16,7 @@
from neural_lam import config
+
def plot_graph(graph, title=None):
fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H
edge_index = graph.edge_index
diff --git a/create_parameter_weights.py b/create_parameter_weights.py
index cae1ae3e..13321cef 100644
--- a/create_parameter_weights.py
+++ b/create_parameter_weights.py
@@ -82,6 +82,7 @@ def main():
# Compute mean and std.-dev. of each parameter (+ flux forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")
+
means = []
squares = []
flux_means = []
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index 29b169d4..3dad27e6 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -1,17 +1,22 @@
+# pylint: disable=wrong-import-order
# Standard library
+import glob
import os
# Third-party
+import imageio
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
+from torch import nn
# First-party
-from neural_lam import config, metrics, utils, vis
+from neural_lam import constants, metrics, utils, vis
+# pylint: disable=too-many-public-methods
class ARModel(pl.LightningModule):
"""
Generic auto-regressive weather model.
@@ -106,6 +111,13 @@ def interior_mask_bool(self):
"""
return self.interior_mask[:, 0].to(torch.bool)
+ @property
+ def interior_mask_bool(self):
+ """
+ Get the interior mask as a boolean (N,) mask.
+ """
+ return self.interior_mask[:, 0].to(torch.bool)
+
@staticmethod
def expand_to_batch(x, batch_size):
"""
@@ -113,34 +125,99 @@ def expand_to_batch(x, batch_size):
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)
- def predict_step(self, prev_state, prev_prev_state, forcing):
+ def precompute_variable_indices(self):
+ """
+ Precompute indices for each variable in the input tensor
+ """
+ variable_indices = {}
+ all_vars = []
+ index = 0
+ # Create a list of tuples for all variables, using level 0 for 2D
+ # variables
+ for var_name in constants.PARAM_NAMES_SHORT:
+ if constants.IS_3D[var_name]:
+ for level in constants.VERTICAL_LEVELS:
+ all_vars.append((var_name, level))
+ else:
+ all_vars.append((var_name, 0)) # Use level 0 for 2D variables
+
+ # Sort the variables based on the tuples
+ sorted_vars = sorted(all_vars)
+
+ for var in sorted_vars:
+ var_name, level = var
+ if var_name not in variable_indices:
+ variable_indices[var_name] = []
+ variable_indices[var_name].append(index)
+ index += 1
+
+ return variable_indices
+
+ def apply_constraints(self, prediction):
+ """
+ Apply constraints to prediction to ensure values are within the
+ specified bounds
+ """
+ for param, (min_val, max_val) in constants.PARAM_CONSTRAINTS.items():
+ indices = self.variable_indices[param]
+ for index in indices:
+ # Apply clamping to ensure values are within the specified
+ # bounds
+ prediction[:, :, index] = torch.clamp(
+ prediction[:, :, index],
+ min=min_val,
+ max=max_val if max_val is not None else float("inf"),
+ )
+ return prediction
+
+ def predict_step(
+ self,
+ prev_state,
+ prev_prev_state,
+ batch_static_features=None,
+ forcing=None,
+ ):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
- forcing: (B, num_grid_nodes, forcing_dim)
+ batch_static_features: (B, num_grid_nodes, batch_static_feature_dim)
+ forcing: (B, num_grid_nodes, forcing_dim), optional
"""
raise NotImplementedError("No prediction step implemented")
- def unroll_prediction(self, init_states, forcing_features, true_states):
+ def unroll_prediction(
+ self,
+ init_states,
+ true_states,
+ batch_static_features=None,
+ forcing_features=None,
+ ):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f)
- forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
+ batch_static_features: (B, num_grid_nodes, d_static_f), optional
+ forcing_features: (B, pred_steps, num_grid_nodes, d_static_f), optional
true_states: (B, pred_steps, num_grid_nodes, d_f)
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
prediction_list = []
pred_std_list = []
- pred_steps = forcing_features.shape[1]
+ pred_steps = (
+ forcing_features.shape[1]
+ if forcing_features is not None
+ else true_states.shape[1]
+ )
for i in range(pred_steps):
- forcing = forcing_features[:, i]
+ forcing = (
+ forcing_features[:, i] if forcing_features is not None else None
+ )
border_state = true_states[:, i]
pred_state, pred_std = self.predict_step(
- prev_state, prev_prev_state, forcing
+ prev_state, prev_prev_state, batch_static_features, forcing
)
# state: (B, num_grid_nodes, d_f)
# pred_std: (B, num_grid_nodes, d_f) or None
@@ -521,10 +598,15 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
wandb.log(log_dict) # Log all
plt.close("all") # Close all figs
- def on_test_epoch_end(self):
+ def smooth_prediction_borders(self, prediction_rescaled):
"""
- Compute test metrics and make plots at the end of test epoch.
- Will gather stored tensors and perform plotting and logging on rank 0.
+ Smooths the prediction at the borders to avoid artifacts.
+
+ Args:
+ prediction_rescaled (torch.Tensor): The rescaled prediction tensor.
+
+ Returns:
+ torch.Tensor: The prediction tensor after smoothing the borders.
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 8c9ca77c..c42573a1 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -1,4 +1,5 @@
# Third-party
+import cartopy.feature as cf
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index a782806b..3fddfc3c 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -10,6 +10,9 @@
# First-party
from neural_lam import utils
+# pylint: disable=W0613:unused-argument
+# pylint: disable=W0201:attribute-defined-outside-init
+
class WeatherDataset(torch.utils.data.Dataset):
"""
diff --git a/train_model.py b/train_model.py
index fe064384..219d5475 100644
--- a/train_model.py
+++ b/train_model.py
@@ -1,11 +1,13 @@
# Standard library
import random
+import os
import time
from argparse import ArgumentParser
# Third-party
import pytorch_lightning as pl
import torch
+import wandb
from lightning_fabric.utilities import seed
# First-party
@@ -225,32 +227,11 @@ def main():
# Set seed
seed.seed_everything(args.seed)
- # Load data
- train_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- config_loader.dataset.name,
- pred_length=args.ar_steps,
- split="train",
- subsample_step=args.step_length,
- subset=bool(args.subset_ds),
- control_only=args.control_only,
- ),
- args.batch_size,
- shuffle=True,
- num_workers=args.n_workers,
- )
- max_pred_length = (65 // args.step_length) - 2 # 19
- val_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- config_loader.dataset.name,
- pred_length=max_pred_length,
- split="val",
- subsample_step=args.step_length,
- subset=bool(args.subset_ds),
- control_only=args.control_only,
- ),
- args.batch_size,
- shuffle=False,
+ # Create datamodule
+ data_module = WeatherDataModule(
+ args.dataset,
+ subset=bool(args.subset_ds),
+ batch_size=args.batch_size,
num_workers=args.n_workers,
)
@@ -323,12 +304,16 @@ def main():
trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load)
else:
# Train model
- trainer.fit(
- model=model,
- train_dataloaders=train_loader,
- val_dataloaders=val_loader,
- ckpt_path=args.load,
- )
+ data_module.split = "train"
+ if args.load:
+ trainer.fit(
+ model=model, datamodule=data_module, ckpt_path=args.load
+ )
+ else:
+ trainer.fit(model=model, datamodule=data_module)
+
+ # Print profiler
+ print(trainer.profiler) # pylint: disable=no-member
if __name__ == "__main__":