Skip to content

Commit

Permalink
Merge pull request MeteoSwiss#10 from MeteoSwiss/merge_upstream
Browse files Browse the repository at this point in the history
Merge upstream
Conflicts:
	.github/workflows/pre-commit.yml
	.gitignore
	.pre-commit-config.yaml
	README.md
	create_grid_features.py
	create_mesh.py
	create_parameter_weights.py
	create_static_features.py
	create_zarr_archive.py
	environment.yml
	helper.py
	neural_lam/constants.py
	neural_lam/models/ar_model.py
	neural_lam/models/base_graph_model.py
	neural_lam/models/base_hi_graph_model.py
	neural_lam/rotate_grid.py
	neural_lam/utils.py
	neural_lam/vis.py
	neural_lam/weather_dataset.py
	plot_graph.py
	pyproject.toml
	requirements.txt
	slurm_eval.sh
	slurm_train.sh
	train_model.py
  • Loading branch information
sadamov authored and Capucine Lechartre committed May 30, 2024
1 parent 879cfec commit 4537427
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 43 deletions.
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,37 @@ We plan to continue updating this repository as we improve existing models and d
Collaborations around this implementation are very welcome.
If you are working with Neural-LAM feel free to get in touch and/or submit pull requests to the repository.

<span style="color:blue;">Additions relevant to the COSMO Neural-LAM implementation are highlighted in __blue__.</span>
# Quick Start
<span style="color:blue;">
Follow the steps below to get started with Neural-LAM on Balfrin.cscs.ch.
Don't worry everything is carried out on a small subset of data for a limited number of epochs.
</span>

```{bash}
# Clone the repository
git clone https://github.com/MeteoSwiss/neural-lam/
cd neural-lam
# Link the data folder containing the COSMO zarr archives
ln -s /scratch/mch/sadamov/pyprojects_data/neural_lam/data
mkdir lightning_logs
# Create the conda environment (~10min)
mamba env create -f environment.yml
mamba activate neural-lam
# Run the preprocessing/training scripts
# (don't execute preprocessing scripts at the same time as training)
sbatch slurm_train.sh
# Run the evaluation script and generate plots and gif for TQV
# (by default this will use the pre-trained model from `wandb/example.ckpt`)
sbatch slurm_eval.sh
```


# Modularity
The Neural-LAM code is designed to modularize the different components involved in training and evaluating neural weather prediction models.
Models, graphs and data are stored separately and it should be possible to swap out individual components.
Expand Down Expand Up @@ -56,6 +87,21 @@ Below follows instructions on how to use Neural-LAM to train and evaluate models
## Installation
Follow the steps below to create the necessary python environment.

<span style="color:blue;">

For COSMO we use conda to avoid the Cartopy installation issues and because conda environments usually work well on the vCluster called Balfrin.cscs.ch.

1. Simply run `conda env create -f environment.yml` to create the environment.
2. Activate the environment with `conda activate neural-lam`.
3. Happy Coding \o/

Note that only the cuda version is pinned to 11.8, otherwise all the latest libraries are installed. This might break in the future and must be adjusted to the users conda version.

</span>

\
Follow the steps below to create the necessary python environment.

1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement.
2. Use python 3.9.
3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
Expand Down
1 change: 1 addition & 0 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from neural_lam import config



def plot_graph(graph, title=None):
fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H
edge_index = graph.edge_index
Expand Down
1 change: 1 addition & 0 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def main():
# Compute mean and std.-dev. of each parameter (+ flux forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")

means = []
squares = []
flux_means = []
Expand Down
104 changes: 93 additions & 11 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# pylint: disable=wrong-import-order
# Standard library
import glob
import os

# Third-party
import imageio
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from torch import nn

# First-party
from neural_lam import config, metrics, utils, vis
from neural_lam import constants, metrics, utils, vis


# pylint: disable=too-many-public-methods
class ARModel(pl.LightningModule):
"""
Generic auto-regressive weather model.
Expand Down Expand Up @@ -106,41 +111,113 @@ def interior_mask_bool(self):
"""
return self.interior_mask[:, 0].to(torch.bool)

@property
def interior_mask_bool(self):
"""
Get the interior mask as a boolean (N,) mask.
"""
return self.interior_mask[:, 0].to(torch.bool)

@staticmethod
def expand_to_batch(x, batch_size):
"""
Expand tensor with initial batch dimension
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

def predict_step(self, prev_state, prev_prev_state, forcing):
def precompute_variable_indices(self):
"""
Precompute indices for each variable in the input tensor
"""
variable_indices = {}
all_vars = []
index = 0
# Create a list of tuples for all variables, using level 0 for 2D
# variables
for var_name in constants.PARAM_NAMES_SHORT:
if constants.IS_3D[var_name]:
for level in constants.VERTICAL_LEVELS:
all_vars.append((var_name, level))
else:
all_vars.append((var_name, 0)) # Use level 0 for 2D variables

# Sort the variables based on the tuples
sorted_vars = sorted(all_vars)

for var in sorted_vars:
var_name, level = var
if var_name not in variable_indices:
variable_indices[var_name] = []
variable_indices[var_name].append(index)
index += 1

return variable_indices

def apply_constraints(self, prediction):
"""
Apply constraints to prediction to ensure values are within the
specified bounds
"""
for param, (min_val, max_val) in constants.PARAM_CONSTRAINTS.items():
indices = self.variable_indices[param]
for index in indices:
# Apply clamping to ensure values are within the specified
# bounds
prediction[:, :, index] = torch.clamp(
prediction[:, :, index],
min=min_val,
max=max_val if max_val is not None else float("inf"),
)
return prediction

def predict_step(
self,
prev_state,
prev_prev_state,
batch_static_features=None,
forcing=None,
):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
forcing: (B, num_grid_nodes, forcing_dim)
batch_static_features: (B, num_grid_nodes, batch_static_feature_dim)
forcing: (B, num_grid_nodes, forcing_dim), optional
"""
raise NotImplementedError("No prediction step implemented")

def unroll_prediction(self, init_states, forcing_features, true_states):
def unroll_prediction(
self,
init_states,
true_states,
batch_static_features=None,
forcing_features=None,
):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f)
forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
batch_static_features: (B, num_grid_nodes, d_static_f), optional
forcing_features: (B, pred_steps, num_grid_nodes, d_static_f), optional
true_states: (B, pred_steps, num_grid_nodes, d_f)
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
prediction_list = []
pred_std_list = []
pred_steps = forcing_features.shape[1]
pred_steps = (
forcing_features.shape[1]
if forcing_features is not None
else true_states.shape[1]
)

for i in range(pred_steps):
forcing = forcing_features[:, i]
forcing = (
forcing_features[:, i] if forcing_features is not None else None
)
border_state = true_states[:, i]

pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
prev_state, prev_prev_state, batch_static_features, forcing
)
# state: (B, num_grid_nodes, d_f)
# pred_std: (B, num_grid_nodes, d_f) or None
Expand Down Expand Up @@ -521,10 +598,15 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
wandb.log(log_dict) # Log all
plt.close("all") # Close all figs

def on_test_epoch_end(self):
def smooth_prediction_borders(self, prediction_rescaled):
"""
Compute test metrics and make plots at the end of test epoch.
Will gather stored tensors and perform plotting and logging on rank 0.
Smooths the prediction at the borders to avoid artifacts.
Args:
prediction_rescaled (torch.Tensor): The rescaled prediction tensor.
Returns:
torch.Tensor: The prediction tensor after smoothing the borders.
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
Expand Down
1 change: 1 addition & 0 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Third-party
import cartopy.feature as cf
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand Down
3 changes: 3 additions & 0 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# First-party
from neural_lam import utils

# pylint: disable=W0613:unused-argument
# pylint: disable=W0201:attribute-defined-outside-init


class WeatherDataset(torch.utils.data.Dataset):
"""
Expand Down
49 changes: 17 additions & 32 deletions train_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Standard library
import random
import os
import time
from argparse import ArgumentParser

# Third-party
import pytorch_lightning as pl
import torch
import wandb
from lightning_fabric.utilities import seed

# First-party
Expand Down Expand Up @@ -225,32 +227,11 @@ def main():
# Set seed
seed.seed_everything(args.seed)

# Load data
train_loader = torch.utils.data.DataLoader(
WeatherDataset(
config_loader.dataset.name,
pred_length=args.ar_steps,
split="train",
subsample_step=args.step_length,
subset=bool(args.subset_ds),
control_only=args.control_only,
),
args.batch_size,
shuffle=True,
num_workers=args.n_workers,
)
max_pred_length = (65 // args.step_length) - 2 # 19
val_loader = torch.utils.data.DataLoader(
WeatherDataset(
config_loader.dataset.name,
pred_length=max_pred_length,
split="val",
subsample_step=args.step_length,
subset=bool(args.subset_ds),
control_only=args.control_only,
),
args.batch_size,
shuffle=False,
# Create datamodule
data_module = WeatherDataModule(
args.dataset,
subset=bool(args.subset_ds),
batch_size=args.batch_size,
num_workers=args.n_workers,
)

Expand Down Expand Up @@ -323,12 +304,16 @@ def main():
trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load)
else:
# Train model
trainer.fit(
model=model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
ckpt_path=args.load,
)
data_module.split = "train"
if args.load:
trainer.fit(
model=model, datamodule=data_module, ckpt_path=args.load
)
else:
trainer.fit(model=model, datamodule=data_module)

# Print profiler
print(trainer.profiler) # pylint: disable=no-member


if __name__ == "__main__":
Expand Down

0 comments on commit 4537427

Please sign in to comment.