From 5df1bff46f22f818a01389f2d8bf5148d822bde9 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 19:46:37 +0100 Subject: [PATCH 01/90] add datastore_boundary to neural_lam --- neural_lam/train_model.py | 22 ++++++++++++++++++++++ neural_lam/weather_dataset.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..37bf6db7 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,6 +34,11 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) + parser.add_argument( + "--config_path_boundary", + type=str, + help="Path to the configuration for boundary conditions", + ) parser.add_argument( "--model", type=str, @@ -212,6 +217,9 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" + assert ( + args.config_path_boundary is not None + ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -227,10 +235,24 @@ def main(input_args=None): # Load neural-lam configuration and datastore to use config, datastore = load_config_and_datastore(config_path=args.config_path) + config_boundary, datastore_boundary = load_config_and_datastore( + config_path=args.config_path_boundary + ) + + # TODO this should not be required, make more flexible + assert ( + datastore.num_past_forcing_steps + == datastore_boundary.num_past_forcing_steps + ), "Mismatch in num_past_forcing_steps" + assert ( + datastore.num_future_forcing_steps + == datastore_boundary.num_future_forcing_steps + ), "Mismatch in num_future_forcing_steps" # Create datamodule data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..51256e41 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -22,6 +22,8 @@ class WeatherDataset(torch.utils.data.Dataset): ---------- datastore : BaseDatastore The datastore to load the data from (e.g. mdp). + datastore_boundary : BaseDatastore + The boundary datastore to load the data from (e.g. mdp). split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional @@ -43,6 +45,7 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, num_past_forcing_steps=1, @@ -54,6 +57,7 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore + self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps @@ -605,6 +609,7 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, standardize=True, @@ -615,6 +620,7 @@ def __init__( ): super().__init__() self._datastore = datastore + self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps self.ar_steps_train = ar_steps_train @@ -626,6 +632,7 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: + # BUG: There also seem to be issues with "spawn", to be investigated # default to spawn for now, as the default on linux "fork" hangs # when using dask (which the npyfilesmeps datastore uses) self.multiprocessing_context = "spawn" @@ -636,6 +643,7 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, @@ -644,6 +652,7 @@ def setup(self, stage=None): ) self.val_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, @@ -654,6 +663,7 @@ def setup(self, stage=None): if stage == "test" or stage is None: self.test_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, From 46590efc277cb809d788ce5af44133f8b95eb279 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:41 +0100 Subject: [PATCH 02/90] complete integration of boundary in weatherDataset --- neural_lam/weather_dataset.py | 55 ++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 51256e41..10b74086 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,6 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + self.da_boundary = self.datastore_boundary.get_dataarray( + category="boundary", split=self.split + ) # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -118,6 +121,15 @@ def __init__( self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + if self.da_boundary is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="boundary" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.boundary_mean + self.da_boundary_std = self.ds_boundary_stats.boundary_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -352,6 +364,8 @@ def _build_item_dataarrays(self, idx): The dataarray for the target states. da_forcing_windowed : xr.DataArray The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -381,6 +395,11 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None + if self.da_boundary is not None: + da_boundary = self.da_boundary + else: + da_boundary = None + # handle time sampling in a way that is compatible with both analysis # and forecast data da_state = self._slice_state_time( @@ -390,11 +409,17 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed = self._slice_forcing_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) + if da_boundary is not None: + da_boundary_windowed = self._slice_forcing_time( + da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + ) # load the data into memory da_state.load() if da_forcing is not None: da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() da_init_states = da_state.isel(time=slice(0, 2)) da_target_states = da_state.isel(time=slice(2, None)) @@ -417,6 +442,11 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed - self.da_forcing_mean ) / self.da_forcing_std + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + if da_forcing is not None: # stack the `forcing_feature` and `window_sample` dimensions into a # single `forcing_feature` dimension @@ -436,11 +466,31 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: + # stack the `forcing_feature` and `window_sample` dimensions into a + # single `forcing_feature` dimension + da_boundary_windowed = da_boundary_windowed.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + else: + # create an empty forcing tensor with the right shape + da_boundary_windowed = xr.DataArray( + data=np.empty( + (self.ar_steps, da_state.grid_index.size, 0), + ), + dims=("time", "grid_index", "boundary_feature"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + "boundary_feature": [], + }, + ) return ( da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) @@ -475,6 +525,7 @@ def __getitem__(self, idx): da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) = self._build_item_dataarrays(idx=idx) @@ -491,13 +542,15 @@ def __getitem__(self, idx): ) forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) - return init_states, target_states, forcing, target_times + return init_states, target_states, forcing, boundary, target_times def __iter__(self): """ From b990f4941bd7167160a2f265b1e9fe17026ed31e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:55 +0100 Subject: [PATCH 03/90] Add test to check timestep length and spacing --- neural_lam/weather_dataset.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 10b74086..97d9f9c3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -101,6 +101,82 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # Check time coverage for forcing and boundary data + if self.da_forcing is not None or self.da_boundary is not None: + state_times = self.da_state.time + state_time_min = state_times.min().values + state_time_max = state_times.max().values + + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + if self.da_forcing is not None: + forcing_times = self.da_forcing.time + forcing_time_step = get_time_step(forcing_times.values) + forcing_time_min = forcing_times.min().values + forcing_time_max = forcing_times.max().values + + # Calculate required bounds for forcing using its time step + forcing_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * forcing_time_step + ) + forcing_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * forcing_time_step + ) + + if forcing_time_min > forcing_required_time_min: + raise ValueError( + f"Forcing data starts too late." + f"Required start: {forcing_required_time_min}, " + f"but forcing starts at {forcing_time_min}." + ) + + if forcing_time_max < forcing_required_time_max: + raise ValueError( + f"Forcing data ends too early." + f"Required end: {forcing_required_time_max}," + f"but forcing ends at {forcing_time_max}." + ) + + if self.da_boundary is not None: + boundary_times = self.da_boundary.time + boundary_time_step = get_time_step(boundary_times.values) + boundary_time_min = boundary_times.min().values + boundary_time_max = boundary_times.max().values + + # Calculate required bounds for boundary using its time step + boundary_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * boundary_time_step + ) + boundary_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * boundary_time_step + ) + + if boundary_time_min > boundary_required_time_min: + raise ValueError( + f"Boundary data starts too late." + f"Required start: {boundary_required_time_min}, " + f"but boundary starts at {boundary_time_min}." + ) + + if boundary_time_max < boundary_required_time_max: + raise ValueError( + f"Boundary data ends too early." + f"Required end: {boundary_required_time_max}, " + f"but boundary ends at {boundary_time_max}." + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize From 3fd1d6be82d0174b106922a7ff9c74255bac5a35 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:43:57 +0100 Subject: [PATCH 04/90] setting default mdp boundary to 0 gridcells --- neural_lam/datastore/mdp.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..8c67fe58 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -26,7 +26,7 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at `config_path`. A boundary mask is created with `n_boundary_points` @@ -335,19 +335,22 @@ def boundary_mask(self) -> xr.DataArray: boundary point and 0 is not. """ - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + if self._n_boundary_points > 0: + ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) + da_state_variable = ( + ds_unstacked["state"].isel(time=0).isel(state_feature=0) + ) + da_domain_allzero = xr.zeros_like(da_state_variable) + ds_unstacked["boundary_mask"] = da_domain_allzero.isel( + x=slice(self._n_boundary_points, -self._n_boundary_points), + y=slice(self._n_boundary_points, -self._n_boundary_points), + ) + ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( + 1 + ).astype(int) + return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + else: + return None @property def coords_projection(self) -> ccrs.Projection: From 1f2499c3b3fb8493b89d2be97ff301181c756f72 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:44:54 +0100 Subject: [PATCH 05/90] implement time-based slicing combine two slicing fcts into one --- neural_lam/weather_dataset.py | 300 ++++++++++++++++++---------------- 1 file changed, 161 insertions(+), 139 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 97d9f9c3..5d35a4b7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,8 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + # XXX For now boundary data is always considered forcing data self.da_boundary = self.datastore_boundary.get_dataarray( - category="boundary", split=self.split + category="forcing", split=self.split ) # check that with the provided data-arrays and ar_steps that we have a @@ -200,7 +201,7 @@ def get_time_step(times): if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( - category="boundary" + category="forcing" ) ) self.da_boundary_mean = self.ds_boundary_stats.boundary_mean @@ -252,175 +253,156 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_state_time(self, da_state, idx, n_steps: int): + def _slice_time(self, da_state, da_forcing, idx, n_steps: int): """ - Produce a time slice of the given dataarray `da_state` (state) starting - at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `num_past_forcing_steps` class attribute. `Offset` is used to offset the - start of the sample, to assert that enough previous time steps are - available for the 2 initial states and any corresponding forcings - (calculated in `_slice_forcing_time`). + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing` (forcing). For the state data, slicing is done as before + based on `idx`. For the forcing data, nearest neighbor matching is + performed based on the state times. Additionally, the time difference + between the matched forcing times and state times (in multiples of state + time steps) is added to the forcing dataarray. Parameters ---------- da_state : xr.DataArray - The dataarray to slice. This is expected to have a `time` dimension - if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. + The state dataarray to slice. + da_forcing : xr.DataArray + The forcing dataarray to slice. idx : int - The index of the time step to start the sample from. + The index of the time step to start the sample from in the state + data. n_steps : int The number of time steps to include in the sample. Returns ------- - da_sliced : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). + da_forcing_matched : xr.DataArray + The forcing dataarray matched to state times with an added + coordinate 'time_diff', representing the time difference to state + times in multiples of state time steps. """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). + # Number of initial steps required (e.g., for initializing models) init_steps = 2 - # slice the dataarray to include the required number of time steps + + # Slice the state data as before if self.datastore.is_forecast: + # Calculate start and end indices for slicing start_idx = max(0, self.num_past_forcing_steps - init_steps) end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps - # this implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select a analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast, always starting at forecast time 2. - da_sliced = da_state.isel( + + # Slice the state data over the elapsed forecast duration + da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - # create a new time dimension so that the produced sample has a - # `time` dimension, similarly to the analysis only data - da_sliced["time"] = ( - da_sliced.analysis_time + da_sliced.elapsed_forecast_duration + + # Create a new 'time' dimension + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration ) - da_sliced = da_sliced.swap_dims( + da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + else: - # For analysis data we slice the time dimension directly. The offset - # is only relevant for the very first (and last) samples in the - # dataset. + # For analysis data, slice the time dimension directly start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) end_idx = ( idx + max(init_steps, self.num_past_forcing_steps) + n_steps ) - da_sliced = da_state.isel(time=slice(start_idx, end_idx)) - return da_sliced + da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - def _slice_forcing_time(self, da_forcing, idx, n_steps: int): - """ - Produce a time slice of the given dataarray `da_forcing` (forcing) - starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `num_past_forcing_steps` class attribute. It is used to - offset the start of the sample, to ensure that enough previous time - steps are available for the forcing data. The forcing data is windowed - around the current autoregressive time step to include the past and - future forcings. - - Parameters - ---------- - da_forcing : xr.DataArray - The forcing dataarray to slice. This is expected to have a `time` - dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. - idx : int - The index of the time step to start the sample from. - n_steps : int - The number of time steps to include in the sample. - - Returns - ------- - da_concat : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', - 'window', 'forcing_feature'). - """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). The forcing data is windowed around the - # current autregressive time step. The two `init_steps` can also be used - # as past forcings. - init_steps = 2 - da_list = [] + # Get the state times for matching + state_times = da_state_sliced["time"] + # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: - # This implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select an analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast. - # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - current_time = ( - da_forcing.analysis_time[idx] - + da_forcing.elapsed_forecast_duration[offset + step] - ) - - da_sliced = da_forcing.isel( - analysis_time=idx, - elapsed_forecast_duration=slice(start_idx, end_idx + 1), - ) - - da_sliced = da_sliced.rename( - {"elapsed_forecast_duration": "window"} - ) + # Calculate all possible forcing times + forcing_times = ( + da_forcing.analysis_time + da_forcing.elapsed_forecast_duration + ) + forcing_times_flat = forcing_times.stack( + forecast_time=("analysis_time", "elapsed_forecast_duration") + ) - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # Compute time differences + time_deltas = ( + forcing_times_flat.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) + + # Retrieve corresponding indices for analysis_time and + # elapsed_forecast_duration + forecast_time_index = forcing_times_flat["forecast_time"][idx_min] + analysis_time_indices = forecast_time_index["analysis_time"] + elapsed_forecast_duration_indices = forecast_time_index[ + "elapsed_forecast_duration" + ] + + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel( + analysis_time=("time", analysis_time_indices), + elapsed_forecast_duration=( + "time", + elapsed_forecast_duration_indices, + ), + ) - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Assign matched state times to the forcing data + da_forcing_matched["time"] = state_times + da_forcing_matched = da_forcing_matched.swap_dims( + {"elapsed_forecast_duration": "time"} + ) - da_list.append(da_sliced) + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) else: - # For analysis data, we slice the time dimension directly. The - # offset is only relevant for the very first (and last) samples in - # the dataset. - offset = idx + max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - # Slice the data over the desired time window - da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) - - da_sliced = da_sliced.rename({"time": "window"}) - - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # For analysis data, match directly using the 'time' coordinate + forcing_times = da_forcing["time"] - # Add a 'time' dimension to keep track of steps using actual - # time coordinates - current_time = da_forcing.time[offset + step] - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Compute time differences + time_deltas = ( + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) - da_list.append(da_sliced) + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel(time=idx_min) + da_forcing_matched = da_forcing_matched.assign_coords( + time=state_times + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - return da_concat + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) + + return da_state_sliced, da_forcing_matched def _build_item_dataarrays(self, idx): """ @@ -442,6 +424,7 @@ def _build_item_dataarrays(self, idx): The dataarray for the forcing data, windowed for the sample. da_boundary_windowed : xr.DataArray The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -478,15 +461,15 @@ def _build_item_dataarrays(self, idx): # handle time sampling in a way that is compatible with both analysis # and forecast data - da_state = self._slice_state_time( + da_state = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps ) if da_forcing is not None: - da_forcing_windowed = self._slice_forcing_time( + da_forcing_windowed = self._slice_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) if da_boundary is not None: - da_boundary_windowed = self._slice_forcing_time( + da_boundary_windowed = self._slice_time( da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps ) @@ -524,13 +507,32 @@ def _build_item_dataarrays(self, idx): ) / self.da_boundary_std if da_forcing is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # Expand 'time_diff' to align with 'forcing_feature' and 'window' + # dimensions 'time_diff' has dimension ('time'), expand to ('time', + # 'forcing_feature', 'window') + time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( + forcing_feature=da_forcing_windowed["forcing_feature"], + window=da_forcing_windowed["window"], + ) + + # Stack 'forcing_feature' and 'window' into a single + # 'forcing_feature_windowed' dimension da_forcing_windowed = da_forcing_windowed.stack( forcing_feature_windowed=("forcing_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + forcing_feature_windowed=("forcing_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' + da_forcing_windowed = da_forcing_windowed.assign_coords( + time_diff=( + "forcing_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty forcing tensor with the right shape da_forcing_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), @@ -542,14 +544,34 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # If 'da_boundary_windowed' also has 'time_diff', process similarly + # Expand 'time_diff' to align with 'boundary_feature' and 'window' + # dimensions + time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( + boundary_feature=da_boundary_windowed["boundary_feature"], + window=da_boundary_windowed["window"], + ) + + # Stack 'boundary_feature' and 'window' into a single + # 'boundary_feature_windowed' dimension da_boundary_windowed = da_boundary_windowed.stack( boundary_feature_windowed=("boundary_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' + da_boundary_windowed = da_boundary_windowed.assign_coords( + time_diff=( + "boundary_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty boundary tensor with the right shape da_boundary_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), From 1af1481e6884f89ccf39befa37e0d61ed16bbcc3 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 06:26:54 +0100 Subject: [PATCH 06/90] remove all interior_mask and boundary_mask --- neural_lam/datastore/base.py | 17 ------- neural_lam/datastore/mdp.py | 34 -------------- neural_lam/datastore/npyfilesmeps/store.py | 28 ------------ neural_lam/models/ar_model.py | 53 ++++------------------ neural_lam/vis.py | 16 ------- tests/dummy_datastore.py | 22 --------- tests/test_datastores.py | 21 --------- 7 files changed, 10 insertions(+), 181 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..5aeedb2e 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -228,23 +228,6 @@ def get_dataarray( """ pass - @cached_property - @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - pass - @abc.abstractmethod def get_xy(self, category: str) -> np.ndarray: """ diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 8c67fe58..5365c723 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -318,40 +318,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) return ds_stats - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Produce a 0/1 mask for the boundary points of the dataset, these will - sit at the edges of the domain (in x/y extent) and will be used to mask - out the boundary points from the loss function and to overwrite the - boundary points from the prediction. For now this is created when the - mask is requested, but in the future this could be saved to the zarr - file. - - Returns - ------- - xr.DataArray - A 0/1 mask for the boundary points of the dataset, where 1 is a - boundary point and 0 is not. - - """ - if self._n_boundary_points > 0: - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) - else: - return None - @property def coords_projection(self) -> ccrs.Projection: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e80706..146b0627 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -668,34 +668,6 @@ def grid_shape_state(self) -> CartesianGridShape: ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions `[grid_index]`. - - """ - xy = self.get_xy(category="state", stacked=False) - xs = xy[:, :, 0] - ys = xy[:, :, 1] - # Check if x-coordinates are constant along columns - assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" - # Check if y-coordinates are constant along rows - assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" - # Extract unique x and y coordinates - x = xs[:, 0] # Unique x-coordinates (changes along the first axis) - y = ys[0, :] # Unique y-coordinates (changes along the second axis) - values = np.load(self.root_path / "static" / "border_mask.npy") - da_mask = xr.DataArray( - values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" - ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) - return da_mask_stacked_xy - def get_standardization_dataarray(self, category: str) -> xr.Dataset: """Return the standardization dataarray for the given category. This should contain a `{category}_mean` and `{category}_std` variable for diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..4ab73cc7 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -42,7 +42,6 @@ def __init__( da_state_stats = datastore.get_standardization_dataarray( category="state" ) - da_boundary_mask = datastore.boundary_mask num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps @@ -115,18 +114,6 @@ def __init__( # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { "mse": [], } @@ -153,13 +140,6 @@ def configure_optimizers(self): ) return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -191,7 +171,6 @@ def unroll_prediction(self, init_states, forcing_features, true_states): for i in range(pred_steps): forcing = forcing_features[:, i] - border_state = true_states[:, i] pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing @@ -199,19 +178,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states): # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) + prediction_list.append(pred_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 @@ -249,12 +222,14 @@ def training_step(self, batch): """ prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( @@ -287,9 +262,7 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -314,7 +287,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -341,9 +313,7 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) @@ -372,16 +342,13 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..31de8f32 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -86,13 +86,6 @@ def plot_prediction( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, axes = plt.subplots( 1, 2, @@ -112,7 +105,6 @@ def plot_prediction( data_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="plasma", @@ -147,13 +139,6 @@ def plot_spatial_error( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, @@ -170,7 +155,6 @@ def plot_spatial_error( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d404..d62c7356 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -148,12 +148,6 @@ def __init__( times = [self.T0 + dt * i for i in range(n_timesteps)] self.ds.coords["time"] = times - # Add boundary mask - self.ds["boundary_mask"] = xr.DataArray( - np.random.choice([0, 1], size=(n_points_1d, n_points_1d)), - dims=["x", "y"], - ) - # Stack the spatial dimensions into grid_index self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) @@ -342,22 +336,6 @@ def get_dataarray( dim_order = self.expected_dim_order(category=category) return self.ds[category].transpose(*dim_order) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - return self.ds["boundary_mask"] - def get_xy(self, category: str, stacked: bool) -> ndarray: """Return the x, y coordinates of the dataset. diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4a4b1100..a91f6245 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -18,8 +18,6 @@ dataarray for the given category. - `get_dataarray` (method): Return the processed data (as a single `xr.DataArray`) for the given category and test/train/val-split. -- `boundary_mask` (property): Return the boundary mask for the dataset, - with spatial dimensions stacked. - `config` (property): Return the configuration of the datastore. In addition BaseRegularGridDatastore must have the following methods and @@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name): assert n_features["train"] == n_features["val"] == n_features["test"] -@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" - datastore = init_datastore_example(datastore_name) - da_mask = datastore.boundary_mask - - assert isinstance(da_mask, xr.DataArray) - assert set(da_mask.dims) == {"grid_index"} - assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size - - if isinstance(datastore, BaseRegularGridDatastore): - grid_shape = datastore.grid_shape_state - assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y - - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): """Check that the `datastore.get_xy_extent` method is implemented and that From d545cb7576de020b7d721c08741e784bc2b69c24 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:55:56 +0100 Subject: [PATCH 07/90] added gcsfs dependency for era5 weatherbench download --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..5bbe4d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard>=0.22.3", - "mllam-data-prep>=0.5.0", + "gcsfs>=2021.10.0", + "mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep@temp/for-neural-lam-datastores", ] requires-python = ">=3.9" From 5c1a7d7cf9a4befb874ce847424787e818cced75 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:57:57 +0100 Subject: [PATCH 08/90] added new era5 datastore config for boundary --- tests/conftest.py | 19 +++- .../mdp/era5_1000hPa_winds/.gitignore | 2 + .../mdp/era5_1000hPa_winds/config.yaml | 3 + .../era5_1000hPa_winds/era5.datastore.yaml | 90 +++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 6f579621..be5cf3e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,15 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) +DATASTORES_BOUNDARY_EXAMPLES = dict( + mdp=( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_winds" + / "era5.datastore.yaml" + ) +) + DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore @@ -102,5 +111,13 @@ def init_datastore_example(datastore_kind): datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], ) - return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml new file mode 100644 index 00000000..5d1e05f2 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml @@ -0,0 +1,3 @@ +datastore: + kind: mdp + config_path: era5.datastore.yaml diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml new file mode 100644 index 00000000..36b39501 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml @@ -0,0 +1,90 @@ +#TODO: What do these versions mean? Should they be updated? +schema_version: v0.2.0+dev +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-02T00:00 + end: 1990-09-10T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-02T00:00 + end: 1990-09-07T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-05T00:00 + end: 1990-09-08T00:00 + test: + start: 1990-09-06T00:00 + end: 1990-09-10T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + v_component_of_wind: + level: + values: [1000, ] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 From 30e4f05e1c9cc726180868450286d9cf8279ce07 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:58:36 +0100 Subject: [PATCH 09/90] removed left-over boundary-mask references --- neural_lam/datastore/mdp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 5365c723..fd9acb4e 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -26,11 +26,10 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): + def __init__(self, config_path, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded + `config_path`. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified since the zarr was created), otherwise it is created from the configuration file. @@ -41,8 +40,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file` method to then call `mllam_data_prep.create_dataset` to create the dataset. - n_boundary_points : int - The number of boundary points to use in the boundary mask. reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. @@ -69,7 +66,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): if self._ds is None: self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) - self._n_boundary_points = n_boundary_points print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: From 6a8c593f422c2844545feb2cc7e57de520dc1062 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:12 +0100 Subject: [PATCH 10/90] make check for existing category in datastore more flexible (for boundary) --- neural_lam/datastore/mdp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index fd9acb4e..67aaa9d0 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -153,8 +153,8 @@ def get_vars_units(self, category: str) -> List[str]: The units of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_units"].values.tolist() @@ -172,8 +172,8 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature"].values.tolist() @@ -192,8 +192,8 @@ def get_vars_long_names(self, category: str) -> List[str]: The long names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_long_name"].values.tolist() @@ -248,9 +248,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The xarray DataArray object with processed dataset. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") - return None + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] da_category = self._ds[category] From 17c920d36848d61153fd53781d8ec3ac90e5de56 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 20 Nov 2024 16:00:15 +0100 Subject: [PATCH 11/90] implement xarray based (mostly) time slicing and windowing --- neural_lam/weather_dataset.py | 255 +++++++++++++++------------------- 1 file changed, 111 insertions(+), 144 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5d35a4b7..c8806d1c 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -64,10 +64,16 @@ def __init__( self.da_state = self.datastore.get_dataarray( category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) - # XXX For now boundary data is always considered forcing data + # XXX For now boundary data is always considered mdp-forcing data self.da_boundary = self.datastore_boundary.get_dataarray( category="forcing", split=self.split ) @@ -102,53 +108,36 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + # Check time step consistency in state data + _ = get_time_step(self.da_state.time.values) + # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values - def get_time_step(times): - """Calculate the time step from the data""" - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] - if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data forcing_times = self.da_forcing.time - forcing_time_step = get_time_step(forcing_times.values) - forcing_time_min = forcing_times.min().values - forcing_time_max = forcing_times.max().values - - # Calculate required bounds for forcing using its time step - forcing_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * forcing_time_step - ) - forcing_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * forcing_time_step - ) - - if forcing_time_min > forcing_required_time_min: - raise ValueError( - f"Forcing data starts too late." - f"Required start: {forcing_required_time_min}, " - f"but forcing starts at {forcing_time_min}." - ) - - if forcing_time_max < forcing_required_time_max: - raise ValueError( - f"Forcing data ends too early." - f"Required end: {forcing_required_time_max}," - f"but forcing ends at {forcing_time_max}." - ) + _ = get_time_step(forcing_times.values) if self.da_boundary is not None: + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values @@ -204,8 +193,8 @@ def get_time_step(times): category="forcing" ) ) - self.da_boundary_mean = self.ds_boundary_stats.boundary_mean - self.da_boundary_std = self.ds_boundary_stats.boundary_std + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): if self.datastore.is_forecast: @@ -253,7 +242,7 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, da_forcing, idx, n_steps: int): + def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and `da_forcing` (forcing). For the state data, slicing is done as before @@ -316,8 +305,13 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): ) da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) + if da_forcing is None: + return da_state_sliced, None + # Get the state times for matching state_times = da_state_sliced["time"] + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: @@ -371,39 +365,80 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): da_forcing_matched = da_forcing_matched.assign_coords( time_diff=("time", time_diff_steps) ) - else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] # Compute time differences time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + state_times.values[np.newaxis, :] + - forcing_times.values[:, np.newaxis] + ) + idx_min = np.abs(time_deltas).argmin(axis=0) + + time_diff_steps = xr.DataArray( + np.stack( + [ + np.diagonal(time_deltas, offset=offset)[ + -len(state_times) + init_steps : + ] + / state_time_step + for offset in range( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ) + ], + axis=1, + ), + dims=["time", "window"], + coords={ + "time": state_times.isel(time=slice(init_steps, None)), + "window": np.arange( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ), + }, + name="time_diff_steps", ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel(time=idx_min) - da_forcing_matched = da_forcing_matched.assign_coords( - time=state_times + # Create window dimension using rolling + window_size = ( + self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) - - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step + da_forcing_windowed = da_forcing.rolling( + time=window_size, center=True + ).construct(window_dim="window") + da_forcing_matched = da_forcing_windowed.isel( + time=idx_min[init_steps:] ) # Add time difference as a new coordinate da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + time_diff=time_diff_steps ) return da_state_sliced, da_forcing_matched + def _process_windowed_data(self, da_windowed, da_state, da_target_times): + """Helper function to process windowed data after standardization.""" + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + # Stack the 'feature' and 'window' dimensions + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + return xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + def _build_item_dataarrays(self, idx): """ Create the dataarrays for the initial states, target states and forcing @@ -459,18 +494,21 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # handle time sampling in a way that is compatible with both analysis - # and forecast data - da_state = self._slice_time( - da_state=da_state, idx=idx, n_steps=self.ar_steps + # if da_forcing is None, the function will return None for + # da_forcing_windowed + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, ) - if da_forcing is not None: - da_forcing_windowed = self._slice_time( - da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps - ) + if da_boundary is not None: - da_boundary_windowed = self._slice_time( - da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_boundary, ) # load the data into memory @@ -506,83 +544,12 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std - if da_forcing is not None: - # Expand 'time_diff' to align with 'forcing_feature' and 'window' - # dimensions 'time_diff' has dimension ('time'), expand to ('time', - # 'forcing_feature', 'window') - time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( - forcing_feature=da_forcing_windowed["forcing_feature"], - window=da_forcing_windowed["window"], - ) - - # Stack 'forcing_feature' and 'window' into a single - # 'forcing_feature_windowed' dimension - da_forcing_windowed = da_forcing_windowed.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' - da_forcing_windowed = da_forcing_windowed.assign_coords( - time_diff=( - "forcing_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty forcing tensor with the right shape - da_forcing_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "forcing_feature": [], - }, - ) - - if da_boundary is not None: - # If 'da_boundary_windowed' also has 'time_diff', process similarly - # Expand 'time_diff' to align with 'boundary_feature' and 'window' - # dimensions - time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( - boundary_feature=da_boundary_windowed["boundary_feature"], - window=da_boundary_windowed["window"], - ) - - # Stack 'boundary_feature' and 'window' into a single - # 'boundary_feature_windowed' dimension - da_boundary_windowed = da_boundary_windowed.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' - da_boundary_windowed = da_boundary_windowed.assign_coords( - time_diff=( - "boundary_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty boundary tensor with the right shape - da_boundary_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "boundary_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "boundary_feature": [], - }, - ) + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, da_state, da_target_times + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, da_state, da_target_times + ) return ( da_init_states, From 79199956225277cb88b255a514be1a72634926c5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 07:09:52 +0100 Subject: [PATCH 12/90] cleanup analysis based time-slicing --- neural_lam/weather_dataset.py | 85 +++++++++++++++++------------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c8806d1c..bbfb5705 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -245,11 +245,12 @@ def __len__(self): def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done as before - based on `idx`. For the forcing data, nearest neighbor matching is - performed based on the state times. Additionally, the time difference - between the matched forcing times and state times (in multiples of state - time steps) is added to the forcing dataarray. + `da_forcing` (forcing). For the state data, slicing is done based on + `idx`. For the forcing data, nearest neighbor matching is performed + based on the state times. Additionally, the time difference between the + matched forcing times and state times (in multiples of state time steps) + is added to the forcing dataarray. This will be used as an additional + feature in the model (temporal embedding). Parameters ---------- @@ -269,9 +270,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). da_forcing_matched : xr.DataArray - The forcing dataarray matched to state times with an added - coordinate 'time_diff', representing the time difference to state - times in multiples of state time steps. + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -308,9 +308,9 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): if da_forcing is None: return da_state_sliced, None - # Get the state times for matching + # Get the state times and its temporal resolution for matching with + # forcing data state_times = da_state_sliced["time"] - # Calculate time differences in multiples of state time steps state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor @@ -369,39 +369,29 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] - # Compute time differences + # Compute time differences between forcing and state times + # (in multiples of state time steps) + # Retrieve the indices of the closest times in the forcing data time_deltas = ( - state_times.values[np.newaxis, :] - - forcing_times.values[:, np.newaxis] - ) + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) - time_diff_steps = xr.DataArray( - np.stack( - [ - np.diagonal(time_deltas, offset=offset)[ - -len(state_times) + init_steps : - ] - / state_time_step - for offset in range( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ) - ], - axis=1, - ), - dims=["time", "window"], - coords={ - "time": state_times.isel(time=slice(init_steps, None)), - "window": np.arange( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ), - }, - name="time_diff_steps", + time_diff_steps = np.stack( + [ + time_deltas[ + idx_i + - self.num_past_forcing_steps : idx_i + + self.num_future_forcing_steps + + 1, + init_steps + step_i, + ] + for (step_i, idx_i) in enumerate(idx_min[init_steps:]) + ], ) - # Create window dimension using rolling + # Create window dimension for forcing data to stack later window_size = ( self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) @@ -412,9 +402,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): time=idx_min[init_steps:] ) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=time_diff_steps + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, ) return da_state_sliced, da_forcing_matched @@ -423,13 +415,19 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" stacked_dim = "forcing_feature_windowed" if da_windowed is not None: - # Stack the 'feature' and 'window' dimensions + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) + da_windowed = xr.concat( + [da_windowed, da_windowed.time_diff_steps], + dim="forcing_feature_windowed", + ) else: # Create empty DataArray with the correct dimensions and coordinates - return xr.DataArray( + da_windowed = xr.DataArray( data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), dims=("time", "grid_index", f"{stacked_dim}"), coords={ @@ -438,6 +436,7 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): f"{stacked_dim}": [], }, ) + return da_windowed def _build_item_dataarrays(self, idx): """ From 9bafceec0480ead53e4cdd32b24be669c195316c Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:42 +0100 Subject: [PATCH 13/90] implement datastore_boundary in existing tests --- tests/test_datasets.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece0..67eac70e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,12 +14,19 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_shapes(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -31,6 +38,9 @@ def test_dataset_item_shapes(datastore_name): """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_gridpoints = datastore.num_grid_points N_pred_steps = 4 @@ -38,6 +48,7 @@ def test_dataset_item_shapes(datastore_name): num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -48,7 +59,7 @@ def test_dataset_item_shapes(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, target_times = item + init_states, target_states, forcing, boundary, target_times = item # initial states assert init_states.ndim == 3 @@ -81,14 +92,23 @@ def test_dataset_item_shapes(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_create_dataarray_from_tensor(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_create_dataarray_from_tensor( + datastore_name, datastore_boundary_name +): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -158,13 +178,19 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_single_batch(datastore_name, split): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): """Check that the `datastore.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) device_name = ( torch.device("cuda") if torch.cuda.is_available() else "cpu" @@ -210,7 +236,9 @@ def _create_graph(): ) ) - dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) model = GraphLAM(args=args, datastore=datastore, config=config) # noqa From ce06bbc24dc4765944c0b937ace0dc4d0f11f364 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:39:27 +0100 Subject: [PATCH 14/90] allow for grid shape retrieval from forcing data --- neural_lam/datastore/mdp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 67aaa9d0..57a3249f 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -377,8 +377,17 @@ def grid_shape_state(self): The shape of the cartesian grid for the state variables. """ - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) + da_x, da_y = ds_forcing.x, ds_forcing.y + else: + ds_state = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) From 884b5c623117cb18c405ac869caaff028625e5fb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:40:47 +0100 Subject: [PATCH 15/90] rearrange time slicing, boundary first --- neural_lam/weather_dataset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index bbfb5705..32add37a 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -495,13 +495,6 @@ def _build_item_dataarrays(self, idx): # if da_forcing is None, the function will return None for # da_forcing_windowed - da_state, da_forcing_windowed = self._slice_time( - da_state=da_state, - idx=idx, - n_steps=self.ar_steps, - da_forcing=da_forcing, - ) - if da_boundary is not None: _, da_boundary_windowed = self._slice_time( da_state=da_state, @@ -509,6 +502,12 @@ def _build_item_dataarrays(self, idx): n_steps=self.ar_steps, da_forcing=da_boundary, ) + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, + ) # load the data into memory da_state.load() From 5904cbe9da67d3e98eaab0cebd501a2ad0ded7f3 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 25 Nov 2024 16:42:21 +0100 Subject: [PATCH 16/90] identified issue, cleanup next --- neural_lam/datastore/base.py | 9 ++++- neural_lam/datastore/mdp.py | 5 ++- neural_lam/models/ar_model.py | 46 ++++++++++++++++++++-- neural_lam/train_model.py | 2 +- neural_lam/vis.py | 73 +++++++++++++++++++++++++---------- 5 files changed, 107 insertions(+), 28 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..b0055e39 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) - extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] return [float(v) for v in extent] @property diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..0d1aac7b 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,4 +1,5 @@ # Standard library +import copy import warnings from functools import cached_property from pathlib import Path @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection: class_name = projection_info["class_name"] ProjectionClass = getattr(ccrs, class_name) - kwargs = projection_info["kwargs"] + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) globe_kwargs = kwargs.pop("globe", {}) if len(globe_kwargs) > 0: diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..b55143f0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -7,12 +7,14 @@ import pytorch_lightning as pl import torch import wandb +from loguru import logger # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -147,6 +149,14 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + def _create_dataarray_from_tensor(self, tensor, time, split, category): + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time, dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -406,10 +416,13 @@ def test_step(self, batch, batch_idx): ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch @@ -422,18 +435,34 @@ def plot_examples(self, batch, n_examples, prediction=None): prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[3] # Rescale to original data scale prediction_rescaled = prediction * self.state_std + self.state_mean target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): # Each slice is (pred_steps, num_grid_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -465,6 +494,10 @@ def plot_examples(self, batch, n_examples, prediction=None): title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i + ).squeeze(), + da_target=da_target.isel(state_feature=var_i).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -476,6 +509,11 @@ def plot_examples(self, batch, n_examples, prediction=None): ] example_i = self.plotted_examples + for i, fig in enumerate(var_figs): + fn = f"example_{i}_{example_i}_t{t_i}.png" + fig.savefig(fn) + logger.info(f"Saved example plot to {fn}") + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..9d1d5039 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -23,7 +23,7 @@ } -@logger.catch +@logger.catch(reraise=True) def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..357a8977 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -68,6 +68,8 @@ def plot_prediction( pred, target, datastore: BaseRegularGridDatastore, + da_prediction=None, + da_target=None, title=None, vrange=None, ): @@ -88,10 +90,8 @@ def plot_prediction( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + mask_values = np.invert(da_mask.values.astype(bool)).astype(float) + pixel_alpha = mask_values.clip(0.7, 1) # Faded border region fig, axes = plt.subplots( 1, @@ -100,29 +100,62 @@ def plot_prediction( subplot_kw={"projection": datastore.coords_projection}, ) + use_xarray = True + # Plot pred and target - for ax, data in zip(axes, (target, pred)): + + if not use_xarray: + for ax, data in zip(axes, (target, pred)): + ax.coastlines() # Add coastline outlines + data_grid = ( + data.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() + .numpy() + ) + im = ax.imshow( + data_grid, + origin="lower", + extent=extent, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + cmap="plasma", + ) + + cbar = fig.colorbar(im, aspect=30) + cbar.ax.tick_params(labelsize=10) + + x = da_target.x.values + y = da_target.y.values + extent = [x.min(), x.max(), y.min(), y.max()] + for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() - ) - im = ax.imshow( - data_grid, + im = da.plot.imshow( + ax=ax, origin="lower", + x="x", extent=extent, - alpha=pixel_alpha, + alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", + transform=datastore.coords_projection, ) + # da.plot.pcolormesh( + # ax=ax, + # x="x", + # vmin=vmin, + # vmax=vmax, + # transform=datastore.coords_projection, + # cmap="plasma", + # ) + # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -150,9 +183,7 @@ def plot_spatial_error( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region fig, ax = plt.subplots( figsize=(5, 4.8), @@ -161,8 +192,10 @@ def plot_spatial_error( ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() .numpy() ) From efe03027842a22139d6554d68ffee7b6ebe0ad73 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 13:46:05 +0100 Subject: [PATCH 17/90] use xarray plot only --- neural_lam/models/ar_model.py | 47 +++++++++++++++++++++++++++-------- neural_lam/vis.py | 43 +++----------------------------- 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index b55143f0..0af25367 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,7 +8,7 @@ import pytorch_lightning as pl import torch import wandb -from loguru import logger +import xarray as xr # Local from .. import metrics, vis @@ -149,7 +150,35 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] - def _create_dataarray_from_tensor(self, tensor, time, split, category): + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature] + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time, dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( @@ -482,14 +511,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred=pred_t[:, var_i], - target=target_t[:, var_i], datastore=self._datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", @@ -509,10 +534,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ] example_i = self.plotted_examples - for i, fig in enumerate(var_figs): - fn = f"example_{i}_{example_i}_t{t_i}.png" - fig.savefig(fn) - logger.info(f"Saved example plot to {fn}") + # for i, fig in enumerate(var_figs): + # fn = f"example_{i}_{example_i}_t{t_i}.png" + # fig.savefig(fn) + # logger.info(f"Saved example plot to {fn}") wandb.log( { diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 357a8977..47c68e4f 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -65,8 +65,6 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, - target, datastore: BaseRegularGridDatastore, da_prediction=None, da_target=None, @@ -81,8 +79,8 @@ def plot_prediction( """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange @@ -100,39 +98,13 @@ def plot_prediction( subplot_kw={"projection": datastore.coords_projection}, ) - use_xarray = True - # Plot pred and target - - if not use_xarray: - for ax, data in zip(axes, (target, pred)): - ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape( - [datastore.grid_shape_state.x, datastore.grid_shape_state.y] - ) - .T.cpu() - .numpy() - ) - im = ax.imshow( - data_grid, - origin="lower", - extent=extent, - alpha=pixel_alpha, - vmin=vmin, - vmax=vmax, - cmap="plasma", - ) - - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) - x = da_target.x.values y = da_target.y.values extent = [x.min(), x.max(), y.min(), y.max()] for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - im = da.plot.imshow( + da.plot.imshow( ax=ax, origin="lower", x="x", @@ -144,15 +116,6 @@ def plot_prediction( transform=datastore.coords_projection, ) - # da.plot.pcolormesh( - # ax=ax, - # x="x", - # vmin=vmin, - # vmax=vmax, - # transform=datastore.coords_projection, - # cmap="plasma", - # ) - # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) From a489c2ed974397ea230d2e61b842d8d9384867dc Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 14:07:06 +0100 Subject: [PATCH 18/90] don't reraise --- neural_lam/train_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 9d1d5039..74146c89 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -23,7 +23,7 @@ } -@logger.catch(reraise=True) +@logger.catch def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( From 242d08bcb5374cdd90aecfd49f501ed233f1ce0c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 14:50:03 +0100 Subject: [PATCH 19/90] remove debug plot --- neural_lam/models/ar_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0af25367..c875688b 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -534,10 +534,6 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ] example_i = self.plotted_examples - # for i, fig in enumerate(var_figs): - # fn = f"example_{i}_{example_i}_t{t_i}.png" - # fig.savefig(fn) - # logger.info(f"Saved example plot to {fn}") wandb.log( { From c1f706c29542d770ed49e910f8b9bd5caff1fdec Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 16:04:24 +0100 Subject: [PATCH 20/90] remove extent calc used in diagnosing issue --- neural_lam/vis.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 47c68e4f..c814aacf 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -99,9 +99,6 @@ def plot_prediction( ) # Plot pred and target - x = da_target.x.values - y = da_target.y.values - extent = [x.min(), x.max(), y.min(), y.max()] for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines da.plot.imshow( From cf8e3e4c1be93a6ec074368aaf6f91c8042b5278 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 14:51:36 +0100 Subject: [PATCH 21/90] add type annotation --- neural_lam/vis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index c814aacf..d6b57f88 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr # Local from . import utils @@ -66,8 +67,8 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( datastore: BaseRegularGridDatastore, - da_prediction=None, - da_target=None, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, title=None, vrange=None, ): From 85160cecf13ecfc9fc6a589ac1a9e3542da45e23 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 15:03:06 +0100 Subject: [PATCH 22/90] ensure tensor copy to cpu mem before data-array creation --- neural_lam/models/ar_model.py | 10 ++++++---- neural_lam/weather_dataset.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index c875688b..0d8e6e3c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -167,10 +167,12 @@ def _create_dataarray_from_tensor( ---------- tensor : torch.Tensor The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature] + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. time : Union[int,List[int]] The time index or indices for the data, given as integers or a list - of integers representing epoch time in nanoseconds. + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. split : str The split of the data, either 'train', 'val', or 'test' category : str @@ -180,9 +182,9 @@ def _create_dataarray_from_tensor( # not how this should be done but whether WeatherDataset should be # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) - time = np.array(time, dtype="datetime64[ns]") + time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor, time=time, category=category + tensor=tensor.cpu().numpy(), time=time, category=category ) return da diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..b5f85580 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -529,7 +529,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +582,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, ) From 52c452879f56c7f982cfd5d55a5259f37cb6b030 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 15:05:36 +0100 Subject: [PATCH 23/90] apply time-indexing to support ar_steps_val > 1 --- neural_lam/models/ar_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0d8e6e3c..44baf9c2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -522,9 +522,11 @@ def plot_examples(self, batch, n_examples, split, prediction=None): f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, da_prediction=da_prediction.isel( - state_feature=var_i + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 ).squeeze(), - da_target=da_target.isel(state_feature=var_i).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( From b96d8ebc0c5c22f980e22384efafcd08db20577f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:42:05 +0100 Subject: [PATCH 24/90] renaming test datastores --- tests/datastore_examples/.gitignore | 3 +- .../.gitignore | 0 .../era5_1000hPa_danra_100m_winds/config.yaml | 12 +++ .../danra.datastore.yaml | 99 +++++++++++++++++++ .../era5.datastore.yaml | 23 ++--- .../mdp/era5_1000hPa_winds/config.yaml | 3 - 6 files changed, 122 insertions(+), 18 deletions(-) rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/.gitignore (100%) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/era5.datastore.yaml (80%) delete mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index e84e6493..4fbd2326 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,3 @@ npyfilesmeps/*.zip -npyfilesmeps/meps_example_reduced/ +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore similarity index 100% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 00000000..a158bee3 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..3edf1267 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml similarity index 80% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index 36b39501..c97da4bc 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -1,5 +1,4 @@ -#TODO: What do these versions mean? Should they be updated? -schema_version: v0.2.0+dev +schema_version: v0.5.0 dataset_version: v1.0.0 output: @@ -7,8 +6,8 @@ output: forcing: [time, grid_index, forcing_feature] coord_ranges: time: - start: 1990-09-02T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 step: PT6H chunking: time: 1 @@ -16,17 +15,17 @@ output: dim: time splits: train: - start: 1990-09-02T00:00 - end: 1990-09-07T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 compute_statistics: ops: [mean, std, diff_mean, diff_std] dims: [grid_index, time] val: - start: 1990-09-05T00:00 - end: 1990-09-08T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 test: - start: 1990-09-06T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: @@ -37,10 +36,6 @@ inputs: level: values: [1000,] units: hPa - v_component_of_wind: - level: - values: [1000, ] - units: hPa dim_mapping: time: method: rename diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml deleted file mode 100644 index 5d1e05f2..00000000 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml +++ /dev/null @@ -1,3 +0,0 @@ -datastore: - kind: mdp - config_path: era5.datastore.yaml From 72da25fd15d46a4497728935e9767c34330f1ccc Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:15 +0100 Subject: [PATCH 25/90] adding num_past/future_boundary_step args --- neural_lam/train_model.py | 37 +++++++++++++++------------------ tests/test_datasets.py | 43 +++++++++++++++++++++++++++++++++------ tests/test_training.py | 24 ++++++++++++++++++++-- 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 37bf6db7..2a61e86c 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,11 +34,6 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) - parser.add_argument( - "--config_path_boundary", - type=str, - help="Path to the configuration for boundary conditions", - ) parser.add_argument( "--model", type=str, @@ -208,6 +203,18 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -217,9 +224,6 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" - assert ( - args.config_path_boundary is not None - ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -234,21 +238,10 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore = load_config_and_datastore(config_path=args.config_path) - config_boundary, datastore_boundary = load_config_and_datastore( - config_path=args.config_path_boundary + config, datastore, datastore_boundary = load_config_and_datastore( + config_path=args.config_path ) - # TODO this should not be required, make more flexible - assert ( - datastore.num_past_forcing_steps - == datastore_boundary.num_past_forcing_steps - ), "Mismatch in num_past_forcing_steps" - assert ( - datastore.num_future_forcing_steps - == datastore_boundary.num_future_forcing_steps - ), "Mismatch in num_future_forcing_steps" - # Create datamodule data_module = WeatherDataModule( datastore=datastore, @@ -258,6 +251,8 @@ def main(input_args=None): standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 67eac70e..5fbe4a5d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -42,10 +42,13 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): datastore_boundary_name ) N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -53,6 +56,8 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) item = dataset[0] @@ -77,8 +82,23 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( - num_past_forcing_steps + num_future_forcing_steps + 1 + # each stacked forcing feature has one corresponding temporal embedding + assert ( + forcing.shape[2] + == datastore.get_num_data_vars("forcing") + * (num_past_forcing_steps + num_future_forcing_steps + 1) + * 2 + ) + + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert ( + boundary.shape[2] + == datastore_boundary.get_num_data_vars("forcing") + * (num_past_boundary_steps + num_future_boundary_steps + 1) + * 2 ) # batch times @@ -88,6 +108,7 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length + dataset[len(dataset) - 1] @@ -106,6 +127,9 @@ def test_dataset_item_create_dataarray_from_tensor( N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -113,16 +137,22 @@ def test_dataset_item_create_dataarray_from_tensor( ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - _, target_states, _, target_times_arr = dataset[idx] - _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( - idx=idx - ) + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) target_times = np.array(target_times_arr, dtype="datetime64[ns]") np.testing.assert_equal(target_times, da_target_times_true.values) @@ -272,6 +302,7 @@ def test_dataset_length(dataset_config): dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=dataset_config["ar_steps"], num_past_forcing_steps=dataset_config["past"], diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847d..28566a4b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,18 +14,33 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataModule -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( f"Skipping test for {datastore_name} as it is not a regular " "grid datastore." ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a regular " + "grid datastore." + ) if torch.cuda.is_available(): device_name = "cuda" @@ -59,6 +74,7 @@ def test_training(datastore_name): data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=3, ar_steps_eval=5, standardize=True, @@ -66,6 +82,8 @@ def test_training(datastore_name): num_workers=1, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, ) class ModelArgs: @@ -85,6 +103,8 @@ class ModelArgs: metrics_watch = [] num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 model_args = ModelArgs() From 244f1ccb77e9d12852e3a59feddff5034f54ef95 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:51 +0100 Subject: [PATCH 26/90] using combined config file --- neural_lam/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..914ebb38 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -168,4 +168,15 @@ def load_config_and_datastore( datastore_kind=config.datastore.kind, config_path=datastore_config_path ) - return config, datastore + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary From a9cc36e23de294f21fce15f903a4ba7d0a8496a6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:12 +0100 Subject: [PATCH 27/90] proper handling of state/forcing/boundary in dataset --- neural_lam/weather_dataset.py | 304 +++++++++++++++++++--------------- 1 file changed, 167 insertions(+), 137 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 32add37a..b717c40a 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -38,6 +38,16 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time t + Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before + t, given num_past_forcing_steps) are included as boundary inputs at time + t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -50,6 +60,8 @@ def __init__( ar_steps=3, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, standardize=True, ): super().__init__() @@ -60,10 +72,10 @@ def __init__( self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray( - category="state", split=self.split - ) + self.da_state = self.datastore.get_dataarray(category="state", split=self.split) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -74,9 +86,12 @@ def __init__( category="forcing", split=self.split ) # XXX For now boundary data is always considered mdp-forcing data - self.da_boundary = self.datastore_boundary.get_dataarray( - category="forcing", split=self.split - ) + if self.datastore_boundary is not None: + self.da_boundary = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -97,9 +112,7 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order( - category=part - ) + expected_dim_order = self.datastore.expected_dim_order(category=part) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -108,6 +121,23 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + else: + self.da_state = self.da_state + def get_time_step(times): """Calculate the time step from the data""" time_diffs = np.diff(times) @@ -119,11 +149,18 @@ def get_time_step(times): return time_diffs[0] # Check time step consistency in state data - _ = get_time_step(self.da_state.time.values) + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + _ = get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: - state_times = self.da_state.time + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values @@ -131,26 +168,30 @@ def get_time_step(times): # Forcing data is part of the same datastore as state data # During creation the time dimension of the forcing data # is matched to the state data - forcing_times = self.da_forcing.time - _ = get_time_step(forcing_times.values) + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + else: + forcing_times = self.da_forcing.time + get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore # The boundary data is allowed to have a different time_step # Check that the boundary data covers the required time range - boundary_times = self.da_boundary.time + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary.analysis_time + else: + boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * boundary_time_step + state_time_min - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * boundary_time_step + state_time_max + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -179,10 +220,8 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = ( - self.datastore.get_standardization_dataarray( - category="forcing" - ) + self.ds_forcing_stats = self.datastore.get_standardization_dataarray( + category="forcing" ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -208,7 +247,7 @@ def __len__(self): warnings.warn( "only using first ensemble member, so dataset size is " " effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", + f"({self.datastore._num_ensemble_members})", UserWarning, ) @@ -242,36 +281,50 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing_boundary=None, + num_past_steps=None, + num_future_steps=None, + ): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done based on - `idx`. For the forcing data, nearest neighbor matching is performed - based on the state times. Additionally, the time difference between the - matched forcing times and state times (in multiples of state time steps) - is added to the forcing dataarray. This will be used as an additional - feature in the model (temporal embedding). + `da_forcing_boundary`. For the state data, slicing is done + based on `idx`. For the forcing/boundary data, nearest neighbor matching + is performed based on the state times. Additionally, the time difference + between the matched forcing/boundary times and state times (in multiples + of state time steps) is added to the forcing dataarray. This will be + used as an additional feature in the model (temporal embedding). Parameters ---------- da_state : xr.DataArray The state dataarray to slice. - da_forcing : xr.DataArray - The forcing dataarray to slice. idx : int The index of the time step to start the sample from in the state data. n_steps : int The number of time steps to include in the sample. + da_forcing_boundary : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. Returns ------- da_state_sliced : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). - da_forcing_matched : xr.DataArray + da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', - 'forcing_feature_windowed'). + 'forcing/boundary_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -279,8 +332,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # Slice the state data as before if self.datastore.is_forecast: # Calculate start and end indices for slicing - start_idx = max(0, self.num_past_forcing_steps - init_steps) - end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + start_idx = max(0, num_past_steps - init_steps) + end_idx = max(init_steps, num_past_steps) + n_steps # Slice the state data over the elapsed forecast duration da_state_sliced = da_state.isel( @@ -299,13 +352,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): else: # For analysis data, slice the time dimension directly - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = ( - idx + max(init_steps, self.num_past_forcing_steps) + n_steps - ) + start_idx = idx + max(0, num_past_steps - init_steps) + end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - if da_forcing is None: + if da_forcing_boundary is None: return da_state_sliced, None # Get the state times and its temporal resolution for matching with @@ -313,78 +364,66 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] - # Match forcing data to state times based on nearest neighbor - if self.datastore.is_forecast: - # Calculate all possible forcing times - forcing_times = ( - da_forcing.analysis_time + da_forcing.elapsed_forecast_duration - ) - forcing_times_flat = forcing_times.stack( - forecast_time=("analysis_time", "elapsed_forecast_duration") - ) + if "analysis_time" in da_forcing_boundary.dims: + idx = np.abs( + da_forcing_boundary.analysis_time.values + - self.da_state.analysis_time.values[idx] + ).argmin() + # Add a 'time' dimension using the actual forecast times + offset = max(init_steps, num_past_steps) + da_list = [] + for step in range(n_steps): + start_idx = offset + step - num_past_steps + end_idx = offset + step + num_future_steps + + current_time = ( + da_forcing_boundary.analysis_time[idx] + + da_forcing_boundary.elapsed_forecast_duration[offset + step] + ) - # Compute time differences - time_deltas = ( - forcing_times_flat.values[:, np.newaxis] - - state_times.values[np.newaxis, :] - ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - - # Retrieve corresponding indices for analysis_time and - # elapsed_forecast_duration - forecast_time_index = forcing_times_flat["forecast_time"][idx_min] - analysis_time_indices = forecast_time_index["analysis_time"] - elapsed_forecast_duration_indices = forecast_time_index[ - "elapsed_forecast_duration" - ] - - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel( - analysis_time=("time", analysis_time_indices), - elapsed_forecast_duration=( - "time", - elapsed_forecast_duration_indices, - ), - ) + da_sliced = da_forcing_boundary.isel( + analysis_time=idx, + elapsed_forecast_duration=slice(start_idx, end_idx + 1), + ) - # Assign matched state times to the forcing data - da_forcing_matched["time"] = state_times - da_forcing_matched = da_forcing_matched.swap_dims( - {"elapsed_forecast_duration": "time"} - ) + da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step - ) + da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + + da_list.append(da_sliced) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + # Concatenate the list of DataArrays along the 'time' dimension + da_forcing_boundary_matched = xr.concat(da_list, dim="time") + forcing_time_step = ( + da_forcing_boundary_matched.time.values[1] + - da_forcing_boundary_matched.time.values[0] ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( + forcing_time_step / state_time_step + ) + time_diff_steps = da_forcing_boundary_matched.isel( + grid_index=0, forcing_feature=0 + ).data + else: # For analysis data, match directly using the 'time' coordinate - forcing_times = da_forcing["time"] + forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) time_diff_steps = np.stack( [ time_deltas[ - idx_i - - self.num_past_forcing_steps : idx_i - + self.num_future_forcing_steps - + 1, + idx_i - num_past_steps : idx_i + num_future_steps + 1, init_steps + step_i, ] for (step_i, idx_i) in enumerate(idx_min[init_steps:]) @@ -392,24 +431,22 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): ) # Create window dimension for forcing data to stack later - window_size = ( - self.num_past_forcing_steps + self.num_future_forcing_steps + 1 - ) - da_forcing_windowed = da_forcing.rolling( - time=window_size, center=True + window_size = num_past_steps + num_future_steps + 1 + da_forcing_boundary_windowed = da_forcing_boundary.rolling( + time=window_size, center=False ).construct(window_dim="window") - da_forcing_matched = da_forcing_windowed.isel( + da_forcing_boundary_matched = da_forcing_boundary_windowed.isel( time=idx_min[init_steps:] ) - # Add time difference as a new coordinate to concatenate to the - # forcing features later - da_forcing_matched["time_diff_steps"] = ( - ("time", "window"), - time_diff_steps, - ) + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_boundary_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, + ) - return da_state_sliced, da_forcing_matched + return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" @@ -462,23 +499,7 @@ def _build_item_dataarrays(self, idx): da_target_times : xr.DataArray The dataarray for the target times. """ - # handling ensemble data - if self.datastore.is_ensemble: - # for the now the strategy is to only include the first ensemble - # member - # XXX: this could be changed to include all ensemble members by - # splitting `idx` into two parts, one for the analysis time and one - # for the ensemble member and then increasing self.__len__ to - # include all ensemble members - warnings.warn( - "only use of ensemble member 0 (the first member) is " - "implemented for ensemble data" - ) - i_ensemble = 0 - da_state = self.da_state.isel(ensemble_member=i_ensemble) - else: - da_state = self.da_state - + da_state = self.da_state if self.da_forcing is not None: if "ensemble_member" in self.da_forcing.dims: raise NotImplementedError( @@ -500,13 +521,19 @@ def _build_item_dataarrays(self, idx): da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_boundary, + da_forcing_boundary=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, ) + else: + da_boundary_windowed = None da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_forcing, + da_forcing_boundary=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, ) # load the data into memory @@ -521,9 +548,7 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = ( - da_init_states - self.da_state_mean - ) / self.da_state_std + da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -595,9 +620,7 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor( - da_target_states.values, dtype=tensor_dtype - ) + target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -707,10 +730,7 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if ( - grid_coord in da_datastore_state.coords - and grid_coord not in da.coords - ): + if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: @@ -731,6 +751,8 @@ def __init__( standardize=True, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, batch_size=4, num_workers=16, ): @@ -739,6 +761,8 @@ def __init__( self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -765,6 +789,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, @@ -774,6 +800,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) if stage == "test" or stage is None: @@ -785,6 +813,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) def train_dataloader(self): From dcc0b46861ff1263c688301eca265bd62803616f Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:35 +0100 Subject: [PATCH 28/90] datastore_boundars=None introduced --- .../datastore/npyfilesmeps/compute_standardization_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index f2c80e8a..4207812f 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -172,6 +172,7 @@ def main( ar_steps = 63 ds = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=False, From a3b3bde9ed1b044b32afde7e4b12bc8e4a1593e6 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:02 +0100 Subject: [PATCH 29/90] bug fix for file retrieval per member --- neural_lam/datastore/npyfilesmeps/store.py | 51 +++++++++------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 146b0627..7ee583be 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,9 +244,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray( - features=[feature], split=split - ) + self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features ] da = xr.concat(das, dim="feature") @@ -259,9 +257,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = ( - da.analysis_time + da.elapsed_forecast_duration - ).chunk({"elapsed_forecast_duration": 1}) + da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( + {"elapsed_forecast_duration": 1} + ) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -339,10 +337,7 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if ( - set(features).difference(self.get_vars_names(category="static")) - == set() - ): + if set(features).difference(self.get_vars_names(category="static")) == set(): assert split in ( "train", "val", @@ -356,12 +351,8 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names( - category="state" - ): - raise ValueError( - "Member can only be specified for the 'state' category" - ) + if member is not None and features != self.get_vars_names(category="state"): + raise ValueError("Member can only be specified for the 'state' category") concat_axis = 0 @@ -377,9 +368,7 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones( - len(features) + n_to_drop, dtype=bool - ) + feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -445,7 +434,7 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split) + coord_values = self._get_analysis_times(split=split, member_id=member) elif d == "y": coord_values = y elif d == "x": @@ -464,9 +453,7 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format( - analysis_time=analysis_time, **file_params - ) + / filename_format.format(analysis_time=analysis_time, **file_params) for analysis_time in coords["analysis_time"] ] else: @@ -505,7 +492,7 @@ def _get_single_timeseries_dataarray( return da - def _get_analysis_times(self, split) -> List[np.datetime64]: + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: """Get the analysis times for the given split by parsing the filenames of all the files found for the given split. @@ -513,6 +500,8 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ---------- split : str The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. Returns ------- @@ -520,8 +509,12 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: The analysis times for the given split. """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) - pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) sample_dir = self.root_path / "samples" / split sample_files = sample_dir.glob(pattern) @@ -531,9 +524,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError( - f"No files found in {sample_dir} with pattern {pattern}" - ) + raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") return times @@ -690,9 +681,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load( - self.root_path / "static" / fn, weights_only=True - ).numpy() + return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() mean_diff_values = None std_diff_values = None From 3ffc413e2f669dafd4c745a50b9b723fff231316 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:17 +0100 Subject: [PATCH 30/90] rename datastore for tests --- tests/conftest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index be5cf3e7..90a86d0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,14 +94,14 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) -DATASTORES_BOUNDARY_EXAMPLES = dict( - mdp=( +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( DATASTORE_EXAMPLES_ROOT_PATH / "mdp" - / "era5_1000hPa_winds" + / "era5_1000hPa_danra_100m_winds" / "era5.datastore.yaml" - ) -) + ), +} DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore From 85aad66c8e9eec4e0b4e95cabb753d8492a0c49a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:31 +0100 Subject: [PATCH 31/90] aligned time with danra for easier boundary testing --- tests/dummy_datastore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index d62c7356..a958b8f5 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -28,7 +28,7 @@ class DummyDatastore(BaseRegularGridDatastore): """ SHORT_NAME = "dummydata" - T0 = isodate.parse_datetime("2021-01-01T00:00:00") + T0 = isodate.parse_datetime("1990-09-02T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) CARTESIAN_COORDS = ["x", "y"] From 64f057f78b713e39496abfc3962affa794666369 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:50 +0100 Subject: [PATCH 32/90] Fixed test for temporal embedding --- tests/test_time_slicing.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..2f5ed96c 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,9 +40,7 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) + da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -78,10 +76,8 @@ def get_vars_long_names(self, category): def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -93,6 +89,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,9 +98,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ - tensor.numpy() for tensor in sample - ] + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] expected_init_states = [0, 1] if ar_steps == 3: @@ -130,7 +125,7 @@ def test_time_slicing_analysis( # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) assert init_states.shape == (2, 1, 1) assert init_states[:, 0, 0].tolist() == expected_init_states @@ -141,6 +136,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( 3, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + # Factor 2 because each window step has a temporal embedding + (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, + ) + np.testing.assert_equal( + forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], + np.array(expected_forcing_values), ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values)) From 6205dbd88f1b208118d93da6d12c0a1be672caef Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 2 Dec 2024 10:26:54 +0100 Subject: [PATCH 33/90] pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..fdcb7f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "torch>=2.3.0", "torch-geometric==2.3.1", "parse>=1.20.2", - "dataclass-wizard>=0.22.3", + "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 551cd267235a82378ab28f2b1a4db90523d87ea8 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:48 +0100 Subject: [PATCH 34/90] allow boundary as input to ar_model.common_step --- neural_lam/models/ar_model.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 4ab73cc7..4a08306d 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -107,7 +107,9 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - + num_forcing_vars + # Factor 2 because of temporal embedding or windowed features + + 2 + * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) @@ -200,19 +202,20 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states + Predict on single batch batch consists of: + init_states: (B, 2,num_grid_nodes, d_features) + target_states: (B, pred_steps,num_grid_nodes, d_features) + forcing_features: (B, pred_steps,num_grid_nodes, d_forcing) + boundary_features: (B, pred_steps,num_grid_nodes, d_boundaries) + batch_times: (B, pred_steps) """ - (init_states, target_states, forcing_features, batch_times) = batch + (init_states, target_states, forcing_features, _, batch_times) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + ) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std, batch_times From fc95350a28cbdb81419962b203e0bb08e36520dd Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:56 +0100 Subject: [PATCH 35/90] linting --- neural_lam/datastore/npyfilesmeps/store.py | 43 ++++++++---- neural_lam/weather_dataset.py | 66 ++++++++++++------- .../era5.datastore.yaml | 2 +- tests/test_time_slicing.py | 12 +++- tests/test_training.py | 17 ++--- 5 files changed, 91 insertions(+), 49 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 7ee583be..24349e7e 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,7 +244,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray(features=[feature], split=split) + self._get_single_timeseries_dataarray( + features=[feature], split=split + ) for feature in features ] da = xr.concat(das, dim="feature") @@ -257,9 +259,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( - {"elapsed_forecast_duration": 1} - ) + da_forecast_time = ( + da.analysis_time + da.elapsed_forecast_duration + ).chunk({"elapsed_forecast_duration": 1}) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -337,7 +339,10 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if set(features).difference(self.get_vars_names(category="static")) == set(): + if ( + set(features).difference(self.get_vars_names(category="static")) + == set() + ): assert split in ( "train", "val", @@ -351,8 +356,12 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names(category="state"): - raise ValueError("Member can only be specified for the 'state' category") + if member is not None and features != self.get_vars_names( + category="state" + ): + raise ValueError( + "Member can only be specified for the 'state' category" + ) concat_axis = 0 @@ -368,7 +377,9 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) + feature_dim_mask = np.ones( + len(features) + n_to_drop, dtype=bool + ) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -434,7 +445,9 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split, member_id=member) + coord_values = self._get_analysis_times( + split=split, member_id=member + ) elif d == "y": coord_values = y elif d == "x": @@ -453,7 +466,9 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format(analysis_time=analysis_time, **file_params) + / filename_format.format( + analysis_time=analysis_time, **file_params + ) for analysis_time in coords["analysis_time"] ] else: @@ -524,7 +539,9 @@ def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") + raise ValueError( + f"No files found in {sample_dir} with pattern {pattern}" + ) return times @@ -681,7 +698,9 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() + return torch.load( + self.root_path / "static" / fn, weights_only=True + ).numpy() mean_diff_values = None std_diff_values = None diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b717c40a..b3d86292 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -41,13 +41,13 @@ class WeatherDataset(torch.utils.data.Dataset): num_past_boundary_steps: int, optional Number of past time steps to include in boundary input. If set to i, boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, - given num_future_forcing_steps) are included as boundary inputs at time t - Default is 1. + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. num_future_boundary_steps: int, optional Number of future time steps to include in boundary input. If set to j, - boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before - t, given num_past_forcing_steps) are included as boundary inputs at time - t. Default is 1. + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -75,7 +75,9 @@ def __init__( self.num_past_boundary_steps = num_past_boundary_steps self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray(category="state", split=self.split) + self.da_state = self.datastore.get_dataarray( + category="state", split=self.split + ) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -112,7 +114,9 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order(category=part) + expected_dim_order = self.datastore.expected_dim_order( + category=part + ) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -188,10 +192,12 @@ def get_time_step(times): # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - self.num_past_forcing_steps * boundary_time_step + state_time_min + - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max + self.num_future_forcing_steps * boundary_time_step + state_time_max + + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -220,8 +226,10 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = self.datastore.get_standardization_dataarray( - category="forcing" + self.ds_forcing_stats = ( + self.datastore.get_standardization_dataarray( + category="forcing" + ) ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -378,7 +386,9 @@ def _slice_time( current_time = ( da_forcing_boundary.analysis_time[idx] - + da_forcing_boundary.elapsed_forecast_duration[offset + step] + + da_forcing_boundary.elapsed_forecast_duration[ + offset + step + ] ) da_sliced = da_forcing_boundary.isel( @@ -386,12 +396,16 @@ def _slice_time( elapsed_forecast_duration=slice(start_idx, end_idx + 1), ) - da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.rename( + {"elapsed_forecast_duration": "window"} + ) da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) - da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + da_sliced = da_sliced.expand_dims( + dim={"time": [current_time.values]} + ) da_list.append(da_sliced) @@ -401,13 +415,13 @@ def _slice_time( da_forcing_boundary_matched.time.values[1] - da_forcing_boundary_matched.time.values[0] ) - da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( - forcing_time_step / state_time_step - ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ + "window" + ] * (forcing_time_step / state_time_step) time_diff_steps = da_forcing_boundary_matched.isel( grid_index=0, forcing_feature=0 ).data - + else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing_boundary["time"] @@ -416,7 +430,8 @@ def _slice_time( # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) @@ -548,7 +563,9 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std + da_init_states = ( + da_init_states - self.da_state_mean + ) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -620,7 +637,9 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) + target_states = torch.tensor( + da_target_states.values, dtype=tensor_dtype + ) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -730,7 +749,10 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: + if ( + grid_coord in da_datastore_state.coords + and grid_coord not in da.coords + ): da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index c97da4bc..7c5ffb3b 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -25,7 +25,7 @@ output: end: 2022-09-30T00:00 test: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 2f5ed96c..4a59c81e 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,7 +40,9 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -77,7 +79,9 @@ def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) + time_values = np.datetime64("2020-01-01") + np.arange( + len(ANALYSIS_STATE_VALUES) + ) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -98,7 +102,9 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] expected_init_states = [0, 1] if ar_steps == 3: diff --git a/tests/test_training.py b/tests/test_training.py index 28566a4b..7a1b4717 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,6 +5,7 @@ import pytest import pytorch_lightning as pl import torch + import wandb # First-party @@ -22,14 +23,10 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) +@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) + datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -38,15 +35,13 @@ def test_training(datastore_name, datastore_boundary_name): ) if not isinstance(datastore_boundary, BaseRegularGridDatastore): pytest.skip( - f"Skipping test for {datastore_boundary_name} as it is not a regular " - "grid datastore." + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." ) if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s else: device_name = "cpu" From 01fa807bc5ce47270e3b4568db8df8ce3b436953 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:10:29 +0100 Subject: [PATCH 36/90] improved docstrings and added some assertions --- neural_lam/weather_dataset.py | 105 ++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b3d86292..991965d9 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -143,7 +143,13 @@ def __init__( self.da_state = self.da_state def get_time_step(times): - """Calculate the time step from the data""" + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ time_diffs = np.diff(times) if not np.all(time_diffs == time_diffs[0]): raise ValueError( @@ -234,6 +240,7 @@ def get_time_step(times): self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + # XXX: Again, the boundary data is considered forcing data for now if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( @@ -305,7 +312,7 @@ def _slice_time( is performed based on the state times. Additionally, the time difference between the matched forcing/boundary times and state times (in multiples of state time steps) is added to the forcing dataarray. This will be - used as an additional feature in the model (temporal embedding). + used as an additional input feature in the model (temporal embedding). Parameters ---------- @@ -333,23 +340,26 @@ def _slice_time( da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. """ - # Number of initial steps required (e.g., for initializing models) + # The current implementation requires at least 2 time steps for the + # initial state (see GraphCast). init_steps = 2 - - # Slice the state data as before + # slice the dataarray to include the required number of time steps if self.datastore.is_forecast: - # Calculate start and end indices for slicing - start_idx = max(0, num_past_steps - init_steps) - end_idx = max(init_steps, num_past_steps) + n_steps - - # Slice the state data over the elapsed forecast duration + start_idx = max(0, self.num_past_forcing_steps - init_steps) + end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + # this implies that the data will have both `analysis_time` and + # `elapsed_forecast_duration` dimensions for forecasts. We for now + # simply select a analysis time and the first `n_steps` forecast + # times (given no offset). Note that this means that we get one + # sample per forecast, always starting at forecast time 2. da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - - # Create a new 'time' dimension + # create a new time dimension so that the produced sample has a + # `time` dimension, similarly to the analysis only data da_state_sliced["time"] = ( da_state_sliced.analysis_time + da_state_sliced.elapsed_forecast_duration @@ -357,9 +367,13 @@ def _slice_time( da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + # Asserting that the forecast time step is consistent + self.get_time_step(da_state_sliced.time) else: - # For analysis data, slice the time dimension directly + # For analysis data we slice the time dimension directly. The offset + # is only relevant for the very first (and last) samples in the + # dataset. start_idx = idx + max(0, num_past_steps - init_steps) end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) @@ -372,7 +386,13 @@ def _slice_time( state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary if "analysis_time" in da_forcing_boundary.dims: + # Select the closest analysis time in the forcing/boundary data + # This is mostly relevant for boundary data where the time steps + # are not necessarily the same as the state data. But still fast + # enough for forcing data where the time steps are the same. idx = np.abs( da_forcing_boundary.analysis_time.values - self.da_state.analysis_time.values[idx] @@ -399,6 +419,8 @@ def _slice_time( da_sliced = da_sliced.rename( {"elapsed_forecast_duration": "window"} ) + + # Assign the 'window' coordinate to be relative positions da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) @@ -409,7 +431,10 @@ def _slice_time( da_list.append(da_sliced) - # Concatenate the list of DataArrays along the 'time' dimension + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time. da_forcing_boundary_matched = xr.concat(da_list, dim="time") forcing_time_step = ( da_forcing_boundary_matched.time.values[1] @@ -423,7 +448,9 @@ def _slice_time( ).data else: - # For analysis data, match directly using the 'time' coordinate + # For analysis data, we slice the time dimension directly. The + # offset is only relevant for the very first (and last) samples in + # the dataset. forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times @@ -455,7 +482,7 @@ def _slice_time( ) # Add time difference as a new coordinate to concatenate to the - # forcing features later + # forcing features later as temporal embedding da_forcing_boundary_matched["time_diff_steps"] = ( ("time", "window"), time_diff_steps, @@ -464,7 +491,26 @@ def _slice_time( return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): - """Helper function to process windowed data after standardization.""" + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + differences to the existing features as a temporal embedding. + + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. + + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. + + """ stacked_dim = "forcing_feature_windowed" if da_windowed is not None: # Stack the 'feature' and 'window' dimensions and add the @@ -492,8 +538,8 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): def _build_item_dataarrays(self, idx): """ - Create the dataarrays for the initial states, target states and forcing - data for the sample at index `idx`. + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. Parameters ---------- @@ -529,7 +575,7 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # if da_forcing is None, the function will return None for + # if da_forcing_boundary is None, the function will return None for # da_forcing_windowed if da_boundary is not None: _, da_boundary_windowed = self._slice_time( @@ -542,6 +588,9 @@ def _build_item_dataarrays(self, idx): ) else: da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, @@ -584,6 +633,10 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std + # This function handles the stacking of the forcing and boundary data + # and adds the time step differences as a temporal embedding. + # It can handle `None` inputs for the forcing and boundary data + # (and simlpy return an empty DataArray in that case). da_forcing_windowed = self._process_windowed_data( da_forcing_windowed, da_state, da_target_times ) @@ -655,6 +708,11 @@ def __getitem__(self, idx): # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + return init_states, target_states, forcing, boundary, target_times def __iter__(self): @@ -794,9 +852,10 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: - # BUG: There also seem to be issues with "spawn", to be investigated - # default to spawn for now, as the default on linux "fork" hangs - # when using dask (which the npyfilesmeps datastore uses) + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) self.multiprocessing_context = "spawn" else: self.multiprocessing_context = None From 5a749f3ab55d79ce27ebe5bf439815d0cbf78093 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:10:42 +0100 Subject: [PATCH 37/90] update mdp dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5bbe4d92..ef75c8d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "parse>=1.20.2", "dataclass-wizard>=0.22.3", "gcsfs>=2021.10.0", - "mllam-data-prep @ git+https://github.com/leifdenby/mllam-data-prep@temp/for-neural-lam-datastores", + "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 45ba60782066cfc94d621f07119f23266556a374 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:11:32 +0100 Subject: [PATCH 38/90] remove boundary datastore from tests that don't need it --- tests/test_datasets.py | 17 ++--------------- tests/test_training.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5fbe4a5d..063ec147 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -108,37 +108,24 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length - dataset[len(dataset) - 1] @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) -def test_dataset_item_create_dataarray_from_tensor( - datastore_name, datastore_boundary_name -): +def test_dataset_item_create_dataarray_from_tensor(datastore_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 - num_past_boundary_steps = 1 - num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, - datastore_boundary=datastore_boundary, + datastore_boundary=None, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, - num_past_boundary_steps=num_past_boundary_steps, - num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 diff --git a/tests/test_training.py b/tests/test_training.py index 7a1b4717..ca0ebf41 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,7 +5,6 @@ import pytest import pytorch_lightning as pl import torch - import wandb # First-party @@ -23,10 +22,14 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -41,7 +44,9 @@ def test_training(datastore_name, datastore_boundary_name): if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s else: device_name = "cpu" From f36f36040dcbfa40380880d4cc9fa03f6632da43 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:42:43 +0100 Subject: [PATCH 39/90] fix scope of _get_slice_time --- neural_lam/weather_dataset.py | 40 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 991965d9..4bc9d5c7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,28 +142,14 @@ def __init__( else: self.da_state = self.da_state - def get_time_step(times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] + # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: state_times = self.da_state.time - _ = get_time_step(state_times) + _ = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -182,7 +168,7 @@ def get_time_step(times): forcing_times = self.da_forcing.analysis_time else: forcing_times = self.da_forcing.time - get_time_step(forcing_times.values) + self._get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -192,7 +178,7 @@ def get_time_step(times): boundary_times = self.da_boundary.analysis_time else: boundary_times = self.da_boundary.time - boundary_time_step = get_time_step(boundary_times.values) + boundary_time_step = self._get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values @@ -296,6 +282,22 @@ def __len__(self): - self.num_future_forcing_steps ) + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + def _slice_time( self, da_state, @@ -368,7 +370,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent - self.get_time_step(da_state_sliced.time) + self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset From 105108e9bd144c64075e0f5588f15176fc1fde52 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:43:01 +0100 Subject: [PATCH 40/90] fix scope of _get_time_step --- neural_lam/weather_dataset.py | 40 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 991965d9..4bc9d5c7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,28 +142,14 @@ def __init__( else: self.da_state = self.da_state - def get_time_step(times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] + # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: state_times = self.da_state.time - _ = get_time_step(state_times) + _ = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -182,7 +168,7 @@ def get_time_step(times): forcing_times = self.da_forcing.analysis_time else: forcing_times = self.da_forcing.time - get_time_step(forcing_times.values) + self._get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -192,7 +178,7 @@ def get_time_step(times): boundary_times = self.da_boundary.analysis_time else: boundary_times = self.da_boundary.time - boundary_time_step = get_time_step(boundary_times.values) + boundary_time_step = self._get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values @@ -296,6 +282,22 @@ def __len__(self): - self.num_future_forcing_steps ) + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + def _slice_time( self, da_state, @@ -368,7 +370,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent - self.get_time_step(da_state_sliced.time) + self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset From ae0cf764bd23adfde2befa4bef8ef89122975688 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 16:58:46 +0100 Subject: [PATCH 41/90] added information about optional boundary datastore --- README.md | 22 +++++++++++++--------- neural_lam/weather_dataset.py | 2 -- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e21b7c24..7a5e5caf 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th interface that provides the data in a data-structure that can be used within neural-lam. A datastore is used to create a `pytorch.Dataset`-derived class that samples the data in time to create individual samples for - training, validation and testing. + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. 2. **The graph structure** is used to define message-passing GNN layers, that are trained to emulate fluid flow in the atmosphere over time. The @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. The path you provide to the neural-lam config (`config.yaml`) also sets the root directory relative to which all other paths are resolved, as in the parent -directory of the config becomes the root directory. Both the datastore and +directory of the config becomes the root directory. Both the datastores and graphs you generate are then stored in subdirectories of this root directory. Exactly how and where a specific datastore expects its source data to be stored and where it stores its derived data is up to the implementation of the @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): data/ ├── config.yaml - Configuration file for neural-lam ├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml └── graphs/ - Directory containing graphs for training ``` @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: datastore: kind: mdp config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines two things: +1) the kind of datastores and the path to their config +2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha # Contact If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. -There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). -You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4bc9d5c7..8d82229f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,8 +142,6 @@ def __init__( else: self.da_state = self.da_state - - # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time From 9af27e0741894319860d11eb22cd9e9fd398e1ec Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 19:46:37 +0100 Subject: [PATCH 42/90] add datastore_boundary to neural_lam --- neural_lam/train_model.py | 22 ++++++++++++++++++++++ neural_lam/weather_dataset.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..37bf6db7 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,6 +34,11 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) + parser.add_argument( + "--config_path_boundary", + type=str, + help="Path to the configuration for boundary conditions", + ) parser.add_argument( "--model", type=str, @@ -212,6 +217,9 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" + assert ( + args.config_path_boundary is not None + ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -227,10 +235,24 @@ def main(input_args=None): # Load neural-lam configuration and datastore to use config, datastore = load_config_and_datastore(config_path=args.config_path) + config_boundary, datastore_boundary = load_config_and_datastore( + config_path=args.config_path_boundary + ) + + # TODO this should not be required, make more flexible + assert ( + datastore.num_past_forcing_steps + == datastore_boundary.num_past_forcing_steps + ), "Mismatch in num_past_forcing_steps" + assert ( + datastore.num_future_forcing_steps + == datastore_boundary.num_future_forcing_steps + ), "Mismatch in num_future_forcing_steps" # Create datamodule data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b5f85580..75f7e04e 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -22,6 +22,8 @@ class WeatherDataset(torch.utils.data.Dataset): ---------- datastore : BaseDatastore The datastore to load the data from (e.g. mdp). + datastore_boundary : BaseDatastore + The boundary datastore to load the data from (e.g. mdp). split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional @@ -43,6 +45,7 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, num_past_forcing_steps=1, @@ -54,6 +57,7 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore + self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps @@ -606,6 +610,7 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, standardize=True, @@ -616,6 +621,7 @@ def __init__( ): super().__init__() self._datastore = datastore + self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps self.ar_steps_train = ar_steps_train @@ -627,6 +633,7 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: + # BUG: There also seem to be issues with "spawn", to be investigated # default to spawn for now, as the default on linux "fork" hangs # when using dask (which the npyfilesmeps datastore uses) self.multiprocessing_context = "spawn" @@ -637,6 +644,7 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, @@ -645,6 +653,7 @@ def setup(self, stage=None): ) self.val_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, @@ -655,6 +664,7 @@ def setup(self, stage=None): if stage == "test" or stage is None: self.test_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, From c25fb30ab6b9fc8038227a590b5551f1660dbe19 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:41 +0100 Subject: [PATCH 43/90] complete integration of boundary in weatherDataset --- neural_lam/weather_dataset.py | 55 ++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 75f7e04e..7585207c 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,6 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + self.da_boundary = self.datastore_boundary.get_dataarray( + category="boundary", split=self.split + ) # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -118,6 +121,15 @@ def __init__( self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + if self.da_boundary is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="boundary" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.boundary_mean + self.da_boundary_std = self.ds_boundary_stats.boundary_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -352,6 +364,8 @@ def _build_item_dataarrays(self, idx): The dataarray for the target states. da_forcing_windowed : xr.DataArray The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -381,6 +395,11 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None + if self.da_boundary is not None: + da_boundary = self.da_boundary + else: + da_boundary = None + # handle time sampling in a way that is compatible with both analysis # and forecast data da_state = self._slice_state_time( @@ -390,11 +409,17 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed = self._slice_forcing_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) + if da_boundary is not None: + da_boundary_windowed = self._slice_forcing_time( + da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + ) # load the data into memory da_state.load() if da_forcing is not None: da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() da_init_states = da_state.isel(time=slice(0, 2)) da_target_states = da_state.isel(time=slice(2, None)) @@ -417,6 +442,11 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed - self.da_forcing_mean ) / self.da_forcing_std + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + if da_forcing is not None: # stack the `forcing_feature` and `window_sample` dimensions into a # single `forcing_feature` dimension @@ -436,11 +466,31 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: + # stack the `forcing_feature` and `window_sample` dimensions into a + # single `forcing_feature` dimension + da_boundary_windowed = da_boundary_windowed.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + else: + # create an empty forcing tensor with the right shape + da_boundary_windowed = xr.DataArray( + data=np.empty( + (self.ar_steps, da_state.grid_index.size, 0), + ), + dims=("time", "grid_index", "boundary_feature"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + "boundary_feature": [], + }, + ) return ( da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) @@ -475,6 +525,7 @@ def __getitem__(self, idx): da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) = self._build_item_dataarrays(idx=idx) @@ -491,13 +542,15 @@ def __getitem__(self, idx): ) forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) - return init_states, target_states, forcing, target_times + return init_states, target_states, forcing, boundary, target_times def __iter__(self): """ From 505ceeb589c3398d37100a6073fa5590e7d786c2 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 20:15:55 +0100 Subject: [PATCH 44/90] Add test to check timestep length and spacing --- neural_lam/weather_dataset.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 7585207c..8e55d4a5 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -101,6 +101,82 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # Check time coverage for forcing and boundary data + if self.da_forcing is not None or self.da_boundary is not None: + state_times = self.da_state.time + state_time_min = state_times.min().values + state_time_max = state_times.max().values + + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + if self.da_forcing is not None: + forcing_times = self.da_forcing.time + forcing_time_step = get_time_step(forcing_times.values) + forcing_time_min = forcing_times.min().values + forcing_time_max = forcing_times.max().values + + # Calculate required bounds for forcing using its time step + forcing_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * forcing_time_step + ) + forcing_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * forcing_time_step + ) + + if forcing_time_min > forcing_required_time_min: + raise ValueError( + f"Forcing data starts too late." + f"Required start: {forcing_required_time_min}, " + f"but forcing starts at {forcing_time_min}." + ) + + if forcing_time_max < forcing_required_time_max: + raise ValueError( + f"Forcing data ends too early." + f"Required end: {forcing_required_time_max}," + f"but forcing ends at {forcing_time_max}." + ) + + if self.da_boundary is not None: + boundary_times = self.da_boundary.time + boundary_time_step = get_time_step(boundary_times.values) + boundary_time_min = boundary_times.min().values + boundary_time_max = boundary_times.max().values + + # Calculate required bounds for boundary using its time step + boundary_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * boundary_time_step + ) + boundary_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * boundary_time_step + ) + + if boundary_time_min > boundary_required_time_min: + raise ValueError( + f"Boundary data starts too late." + f"Required start: {boundary_required_time_min}, " + f"but boundary starts at {boundary_time_min}." + ) + + if boundary_time_max < boundary_required_time_max: + raise ValueError( + f"Boundary data ends too early." + f"Required end: {boundary_required_time_max}, " + f"but boundary ends at {boundary_time_max}." + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize From e7330664661bd336caf40842dfb46a406b120721 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:43:57 +0100 Subject: [PATCH 45/90] setting default mdp boundary to 0 gridcells --- neural_lam/datastore/mdp.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 0d1aac7b..b6f1676c 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -27,7 +27,7 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at `config_path`. A boundary mask is created with `n_boundary_points` @@ -336,19 +336,22 @@ def boundary_mask(self) -> xr.DataArray: boundary point and 0 is not. """ - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + if self._n_boundary_points > 0: + ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) + da_state_variable = ( + ds_unstacked["state"].isel(time=0).isel(state_feature=0) + ) + da_domain_allzero = xr.zeros_like(da_state_variable) + ds_unstacked["boundary_mask"] = da_domain_allzero.isel( + x=slice(self._n_boundary_points, -self._n_boundary_points), + y=slice(self._n_boundary_points, -self._n_boundary_points), + ) + ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( + 1 + ).astype(int) + return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) + else: + return None @property def coords_projection(self) -> ccrs.Projection: From d8349a4801654c152f14924aa86d08c4ab952468 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 18 Nov 2024 21:44:54 +0100 Subject: [PATCH 46/90] implement time-based slicing combine two slicing fcts into one --- neural_lam/weather_dataset.py | 300 ++++++++++++++++++---------------- 1 file changed, 161 insertions(+), 139 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 8e55d4a5..5559e838 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -67,8 +67,9 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + # XXX For now boundary data is always considered forcing data self.da_boundary = self.datastore_boundary.get_dataarray( - category="boundary", split=self.split + category="forcing", split=self.split ) # check that with the provided data-arrays and ar_steps that we have a @@ -200,7 +201,7 @@ def get_time_step(times): if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( - category="boundary" + category="forcing" ) ) self.da_boundary_mean = self.ds_boundary_stats.boundary_mean @@ -252,175 +253,156 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_state_time(self, da_state, idx, n_steps: int): + def _slice_time(self, da_state, da_forcing, idx, n_steps: int): """ - Produce a time slice of the given dataarray `da_state` (state) starting - at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `num_past_forcing_steps` class attribute. `Offset` is used to offset the - start of the sample, to assert that enough previous time steps are - available for the 2 initial states and any corresponding forcings - (calculated in `_slice_forcing_time`). + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing` (forcing). For the state data, slicing is done as before + based on `idx`. For the forcing data, nearest neighbor matching is + performed based on the state times. Additionally, the time difference + between the matched forcing times and state times (in multiples of state + time steps) is added to the forcing dataarray. Parameters ---------- da_state : xr.DataArray - The dataarray to slice. This is expected to have a `time` dimension - if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. + The state dataarray to slice. + da_forcing : xr.DataArray + The forcing dataarray to slice. idx : int - The index of the time step to start the sample from. + The index of the time step to start the sample from in the state + data. n_steps : int The number of time steps to include in the sample. Returns ------- - da_sliced : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). + da_forcing_matched : xr.DataArray + The forcing dataarray matched to state times with an added + coordinate 'time_diff', representing the time difference to state + times in multiples of state time steps. """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). + # Number of initial steps required (e.g., for initializing models) init_steps = 2 - # slice the dataarray to include the required number of time steps + + # Slice the state data as before if self.datastore.is_forecast: + # Calculate start and end indices for slicing start_idx = max(0, self.num_past_forcing_steps - init_steps) end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps - # this implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select a analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast, always starting at forecast time 2. - da_sliced = da_state.isel( + + # Slice the state data over the elapsed forecast duration + da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - # create a new time dimension so that the produced sample has a - # `time` dimension, similarly to the analysis only data - da_sliced["time"] = ( - da_sliced.analysis_time + da_sliced.elapsed_forecast_duration + + # Create a new 'time' dimension + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration ) - da_sliced = da_sliced.swap_dims( + da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + else: - # For analysis data we slice the time dimension directly. The offset - # is only relevant for the very first (and last) samples in the - # dataset. + # For analysis data, slice the time dimension directly start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) end_idx = ( idx + max(init_steps, self.num_past_forcing_steps) + n_steps ) - da_sliced = da_state.isel(time=slice(start_idx, end_idx)) - return da_sliced + da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - def _slice_forcing_time(self, da_forcing, idx, n_steps: int): - """ - Produce a time slice of the given dataarray `da_forcing` (forcing) - starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `num_past_forcing_steps` class attribute. It is used to - offset the start of the sample, to ensure that enough previous time - steps are available for the forcing data. The forcing data is windowed - around the current autoregressive time step to include the past and - future forcings. - - Parameters - ---------- - da_forcing : xr.DataArray - The forcing dataarray to slice. This is expected to have a `time` - dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. - idx : int - The index of the time step to start the sample from. - n_steps : int - The number of time steps to include in the sample. - - Returns - ------- - da_concat : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', - 'window', 'forcing_feature'). - """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). The forcing data is windowed around the - # current autregressive time step. The two `init_steps` can also be used - # as past forcings. - init_steps = 2 - da_list = [] + # Get the state times for matching + state_times = da_state_sliced["time"] + # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: - # This implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select an analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast. - # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - current_time = ( - da_forcing.analysis_time[idx] - + da_forcing.elapsed_forecast_duration[offset + step] - ) - - da_sliced = da_forcing.isel( - analysis_time=idx, - elapsed_forecast_duration=slice(start_idx, end_idx + 1), - ) - - da_sliced = da_sliced.rename( - {"elapsed_forecast_duration": "window"} - ) + # Calculate all possible forcing times + forcing_times = ( + da_forcing.analysis_time + da_forcing.elapsed_forecast_duration + ) + forcing_times_flat = forcing_times.stack( + forecast_time=("analysis_time", "elapsed_forecast_duration") + ) - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # Compute time differences + time_deltas = ( + forcing_times_flat.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) + + # Retrieve corresponding indices for analysis_time and + # elapsed_forecast_duration + forecast_time_index = forcing_times_flat["forecast_time"][idx_min] + analysis_time_indices = forecast_time_index["analysis_time"] + elapsed_forecast_duration_indices = forecast_time_index[ + "elapsed_forecast_duration" + ] + + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel( + analysis_time=("time", analysis_time_indices), + elapsed_forecast_duration=( + "time", + elapsed_forecast_duration_indices, + ), + ) - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Assign matched state times to the forcing data + da_forcing_matched["time"] = state_times + da_forcing_matched = da_forcing_matched.swap_dims( + {"elapsed_forecast_duration": "time"} + ) - da_list.append(da_sliced) + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) else: - # For analysis data, we slice the time dimension directly. The - # offset is only relevant for the very first (and last) samples in - # the dataset. - offset = idx + max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - # Slice the data over the desired time window - da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) - - da_sliced = da_sliced.rename({"time": "window"}) - - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) - ) + # For analysis data, match directly using the 'time' coordinate + forcing_times = da_forcing["time"] - # Add a 'time' dimension to keep track of steps using actual - # time coordinates - current_time = da_forcing.time[offset + step] - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + # Compute time differences + time_deltas = ( + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) + time_diffs = np.abs(time_deltas) + idx_min = time_diffs.argmin(axis=0) - da_list.append(da_sliced) + # Slice the forcing data using matched indices + da_forcing_matched = da_forcing.isel(time=idx_min) + da_forcing_matched = da_forcing_matched.assign_coords( + time=state_times + ) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] + time_diff_steps = ( + time_deltas[idx_min, np.arange(len(state_times))] + / state_time_step + ) - return da_concat + # Add time difference as a new coordinate + da_forcing_matched = da_forcing_matched.assign_coords( + time_diff=("time", time_diff_steps) + ) + + return da_state_sliced, da_forcing_matched def _build_item_dataarrays(self, idx): """ @@ -442,6 +424,7 @@ def _build_item_dataarrays(self, idx): The dataarray for the forcing data, windowed for the sample. da_boundary_windowed : xr.DataArray The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. da_target_times : xr.DataArray The dataarray for the target times. """ @@ -478,15 +461,15 @@ def _build_item_dataarrays(self, idx): # handle time sampling in a way that is compatible with both analysis # and forecast data - da_state = self._slice_state_time( + da_state = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps ) if da_forcing is not None: - da_forcing_windowed = self._slice_forcing_time( + da_forcing_windowed = self._slice_time( da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) if da_boundary is not None: - da_boundary_windowed = self._slice_forcing_time( + da_boundary_windowed = self._slice_time( da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps ) @@ -524,13 +507,32 @@ def _build_item_dataarrays(self, idx): ) / self.da_boundary_std if da_forcing is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # Expand 'time_diff' to align with 'forcing_feature' and 'window' + # dimensions 'time_diff' has dimension ('time'), expand to ('time', + # 'forcing_feature', 'window') + time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( + forcing_feature=da_forcing_windowed["forcing_feature"], + window=da_forcing_windowed["window"], + ) + + # Stack 'forcing_feature' and 'window' into a single + # 'forcing_feature_windowed' dimension da_forcing_windowed = da_forcing_windowed.stack( forcing_feature_windowed=("forcing_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + forcing_feature_windowed=("forcing_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' + da_forcing_windowed = da_forcing_windowed.assign_coords( + time_diff=( + "forcing_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty forcing tensor with the right shape da_forcing_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), @@ -542,14 +544,34 @@ def _build_item_dataarrays(self, idx): "forcing_feature": [], }, ) + if da_boundary is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension + # If 'da_boundary_windowed' also has 'time_diff', process similarly + # Expand 'time_diff' to align with 'boundary_feature' and 'window' + # dimensions + time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( + boundary_feature=da_boundary_windowed["boundary_feature"], + window=da_boundary_windowed["window"], + ) + + # Stack 'boundary_feature' and 'window' into a single + # 'boundary_feature_windowed' dimension da_boundary_windowed = da_boundary_windowed.stack( boundary_feature_windowed=("boundary_feature", "window") ) + time_diff_expanded = time_diff_expanded.stack( + boundary_feature_windowed=("boundary_feature", "window") + ) + + # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' + da_boundary_windowed = da_boundary_windowed.assign_coords( + time_diff=( + "boundary_feature_windowed", + time_diff_expanded.values, + ) + ) else: - # create an empty forcing tensor with the right shape + # Create an empty boundary tensor with the right shape da_boundary_windowed = xr.DataArray( data=np.empty( (self.ar_steps, da_state.grid_index.size, 0), From fd791bfb51c3c751ff4af8d74eaa47c81b63a1eb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 06:26:54 +0100 Subject: [PATCH 47/90] remove all interior_mask and boundary_mask --- neural_lam/datastore/base.py | 17 ---- neural_lam/datastore/mdp.py | 34 -------- neural_lam/datastore/npyfilesmeps/store.py | 28 ------ neural_lam/models/ar_model.py | 53 +++--------- neural_lam/vis.py | 12 --- .../config.yaml | 18 ++++ .../era5.datastore.yaml | 85 +++++++++++++++++++ .../meps_example_reduced.datastore.yaml | 44 ++++++++++ tests/dummy_datastore.py | 22 ----- tests/test_datastores.py | 21 ----- 10 files changed, 157 insertions(+), 177 deletions(-) create mode 100644 tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml create mode 100644 tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml create mode 100644 tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b0055e39..e2d21404 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -228,23 +228,6 @@ def get_dataarray( """ pass - @cached_property - @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - pass - @abc.abstractmethod def get_xy(self, category: str) -> np.ndarray: """ diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index b6f1676c..e662cb63 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -319,40 +319,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) return ds_stats - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Produce a 0/1 mask for the boundary points of the dataset, these will - sit at the edges of the domain (in x/y extent) and will be used to mask - out the boundary points from the loss function and to overwrite the - boundary points from the prediction. For now this is created when the - mask is requested, but in the future this could be saved to the zarr - file. - - Returns - ------- - xr.DataArray - A 0/1 mask for the boundary points of the dataset, where 1 is a - boundary point and 0 is not. - - """ - if self._n_boundary_points > 0: - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) - else: - return None - @property def coords_projection(self) -> ccrs.Projection: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e80706..146b0627 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -668,34 +668,6 @@ def grid_shape_state(self) -> CartesianGridShape: ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions `[grid_index]`. - - """ - xy = self.get_xy(category="state", stacked=False) - xs = xy[:, :, 0] - ys = xy[:, :, 1] - # Check if x-coordinates are constant along columns - assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" - # Check if y-coordinates are constant along rows - assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" - # Extract unique x and y coordinates - x = xs[:, 0] # Unique x-coordinates (changes along the first axis) - y = ys[0, :] # Unique y-coordinates (changes along the second axis) - values = np.load(self.root_path / "static" / "border_mask.npy") - da_mask = xr.DataArray( - values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" - ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) - return da_mask_stacked_xy - def get_standardization_dataarray(self, category: str) -> xr.Dataset: """Return the standardization dataarray for the given category. This should contain a `{category}_mean` and `{category}_std` variable for diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 44baf9c2..710efcec 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -45,7 +45,6 @@ def __init__( da_state_stats = datastore.get_standardization_dataarray( category="state" ) - da_boundary_mask = datastore.boundary_mask num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps @@ -118,18 +117,6 @@ def __init__( # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { "mse": [], } @@ -194,13 +181,6 @@ def configure_optimizers(self): ) return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -232,7 +212,6 @@ def unroll_prediction(self, init_states, forcing_features, true_states): for i in range(pred_steps): forcing = forcing_features[:, i] - border_state = true_states[:, i] pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing @@ -240,19 +219,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states): # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) + prediction_list.append(pred_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 @@ -290,12 +263,14 @@ def training_step(self, batch): """ prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( @@ -328,9 +303,7 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -355,7 +328,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -382,9 +354,7 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) @@ -413,16 +383,13 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times diff --git a/neural_lam/vis.py b/neural_lam/vis.py index d6b57f88..efab20bf 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,11 +87,6 @@ def plot_prediction( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_values = np.invert(da_mask.values.astype(bool)).astype(float) - pixel_alpha = mask_values.clip(0.7, 1) # Faded border region - fig, axes = plt.subplots( 1, 2, @@ -107,7 +102,6 @@ def plot_prediction( origin="lower", x="x", extent=extent, - alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", @@ -141,11 +135,6 @@ def plot_spatial_error( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, @@ -164,7 +153,6 @@ def plot_spatial_error( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml new file mode 100644 index 00000000..27cc9764 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml @@ -0,0 +1,18 @@ +datastore: + kind: npyfilesmeps + config_path: meps_example_reduced.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + nlwrs_0: 1.0 + nswrs_0: 1.0 + pres_0g: 1.0 + pres_0s: 1.0 + r_2: 1.0 + r_65: 1.0 + t_2: 1.0 + t_65: 1.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml new file mode 100644 index 00000000..600a1845 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -0,0 +1,85 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml new file mode 100644 index 00000000..3d88d4a4 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml @@ -0,0 +1,44 @@ +dataset: + name: meps_example_reduced + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + var_units: + - Pa + - Pa + - W/m**2 + - W/m**2 + - '' + - '' + - K + - K + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 +grid_shape_state: +- 134 +- 119 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d404..d62c7356 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -148,12 +148,6 @@ def __init__( times = [self.T0 + dt * i for i in range(n_timesteps)] self.ds.coords["time"] = times - # Add boundary mask - self.ds["boundary_mask"] = xr.DataArray( - np.random.choice([0, 1], size=(n_points_1d, n_points_1d)), - dims=["x", "y"], - ) - # Stack the spatial dimensions into grid_index self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) @@ -342,22 +336,6 @@ def get_dataarray( dim_order = self.expected_dim_order(category=category) return self.ds[category].transpose(*dim_order) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - return self.ds["boundary_mask"] - def get_xy(self, category: str, stacked: bool) -> ndarray: """Return the x, y coordinates of the dataset. diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4a4b1100..a91f6245 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -18,8 +18,6 @@ dataarray for the given category. - `get_dataarray` (method): Return the processed data (as a single `xr.DataArray`) for the given category and test/train/val-split. -- `boundary_mask` (property): Return the boundary mask for the dataset, - with spatial dimensions stacked. - `config` (property): Return the configuration of the datastore. In addition BaseRegularGridDatastore must have the following methods and @@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name): assert n_features["train"] == n_features["val"] == n_features["test"] -@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" - datastore = init_datastore_example(datastore_name) - da_mask = datastore.boundary_mask - - assert isinstance(da_mask, xr.DataArray) - assert set(da_mask.dims) == {"grid_index"} - assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size - - if isinstance(datastore, BaseRegularGridDatastore): - grid_shape = datastore.grid_shape_state - assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y - - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): """Check that the `datastore.get_xy_extent` method is implemented and that From ae82cdb8360d899b063bdf48a877a42184306cab Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:55:56 +0100 Subject: [PATCH 48/90] added gcsfs dependency for era5 weatherbench download --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fdcb7f3e..38e7cb0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard<0.31.0", + "gcsfs>=2021.10.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 34a6cc7d24ffb218b2aef909cac7db06ffbef618 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:57:57 +0100 Subject: [PATCH 49/90] added new era5 datastore config for boundary --- tests/conftest.py | 19 +++- .../mdp/era5_1000hPa_winds/.gitignore | 2 + .../mdp/era5_1000hPa_winds/config.yaml | 3 + .../era5_1000hPa_winds/era5.datastore.yaml | 90 +++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 6f579621..be5cf3e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,15 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) +DATASTORES_BOUNDARY_EXAMPLES = dict( + mdp=( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_winds" + / "era5.datastore.yaml" + ) +) + DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore @@ -102,5 +111,13 @@ def init_datastore_example(datastore_kind): datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], ) - return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml new file mode 100644 index 00000000..5d1e05f2 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml @@ -0,0 +1,3 @@ +datastore: + kind: mdp + config_path: era5.datastore.yaml diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml new file mode 100644 index 00000000..36b39501 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml @@ -0,0 +1,90 @@ +#TODO: What do these versions mean? Should they be updated? +schema_version: v0.2.0+dev +dataset_version: v1.0.0 + +output: + variables: + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-02T00:00 + end: 1990-09-10T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-02T00:00 + end: 1990-09-07T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-05T00:00 + end: 1990-09-08T00:00 + test: + start: 1990-09-06T00:00 + end: 1990-09-10T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + v_component_of_wind: + level: + values: [1000, ] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_surface_net_short_wave_radiation_flux + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 From 2dc67a02e2acad0665452bfe336384de1cc34b4e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:58:36 +0100 Subject: [PATCH 50/90] removed left-over boundary-mask references --- neural_lam/datastore/mdp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index e662cb63..b28d2650 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -27,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): + def __init__(self, config_path, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded + `config_path`. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified since the zarr was created), otherwise it is created from the configuration file. @@ -42,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file` method to then call `mllam_data_prep.create_dataset` to create the dataset. - n_boundary_points : int - The number of boundary points to use in the boundary mask. reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. @@ -70,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=0, reuse_existing=True): if self._ds is None: self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) - self._n_boundary_points = n_boundary_points print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: From 9f8628e03487a80ab3313656857b5fde3e6fde45 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:12 +0100 Subject: [PATCH 51/90] make check for existing category in datastore more flexible (for boundary) --- neural_lam/datastore/mdp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index b28d2650..7b947c20 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -154,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]: The units of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_units"].values.tolist() @@ -173,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature"].values.tolist() @@ -193,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]: The long names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_long_name"].values.tolist() @@ -249,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The xarray DataArray object with processed dataset. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") - return None + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] da_category = self._ds[category] From 388c79df3fdbbaa24ef025621a09dd25ac567ac5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Wed, 20 Nov 2024 16:00:15 +0100 Subject: [PATCH 52/90] implement xarray based (mostly) time slicing and windowing --- neural_lam/weather_dataset.py | 255 +++++++++++++++------------------- 1 file changed, 111 insertions(+), 144 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5559e838..555f2c35 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -64,10 +64,16 @@ def __init__( self.da_state = self.datastore.get_dataarray( category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) - # XXX For now boundary data is always considered forcing data + # XXX For now boundary data is always considered mdp-forcing data self.da_boundary = self.datastore_boundary.get_dataarray( category="forcing", split=self.split ) @@ -102,53 +108,36 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + def get_time_step(times): + """Calculate the time step from the data""" + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + # Check time step consistency in state data + _ = get_time_step(self.da_state.time.values) + # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values - def get_time_step(times): - """Calculate the time step from the data""" - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] - if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data forcing_times = self.da_forcing.time - forcing_time_step = get_time_step(forcing_times.values) - forcing_time_min = forcing_times.min().values - forcing_time_max = forcing_times.max().values - - # Calculate required bounds for forcing using its time step - forcing_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * forcing_time_step - ) - forcing_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * forcing_time_step - ) - - if forcing_time_min > forcing_required_time_min: - raise ValueError( - f"Forcing data starts too late." - f"Required start: {forcing_required_time_min}, " - f"but forcing starts at {forcing_time_min}." - ) - - if forcing_time_max < forcing_required_time_max: - raise ValueError( - f"Forcing data ends too early." - f"Required end: {forcing_required_time_max}," - f"but forcing ends at {forcing_time_max}." - ) + _ = get_time_step(forcing_times.values) if self.da_boundary is not None: + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values @@ -204,8 +193,8 @@ def get_time_step(times): category="forcing" ) ) - self.da_boundary_mean = self.ds_boundary_stats.boundary_mean - self.da_boundary_std = self.ds_boundary_stats.boundary_std + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std def __len__(self): if self.datastore.is_forecast: @@ -253,7 +242,7 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, da_forcing, idx, n_steps: int): + def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and `da_forcing` (forcing). For the state data, slicing is done as before @@ -316,8 +305,13 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): ) da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) + if da_forcing is None: + return da_state_sliced, None + # Get the state times for matching state_times = da_state_sliced["time"] + # Calculate time differences in multiples of state time steps + state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor if self.datastore.is_forecast: @@ -371,39 +365,80 @@ def _slice_time(self, da_state, da_forcing, idx, n_steps: int): da_forcing_matched = da_forcing_matched.assign_coords( time_diff=("time", time_diff_steps) ) - else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] # Compute time differences time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + state_times.values[np.newaxis, :] + - forcing_times.values[:, np.newaxis] + ) + idx_min = np.abs(time_deltas).argmin(axis=0) + + time_diff_steps = xr.DataArray( + np.stack( + [ + np.diagonal(time_deltas, offset=offset)[ + -len(state_times) + init_steps : + ] + / state_time_step + for offset in range( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ) + ], + axis=1, + ), + dims=["time", "window"], + coords={ + "time": state_times.isel(time=slice(init_steps, None)), + "window": np.arange( + -self.num_past_forcing_steps, + self.num_future_forcing_steps + 1, + ), + }, + name="time_diff_steps", ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel(time=idx_min) - da_forcing_matched = da_forcing_matched.assign_coords( - time=state_times + # Create window dimension using rolling + window_size = ( + self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) - - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step + da_forcing_windowed = da_forcing.rolling( + time=window_size, center=True + ).construct(window_dim="window") + da_forcing_matched = da_forcing_windowed.isel( + time=idx_min[init_steps:] ) # Add time difference as a new coordinate da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + time_diff=time_diff_steps ) return da_state_sliced, da_forcing_matched + def _process_windowed_data(self, da_windowed, da_state, da_target_times): + """Helper function to process windowed data after standardization.""" + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + # Stack the 'feature' and 'window' dimensions + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + return xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + def _build_item_dataarrays(self, idx): """ Create the dataarrays for the initial states, target states and forcing @@ -459,18 +494,21 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # handle time sampling in a way that is compatible with both analysis - # and forecast data - da_state = self._slice_time( - da_state=da_state, idx=idx, n_steps=self.ar_steps + # if da_forcing is None, the function will return None for + # da_forcing_windowed + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, ) - if da_forcing is not None: - da_forcing_windowed = self._slice_time( - da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps - ) + if da_boundary is not None: - da_boundary_windowed = self._slice_time( - da_forcing=da_boundary, idx=idx, n_steps=self.ar_steps + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_boundary, ) # load the data into memory @@ -506,83 +544,12 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std - if da_forcing is not None: - # Expand 'time_diff' to align with 'forcing_feature' and 'window' - # dimensions 'time_diff' has dimension ('time'), expand to ('time', - # 'forcing_feature', 'window') - time_diff_expanded = da_forcing_windowed["time_diff"].expand_dims( - forcing_feature=da_forcing_windowed["forcing_feature"], - window=da_forcing_windowed["window"], - ) - - # Stack 'forcing_feature' and 'window' into a single - # 'forcing_feature_windowed' dimension - da_forcing_windowed = da_forcing_windowed.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'forcing_feature_windowed' - da_forcing_windowed = da_forcing_windowed.assign_coords( - time_diff=( - "forcing_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty forcing tensor with the right shape - da_forcing_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "forcing_feature": [], - }, - ) - - if da_boundary is not None: - # If 'da_boundary_windowed' also has 'time_diff', process similarly - # Expand 'time_diff' to align with 'boundary_feature' and 'window' - # dimensions - time_diff_expanded = da_boundary_windowed["time_diff"].expand_dims( - boundary_feature=da_boundary_windowed["boundary_feature"], - window=da_boundary_windowed["window"], - ) - - # Stack 'boundary_feature' and 'window' into a single - # 'boundary_feature_windowed' dimension - da_boundary_windowed = da_boundary_windowed.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - time_diff_expanded = time_diff_expanded.stack( - boundary_feature_windowed=("boundary_feature", "window") - ) - - # Assign 'time_diff' as a coordinate to 'boundary_feature_windowed' - da_boundary_windowed = da_boundary_windowed.assign_coords( - time_diff=( - "boundary_feature_windowed", - time_diff_expanded.values, - ) - ) - else: - # Create an empty boundary tensor with the right shape - da_boundary_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "boundary_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "boundary_feature": [], - }, - ) + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, da_state, da_target_times + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, da_state, da_target_times + ) return ( da_init_states, From 2529969b12eb7babdcfd3311d6eae3045fe1fe15 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 07:09:52 +0100 Subject: [PATCH 53/90] cleanup analysis based time-slicing --- neural_lam/weather_dataset.py | 85 +++++++++++++++++------------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 555f2c35..fd40a2c8 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -245,11 +245,12 @@ def __len__(self): def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done as before - based on `idx`. For the forcing data, nearest neighbor matching is - performed based on the state times. Additionally, the time difference - between the matched forcing times and state times (in multiples of state - time steps) is added to the forcing dataarray. + `da_forcing` (forcing). For the state data, slicing is done based on + `idx`. For the forcing data, nearest neighbor matching is performed + based on the state times. Additionally, the time difference between the + matched forcing times and state times (in multiples of state time steps) + is added to the forcing dataarray. This will be used as an additional + feature in the model (temporal embedding). Parameters ---------- @@ -269,9 +270,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). da_forcing_matched : xr.DataArray - The forcing dataarray matched to state times with an added - coordinate 'time_diff', representing the time difference to state - times in multiples of state time steps. + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -308,9 +308,9 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): if da_forcing is None: return da_state_sliced, None - # Get the state times for matching + # Get the state times and its temporal resolution for matching with + # forcing data state_times = da_state_sliced["time"] - # Calculate time differences in multiples of state time steps state_time_step = state_times.values[1] - state_times.values[0] # Match forcing data to state times based on nearest neighbor @@ -369,39 +369,29 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing["time"] - # Compute time differences + # Compute time differences between forcing and state times + # (in multiples of state time steps) + # Retrieve the indices of the closest times in the forcing data time_deltas = ( - state_times.values[np.newaxis, :] - - forcing_times.values[:, np.newaxis] - ) + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] + ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) - time_diff_steps = xr.DataArray( - np.stack( - [ - np.diagonal(time_deltas, offset=offset)[ - -len(state_times) + init_steps : - ] - / state_time_step - for offset in range( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ) - ], - axis=1, - ), - dims=["time", "window"], - coords={ - "time": state_times.isel(time=slice(init_steps, None)), - "window": np.arange( - -self.num_past_forcing_steps, - self.num_future_forcing_steps + 1, - ), - }, - name="time_diff_steps", + time_diff_steps = np.stack( + [ + time_deltas[ + idx_i + - self.num_past_forcing_steps : idx_i + + self.num_future_forcing_steps + + 1, + init_steps + step_i, + ] + for (step_i, idx_i) in enumerate(idx_min[init_steps:]) + ], ) - # Create window dimension using rolling + # Create window dimension for forcing data to stack later window_size = ( self.num_past_forcing_steps + self.num_future_forcing_steps + 1 ) @@ -412,9 +402,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): time=idx_min[init_steps:] ) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=time_diff_steps + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, ) return da_state_sliced, da_forcing_matched @@ -423,13 +415,19 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" stacked_dim = "forcing_feature_windowed" if da_windowed is not None: - # Stack the 'feature' and 'window' dimensions + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding da_windowed = da_windowed.stack( {stacked_dim: ("forcing_feature", "window")} ) + da_windowed = xr.concat( + [da_windowed, da_windowed.time_diff_steps], + dim="forcing_feature_windowed", + ) else: # Create empty DataArray with the correct dimensions and coordinates - return xr.DataArray( + da_windowed = xr.DataArray( data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), dims=("time", "grid_index", f"{stacked_dim}"), coords={ @@ -438,6 +436,7 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): f"{stacked_dim}": [], }, ) + return da_windowed def _build_item_dataarrays(self, idx): """ From 179a035ac8b976a74e54ce4f38102addf06ed318 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 19 Nov 2024 16:59:42 +0100 Subject: [PATCH 54/90] implement datastore_boundary in existing tests --- tests/test_datasets.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece0..67eac70e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,12 +14,19 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_shapes(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -31,6 +38,9 @@ def test_dataset_item_shapes(datastore_name): """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_gridpoints = datastore.num_grid_points N_pred_steps = 4 @@ -38,6 +48,7 @@ def test_dataset_item_shapes(datastore_name): num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -48,7 +59,7 @@ def test_dataset_item_shapes(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, target_times = item + init_states, target_states, forcing, boundary, target_times = item # initial states assert init_states.ndim == 3 @@ -81,14 +92,23 @@ def test_dataset_item_shapes(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_create_dataarray_from_tensor(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_create_dataarray_from_tensor( + datastore_name, datastore_boundary_name +): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -158,13 +178,19 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_single_batch(datastore_name, split): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): """Check that the `datastore.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) device_name = ( torch.device("cuda") if torch.cuda.is_available() else "cpu" @@ -210,7 +236,9 @@ def _create_graph(): ) ) - dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) model = GraphLAM(args=args, datastore=datastore, config=config) # noqa From 2daeb1642d276730496cc7ab183203ed5abba6ce Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:39:27 +0100 Subject: [PATCH 55/90] allow for grid shape retrieval from forcing data --- neural_lam/datastore/mdp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 7b947c20..809bbdb8 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -380,8 +380,17 @@ def grid_shape_state(self): The shape of the cartesian grid for the state variables. """ - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) + da_x, da_y = ds_forcing.x, ds_forcing.y + else: + ds_state = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) From cbcdcaee71039977090a66ec2b8b1116063cf2a4 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Thu, 21 Nov 2024 16:40:47 +0100 Subject: [PATCH 56/90] rearrange time slicing, boundary first --- neural_lam/weather_dataset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index fd40a2c8..f172d47f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -495,13 +495,6 @@ def _build_item_dataarrays(self, idx): # if da_forcing is None, the function will return None for # da_forcing_windowed - da_state, da_forcing_windowed = self._slice_time( - da_state=da_state, - idx=idx, - n_steps=self.ar_steps, - da_forcing=da_forcing, - ) - if da_boundary is not None: _, da_boundary_windowed = self._slice_time( da_state=da_state, @@ -509,6 +502,12 @@ def _build_item_dataarrays(self, idx): n_steps=self.ar_steps, da_forcing=da_boundary, ) + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, + ) # load the data into memory da_state.load() From e6ace2727038d5a472a18e7eab7e6a26b6362fbb Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:42:05 +0100 Subject: [PATCH 57/90] renaming test datastores --- tests/datastore_examples/.gitignore | 3 +- .../.gitignore | 0 .../era5_1000hPa_danra_100m_winds/config.yaml | 12 +++ .../danra.datastore.yaml | 99 +++++++++++++++++++ .../era5.datastore.yaml | 23 ++--- .../mdp/era5_1000hPa_winds/config.yaml | 3 - 6 files changed, 122 insertions(+), 18 deletions(-) rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/.gitignore (100%) create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml create mode 100644 tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml rename tests/datastore_examples/mdp/{era5_1000hPa_winds => era5_1000hPa_danra_100m_winds}/era5.datastore.yaml (80%) delete mode 100644 tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index e84e6493..4fbd2326 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,3 @@ npyfilesmeps/*.zip -npyfilesmeps/meps_example_reduced/ +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore similarity index 100% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/.gitignore rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 00000000..a158bee3 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..3edf1267 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml similarity index 80% rename from tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml rename to tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index 36b39501..c97da4bc 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -1,5 +1,4 @@ -#TODO: What do these versions mean? Should they be updated? -schema_version: v0.2.0+dev +schema_version: v0.5.0 dataset_version: v1.0.0 output: @@ -7,8 +6,8 @@ output: forcing: [time, grid_index, forcing_feature] coord_ranges: time: - start: 1990-09-02T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 step: PT6H chunking: time: 1 @@ -16,17 +15,17 @@ output: dim: time splits: train: - start: 1990-09-02T00:00 - end: 1990-09-07T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 compute_statistics: ops: [mean, std, diff_mean, diff_std] dims: [grid_index, time] val: - start: 1990-09-05T00:00 - end: 1990-09-08T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 test: - start: 1990-09-06T00:00 - end: 1990-09-10T00:00 + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: @@ -37,10 +36,6 @@ inputs: level: values: [1000,] units: hPa - v_component_of_wind: - level: - values: [1000, ] - units: hPa dim_mapping: time: method: rename diff --git a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml deleted file mode 100644 index 5d1e05f2..00000000 --- a/tests/datastore_examples/mdp/era5_1000hPa_winds/config.yaml +++ /dev/null @@ -1,3 +0,0 @@ -datastore: - kind: mdp - config_path: era5.datastore.yaml From 42818f0e91ccebb03c506b00f42e05e7d8d6fdfa Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:15 +0100 Subject: [PATCH 58/90] adding num_past/future_boundary_step args --- neural_lam/train_model.py | 37 +++++++++++++++------------------ tests/test_datasets.py | 43 +++++++++++++++++++++++++++++++++------ tests/test_training.py | 24 ++++++++++++++++++++-- 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 37bf6db7..2a61e86c 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -34,11 +34,6 @@ def main(input_args=None): type=str, help="Path to the configuration for neural-lam", ) - parser.add_argument( - "--config_path_boundary", - type=str, - help="Path to the configuration for boundary conditions", - ) parser.add_argument( "--model", type=str, @@ -208,6 +203,18 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -217,9 +224,6 @@ def main(input_args=None): assert ( args.config_path is not None ), "Specify your config with --config_path" - assert ( - args.config_path_boundary is not None - ), "Specify your config with --config_path_boundary" assert args.model in MODELS, f"Unknown model: {args.model}" assert args.eval in ( None, @@ -234,21 +238,10 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore = load_config_and_datastore(config_path=args.config_path) - config_boundary, datastore_boundary = load_config_and_datastore( - config_path=args.config_path_boundary + config, datastore, datastore_boundary = load_config_and_datastore( + config_path=args.config_path ) - # TODO this should not be required, make more flexible - assert ( - datastore.num_past_forcing_steps - == datastore_boundary.num_past_forcing_steps - ), "Mismatch in num_past_forcing_steps" - assert ( - datastore.num_future_forcing_steps - == datastore_boundary.num_future_forcing_steps - ), "Mismatch in num_future_forcing_steps" - # Create datamodule data_module = WeatherDataModule( datastore=datastore, @@ -258,6 +251,8 @@ def main(input_args=None): standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 67eac70e..5fbe4a5d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -42,10 +42,13 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): datastore_boundary_name ) N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -53,6 +56,8 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) item = dataset[0] @@ -77,8 +82,23 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( - num_past_forcing_steps + num_future_forcing_steps + 1 + # each stacked forcing feature has one corresponding temporal embedding + assert ( + forcing.shape[2] + == datastore.get_num_data_vars("forcing") + * (num_past_forcing_steps + num_future_forcing_steps + 1) + * 2 + ) + + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert ( + boundary.shape[2] + == datastore_boundary.get_num_data_vars("forcing") + * (num_past_boundary_steps + num_future_boundary_steps + 1) + * 2 ) # batch times @@ -88,6 +108,7 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length + dataset[len(dataset) - 1] @@ -106,6 +127,9 @@ def test_dataset_item_create_dataarray_from_tensor( N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + dataset = WeatherDataset( datastore=datastore, datastore_boundary=datastore_boundary, @@ -113,16 +137,22 @@ def test_dataset_item_create_dataarray_from_tensor( ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - _, target_states, _, target_times_arr = dataset[idx] - _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( - idx=idx - ) + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) target_times = np.array(target_times_arr, dtype="datetime64[ns]") np.testing.assert_equal(target_times, da_target_times_true.values) @@ -272,6 +302,7 @@ def test_dataset_length(dataset_config): dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=dataset_config["ar_steps"], num_past_forcing_steps=dataset_config["past"], diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847d..28566a4b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,18 +14,33 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataModule -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( f"Skipping test for {datastore_name} as it is not a regular " "grid datastore." ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a regular " + "grid datastore." + ) if torch.cuda.is_available(): device_name = "cuda" @@ -59,6 +74,7 @@ def test_training(datastore_name): data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=3, ar_steps_eval=5, standardize=True, @@ -66,6 +82,8 @@ def test_training(datastore_name): num_workers=1, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, ) class ModelArgs: @@ -85,6 +103,8 @@ class ModelArgs: metrics_watch = [] num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 model_args = ModelArgs() From 0103b6e70927cb53e59b77c30245d3fa8139f8ed Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:44:51 +0100 Subject: [PATCH 59/90] using combined config file --- neural_lam/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..914ebb38 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -168,4 +168,15 @@ def load_config_and_datastore( datastore_kind=config.datastore.kind, config_path=datastore_config_path ) - return config, datastore + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary From 089634447df0c2704670df900fc4733a727fce38 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:12 +0100 Subject: [PATCH 60/90] proper handling of state/forcing/boundary in dataset --- neural_lam/weather_dataset.py | 304 +++++++++++++++++++--------------- 1 file changed, 167 insertions(+), 137 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index f172d47f..7dbe0567 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -38,6 +38,16 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time t + Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before + t, given num_past_forcing_steps) are included as boundary inputs at time + t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -50,6 +60,8 @@ def __init__( ar_steps=3, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, standardize=True, ): super().__init__() @@ -60,10 +72,10 @@ def __init__( self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray( - category="state", split=self.split - ) + self.da_state = self.datastore.get_dataarray(category="state", split=self.split) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -74,9 +86,12 @@ def __init__( category="forcing", split=self.split ) # XXX For now boundary data is always considered mdp-forcing data - self.da_boundary = self.datastore_boundary.get_dataarray( - category="forcing", split=self.split - ) + if self.datastore_boundary is not None: + self.da_boundary = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -97,9 +112,7 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order( - category=part - ) + expected_dim_order = self.datastore.expected_dim_order(category=part) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -108,6 +121,23 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + else: + self.da_state = self.da_state + def get_time_step(times): """Calculate the time step from the data""" time_diffs = np.diff(times) @@ -119,11 +149,18 @@ def get_time_step(times): return time_diffs[0] # Check time step consistency in state data - _ = get_time_step(self.da_state.time.values) + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + _ = get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: - state_times = self.da_state.time + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time state_time_min = state_times.min().values state_time_max = state_times.max().values @@ -131,26 +168,30 @@ def get_time_step(times): # Forcing data is part of the same datastore as state data # During creation the time dimension of the forcing data # is matched to the state data - forcing_times = self.da_forcing.time - _ = get_time_step(forcing_times.values) + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + else: + forcing_times = self.da_forcing.time + get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore # The boundary data is allowed to have a different time_step # Check that the boundary data covers the required time range - boundary_times = self.da_boundary.time + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary.analysis_time + else: + boundary_times = self.da_boundary.time boundary_time_step = get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - - self.num_past_forcing_steps * boundary_time_step + state_time_min - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max - + self.num_future_forcing_steps * boundary_time_step + state_time_max + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -179,10 +220,8 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = ( - self.datastore.get_standardization_dataarray( - category="forcing" - ) + self.ds_forcing_stats = self.datastore.get_standardization_dataarray( + category="forcing" ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -208,7 +247,7 @@ def __len__(self): warnings.warn( "only using first ensemble member, so dataset size is " " effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", + f"({self.datastore._num_ensemble_members})", UserWarning, ) @@ -242,36 +281,50 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing_boundary=None, + num_past_steps=None, + num_future_steps=None, + ): """ Produce time slices of the given dataarrays `da_state` (state) and - `da_forcing` (forcing). For the state data, slicing is done based on - `idx`. For the forcing data, nearest neighbor matching is performed - based on the state times. Additionally, the time difference between the - matched forcing times and state times (in multiples of state time steps) - is added to the forcing dataarray. This will be used as an additional - feature in the model (temporal embedding). + `da_forcing_boundary`. For the state data, slicing is done + based on `idx`. For the forcing/boundary data, nearest neighbor matching + is performed based on the state times. Additionally, the time difference + between the matched forcing/boundary times and state times (in multiples + of state time steps) is added to the forcing dataarray. This will be + used as an additional feature in the model (temporal embedding). Parameters ---------- da_state : xr.DataArray The state dataarray to slice. - da_forcing : xr.DataArray - The forcing dataarray to slice. idx : int The index of the time step to start the sample from in the state data. n_steps : int The number of time steps to include in the sample. + da_forcing_boundary : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. Returns ------- da_state_sliced : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). - da_forcing_matched : xr.DataArray + da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', - 'forcing_feature_windowed'). + 'forcing/boundary_feature_windowed'). """ # Number of initial steps required (e.g., for initializing models) init_steps = 2 @@ -279,8 +332,8 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): # Slice the state data as before if self.datastore.is_forecast: # Calculate start and end indices for slicing - start_idx = max(0, self.num_past_forcing_steps - init_steps) - end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + start_idx = max(0, num_past_steps - init_steps) + end_idx = max(init_steps, num_past_steps) + n_steps # Slice the state data over the elapsed forecast duration da_state_sliced = da_state.isel( @@ -299,13 +352,11 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): else: # For analysis data, slice the time dimension directly - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = ( - idx + max(init_steps, self.num_past_forcing_steps) + n_steps - ) + start_idx = idx + max(0, num_past_steps - init_steps) + end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - if da_forcing is None: + if da_forcing_boundary is None: return da_state_sliced, None # Get the state times and its temporal resolution for matching with @@ -313,78 +364,66 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] - # Match forcing data to state times based on nearest neighbor - if self.datastore.is_forecast: - # Calculate all possible forcing times - forcing_times = ( - da_forcing.analysis_time + da_forcing.elapsed_forecast_duration - ) - forcing_times_flat = forcing_times.stack( - forecast_time=("analysis_time", "elapsed_forecast_duration") - ) + if "analysis_time" in da_forcing_boundary.dims: + idx = np.abs( + da_forcing_boundary.analysis_time.values + - self.da_state.analysis_time.values[idx] + ).argmin() + # Add a 'time' dimension using the actual forecast times + offset = max(init_steps, num_past_steps) + da_list = [] + for step in range(n_steps): + start_idx = offset + step - num_past_steps + end_idx = offset + step + num_future_steps + + current_time = ( + da_forcing_boundary.analysis_time[idx] + + da_forcing_boundary.elapsed_forecast_duration[offset + step] + ) - # Compute time differences - time_deltas = ( - forcing_times_flat.values[:, np.newaxis] - - state_times.values[np.newaxis, :] - ) - time_diffs = np.abs(time_deltas) - idx_min = time_diffs.argmin(axis=0) - - # Retrieve corresponding indices for analysis_time and - # elapsed_forecast_duration - forecast_time_index = forcing_times_flat["forecast_time"][idx_min] - analysis_time_indices = forecast_time_index["analysis_time"] - elapsed_forecast_duration_indices = forecast_time_index[ - "elapsed_forecast_duration" - ] - - # Slice the forcing data using matched indices - da_forcing_matched = da_forcing.isel( - analysis_time=("time", analysis_time_indices), - elapsed_forecast_duration=( - "time", - elapsed_forecast_duration_indices, - ), - ) + da_sliced = da_forcing_boundary.isel( + analysis_time=idx, + elapsed_forecast_duration=slice(start_idx, end_idx + 1), + ) - # Assign matched state times to the forcing data - da_forcing_matched["time"] = state_times - da_forcing_matched = da_forcing_matched.swap_dims( - {"elapsed_forecast_duration": "time"} - ) + da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) - # Calculate time differences in multiples of state time steps - state_time_step = state_times.values[1] - state_times.values[0] - time_diff_steps = ( - time_deltas[idx_min, np.arange(len(state_times))] - / state_time_step - ) + da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + + da_list.append(da_sliced) - # Add time difference as a new coordinate - da_forcing_matched = da_forcing_matched.assign_coords( - time_diff=("time", time_diff_steps) + # Concatenate the list of DataArrays along the 'time' dimension + da_forcing_boundary_matched = xr.concat(da_list, dim="time") + forcing_time_step = ( + da_forcing_boundary_matched.time.values[1] + - da_forcing_boundary_matched.time.values[0] ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( + forcing_time_step / state_time_step + ) + time_diff_steps = da_forcing_boundary_matched.isel( + grid_index=0, forcing_feature=0 + ).data + else: # For analysis data, match directly using the 'time' coordinate - forcing_times = da_forcing["time"] + forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) time_diff_steps = np.stack( [ time_deltas[ - idx_i - - self.num_past_forcing_steps : idx_i - + self.num_future_forcing_steps - + 1, + idx_i - num_past_steps : idx_i + num_future_steps + 1, init_steps + step_i, ] for (step_i, idx_i) in enumerate(idx_min[init_steps:]) @@ -392,24 +431,22 @@ def _slice_time(self, da_state, idx, n_steps: int, da_forcing=None): ) # Create window dimension for forcing data to stack later - window_size = ( - self.num_past_forcing_steps + self.num_future_forcing_steps + 1 - ) - da_forcing_windowed = da_forcing.rolling( - time=window_size, center=True + window_size = num_past_steps + num_future_steps + 1 + da_forcing_boundary_windowed = da_forcing_boundary.rolling( + time=window_size, center=False ).construct(window_dim="window") - da_forcing_matched = da_forcing_windowed.isel( + da_forcing_boundary_matched = da_forcing_boundary_windowed.isel( time=idx_min[init_steps:] ) - # Add time difference as a new coordinate to concatenate to the - # forcing features later - da_forcing_matched["time_diff_steps"] = ( - ("time", "window"), - time_diff_steps, - ) + # Add time difference as a new coordinate to concatenate to the + # forcing features later + da_forcing_boundary_matched["time_diff_steps"] = ( + ("time", "window"), + time_diff_steps, + ) - return da_state_sliced, da_forcing_matched + return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): """Helper function to process windowed data after standardization.""" @@ -462,23 +499,7 @@ def _build_item_dataarrays(self, idx): da_target_times : xr.DataArray The dataarray for the target times. """ - # handling ensemble data - if self.datastore.is_ensemble: - # for the now the strategy is to only include the first ensemble - # member - # XXX: this could be changed to include all ensemble members by - # splitting `idx` into two parts, one for the analysis time and one - # for the ensemble member and then increasing self.__len__ to - # include all ensemble members - warnings.warn( - "only use of ensemble member 0 (the first member) is " - "implemented for ensemble data" - ) - i_ensemble = 0 - da_state = self.da_state.isel(ensemble_member=i_ensemble) - else: - da_state = self.da_state - + da_state = self.da_state if self.da_forcing is not None: if "ensemble_member" in self.da_forcing.dims: raise NotImplementedError( @@ -500,13 +521,19 @@ def _build_item_dataarrays(self, idx): da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_boundary, + da_forcing_boundary=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, ) + else: + da_boundary_windowed = None da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, n_steps=self.ar_steps, - da_forcing=da_forcing, + da_forcing_boundary=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, ) # load the data into memory @@ -521,9 +548,7 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = ( - da_init_states - self.da_state_mean - ) / self.da_state_std + da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -595,9 +620,7 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor( - da_target_states.values, dtype=tensor_dtype - ) + target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -708,10 +731,7 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if ( - grid_coord in da_datastore_state.coords - and grid_coord not in da.coords - ): + if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: @@ -732,6 +752,8 @@ def __init__( standardize=True, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, batch_size=4, num_workers=16, ): @@ -740,6 +762,8 @@ def __init__( self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -766,6 +790,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, @@ -775,6 +801,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) if stage == "test" or stage is None: @@ -786,6 +814,8 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) def train_dataloader(self): From 355423c8412677823db63d34ad4b2649abcf1478 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:45:35 +0100 Subject: [PATCH 61/90] datastore_boundars=None introduced --- .../datastore/npyfilesmeps/compute_standardization_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index f2c80e8a..4207812f 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -172,6 +172,7 @@ def main( ar_steps = 63 ds = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=False, From 121d460930fd24ae0ff90dd0d07279c75a15b1d5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:02 +0100 Subject: [PATCH 62/90] bug fix for file retrieval per member --- neural_lam/datastore/npyfilesmeps/store.py | 51 +++++++++------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 146b0627..7ee583be 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,9 +244,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray( - features=[feature], split=split - ) + self._get_single_timeseries_dataarray(features=[feature], split=split) for feature in features ] da = xr.concat(das, dim="feature") @@ -259,9 +257,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = ( - da.analysis_time + da.elapsed_forecast_duration - ).chunk({"elapsed_forecast_duration": 1}) + da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( + {"elapsed_forecast_duration": 1} + ) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -339,10 +337,7 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if ( - set(features).difference(self.get_vars_names(category="static")) - == set() - ): + if set(features).difference(self.get_vars_names(category="static")) == set(): assert split in ( "train", "val", @@ -356,12 +351,8 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names( - category="state" - ): - raise ValueError( - "Member can only be specified for the 'state' category" - ) + if member is not None and features != self.get_vars_names(category="state"): + raise ValueError("Member can only be specified for the 'state' category") concat_axis = 0 @@ -377,9 +368,7 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones( - len(features) + n_to_drop, dtype=bool - ) + feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -445,7 +434,7 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split) + coord_values = self._get_analysis_times(split=split, member_id=member) elif d == "y": coord_values = y elif d == "x": @@ -464,9 +453,7 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format( - analysis_time=analysis_time, **file_params - ) + / filename_format.format(analysis_time=analysis_time, **file_params) for analysis_time in coords["analysis_time"] ] else: @@ -505,7 +492,7 @@ def _get_single_timeseries_dataarray( return da - def _get_analysis_times(self, split) -> List[np.datetime64]: + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: """Get the analysis times for the given split by parsing the filenames of all the files found for the given split. @@ -513,6 +500,8 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ---------- split : str The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. Returns ------- @@ -520,8 +509,12 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: The analysis times for the given split. """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) - pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) sample_dir = self.root_path / "samples" / split sample_files = sample_dir.glob(pattern) @@ -531,9 +524,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError( - f"No files found in {sample_dir} with pattern {pattern}" - ) + raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") return times @@ -690,9 +681,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load( - self.root_path / "static" / fn, weights_only=True - ).numpy() + return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() mean_diff_values = None std_diff_values = None From 7e82eef5d797c76a7667271603e5ea94a3485ac2 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:17 +0100 Subject: [PATCH 63/90] rename datastore for tests --- tests/conftest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index be5cf3e7..90a86d0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,14 +94,14 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) -DATASTORES_BOUNDARY_EXAMPLES = dict( - mdp=( +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( DATASTORE_EXAMPLES_ROOT_PATH / "mdp" - / "era5_1000hPa_winds" + / "era5_1000hPa_danra_100m_winds" / "era5.datastore.yaml" - ) -) + ), +} DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore From 320d7c4826e4055fef0edfa748c3e7b6704c589a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:31 +0100 Subject: [PATCH 64/90] aligned time with danra for easier boundary testing --- tests/dummy_datastore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index d62c7356..a958b8f5 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -28,7 +28,7 @@ class DummyDatastore(BaseRegularGridDatastore): """ SHORT_NAME = "dummydata" - T0 = isodate.parse_datetime("2021-01-01T00:00:00") + T0 = isodate.parse_datetime("1990-09-02T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) CARTESIAN_COORDS = ["x", "y"] From f18dcc2340434ce96f709ba987af482d063de4e5 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 30 Nov 2024 20:46:50 +0100 Subject: [PATCH 65/90] Fixed test for temporal embedding --- tests/test_time_slicing.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..2f5ed96c 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,9 +40,7 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) + da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -78,10 +76,8 @@ def get_vars_long_names(self, category): def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -93,6 +89,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,9 +98,7 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ - tensor.numpy() for tensor in sample - ] + init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] expected_init_states = [0, 1] if ar_steps == 3: @@ -130,7 +125,7 @@ def test_time_slicing_analysis( # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) assert init_states.shape == (2, 1, 1) assert init_states[:, 0, 0].tolist() == expected_init_states @@ -141,6 +136,10 @@ def test_time_slicing_analysis( assert forcing.shape == ( 3, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + # Factor 2 because each window step has a temporal embedding + (1 + num_past_forcing_steps + num_future_forcing_steps) * 2, + ) + np.testing.assert_equal( + forcing[:, 0, : num_past_forcing_steps + num_future_forcing_steps + 1], + np.array(expected_forcing_values), ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values)) From e6327d88373bb2708733f6331aebe407facc1f67 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:48 +0100 Subject: [PATCH 66/90] allow boundary as input to ar_model.common_step --- neural_lam/models/ar_model.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 710efcec..331966e4 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -110,7 +110,9 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - + num_forcing_vars + # Factor 2 because of temporal embedding or windowed features + + 2 + * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) @@ -241,19 +243,20 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states + Predict on single batch batch consists of: + init_states: (B, 2,num_grid_nodes, d_features) + target_states: (B, pred_steps,num_grid_nodes, d_features) + forcing_features: (B, pred_steps,num_grid_nodes, d_forcing) + boundary_features: (B, pred_steps,num_grid_nodes, d_boundaries) + batch_times: (B, pred_steps) """ - (init_states, target_states, forcing_features, batch_times) = batch + (init_states, target_states, forcing_features, _, batch_times) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + ) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std, batch_times From 1374a1976f002ffba86c7c203c6fbb2bea83fb0e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 10:40:56 +0100 Subject: [PATCH 67/90] linting --- neural_lam/datastore/npyfilesmeps/store.py | 43 ++++++++---- neural_lam/weather_dataset.py | 66 ++++++++++++------- .../era5.datastore.yaml | 2 +- tests/test_time_slicing.py | 12 +++- tests/test_training.py | 17 ++--- 5 files changed, 91 insertions(+), 49 deletions(-) diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 7ee583be..24349e7e 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -244,7 +244,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # them separately features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] das = [ - self._get_single_timeseries_dataarray(features=[feature], split=split) + self._get_single_timeseries_dataarray( + features=[feature], split=split + ) for feature in features ] da = xr.concat(das, dim="feature") @@ -257,9 +259,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: # variable is turned into a dask array and so execution of the # calculation is delayed until the feature values are actually # used. - da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk( - {"elapsed_forecast_duration": 1} - ) + da_forecast_time = ( + da.analysis_time + da.elapsed_forecast_duration + ).chunk({"elapsed_forecast_duration": 1}) da_datetime_forcing_features = self._calc_datetime_forcing_features( da_time=da_forecast_time ) @@ -337,7 +339,10 @@ def _get_single_timeseries_dataarray( for all categories of data """ - if set(features).difference(self.get_vars_names(category="static")) == set(): + if ( + set(features).difference(self.get_vars_names(category="static")) + == set() + ): assert split in ( "train", "val", @@ -351,8 +356,12 @@ def _get_single_timeseries_dataarray( "test", ), f"Unknown dataset split {split} for features {features}" - if member is not None and features != self.get_vars_names(category="state"): - raise ValueError("Member can only be specified for the 'state' category") + if member is not None and features != self.get_vars_names( + category="state" + ): + raise ValueError( + "Member can only be specified for the 'state' category" + ) concat_axis = 0 @@ -368,7 +377,9 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "samples" / split if self._remove_state_features_with_index: n_to_drop = len(self._remove_state_features_with_index) - feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool) + feature_dim_mask = np.ones( + len(features) + n_to_drop, dtype=bool + ) feature_dim_mask[self._remove_state_features_with_index] = False elif features == ["toa_downwelling_shortwave_flux"]: filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT @@ -434,7 +445,9 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split, member_id=member) + coord_values = self._get_analysis_times( + split=split, member_id=member + ) elif d == "y": coord_values = y elif d == "x": @@ -453,7 +466,9 @@ def _get_single_timeseries_dataarray( if features_vary_with_analysis_time: filepaths = [ fp_samples - / filename_format.format(analysis_time=analysis_time, **file_params) + / filename_format.format( + analysis_time=analysis_time, **file_params + ) for analysis_time in coords["analysis_time"] ] else: @@ -524,7 +539,9 @@ def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: times.append(name_parts["analysis_time"]) if len(times) == 0: - raise ValueError(f"No files found in {sample_dir} with pattern {pattern}") + raise ValueError( + f"No files found in {sample_dir} with pattern {pattern}" + ) return times @@ -681,7 +698,9 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ def load_pickled_tensor(fn): - return torch.load(self.root_path / "static" / fn, weights_only=True).numpy() + return torch.load( + self.root_path / "static" / fn, weights_only=True + ).numpy() mean_diff_values = None std_diff_values = None diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 7dbe0567..60f8d316 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -41,13 +41,13 @@ class WeatherDataset(torch.utils.data.Dataset): num_past_boundary_steps: int, optional Number of past time steps to include in boundary input. If set to i, boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, - given num_future_forcing_steps) are included as boundary inputs at time t - Default is 1. + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. num_future_boundary_steps: int, optional Number of future time steps to include in boundary input. If set to j, - boundary from times t, t+1, ..., t+j-1, t+j (and potentially times before - t, given num_past_forcing_steps) are included as boundary inputs at time - t. Default is 1. + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -75,7 +75,9 @@ def __init__( self.num_past_boundary_steps = num_past_boundary_steps self.num_future_boundary_steps = num_future_boundary_steps - self.da_state = self.datastore.get_dataarray(category="state", split=self.split) + self.da_state = self.datastore.get_dataarray( + category="state", split=self.split + ) if self.da_state is None: raise ValueError( "A non-empty state dataarray must be provided. " @@ -112,7 +114,9 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): - expected_dim_order = self.datastore.expected_dim_order(category=part) + expected_dim_order = self.datastore.expected_dim_order( + category=part + ) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -188,10 +192,12 @@ def get_time_step(times): # Calculate required bounds for boundary using its time step boundary_required_time_min = ( - state_time_min - self.num_past_forcing_steps * boundary_time_step + state_time_min + - self.num_past_forcing_steps * boundary_time_step ) boundary_required_time_max = ( - state_time_max + self.num_future_forcing_steps * boundary_time_step + state_time_max + + self.num_future_forcing_steps * boundary_time_step ) if boundary_time_min > boundary_required_time_min: @@ -220,8 +226,10 @@ def get_time_step(times): self.da_state_std = self.ds_state_stats.state_std if self.da_forcing is not None: - self.ds_forcing_stats = self.datastore.get_standardization_dataarray( - category="forcing" + self.ds_forcing_stats = ( + self.datastore.get_standardization_dataarray( + category="forcing" + ) ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std @@ -378,7 +386,9 @@ def _slice_time( current_time = ( da_forcing_boundary.analysis_time[idx] - + da_forcing_boundary.elapsed_forecast_duration[offset + step] + + da_forcing_boundary.elapsed_forecast_duration[ + offset + step + ] ) da_sliced = da_forcing_boundary.isel( @@ -386,12 +396,16 @@ def _slice_time( elapsed_forecast_duration=slice(start_idx, end_idx + 1), ) - da_sliced = da_sliced.rename({"elapsed_forecast_duration": "window"}) + da_sliced = da_sliced.rename( + {"elapsed_forecast_duration": "window"} + ) da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) - da_sliced = da_sliced.expand_dims(dim={"time": [current_time.values]}) + da_sliced = da_sliced.expand_dims( + dim={"time": [current_time.values]} + ) da_list.append(da_sliced) @@ -401,13 +415,13 @@ def _slice_time( da_forcing_boundary_matched.time.values[1] - da_forcing_boundary_matched.time.values[0] ) - da_forcing_boundary_matched["window"] = da_forcing_boundary_matched["window"] * ( - forcing_time_step / state_time_step - ) + da_forcing_boundary_matched["window"] = da_forcing_boundary_matched[ + "window" + ] * (forcing_time_step / state_time_step) time_diff_steps = da_forcing_boundary_matched.isel( grid_index=0, forcing_feature=0 ).data - + else: # For analysis data, match directly using the 'time' coordinate forcing_times = da_forcing_boundary["time"] @@ -416,7 +430,8 @@ def _slice_time( # (in multiples of state time steps) # Retrieve the indices of the closest times in the forcing data time_deltas = ( - forcing_times.values[:, np.newaxis] - state_times.values[np.newaxis, :] + forcing_times.values[:, np.newaxis] + - state_times.values[np.newaxis, :] ) / state_time_step idx_min = np.abs(time_deltas).argmin(axis=0) @@ -548,7 +563,9 @@ def _build_item_dataarrays(self, idx): da_target_times = da_target_states.time if self.standardize: - da_init_states = (da_init_states - self.da_state_mean) / self.da_state_std + da_init_states = ( + da_init_states - self.da_state_mean + ) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean ) / self.da_state_std @@ -620,7 +637,9 @@ def __getitem__(self, idx): tensor_dtype = torch.float32 init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) - target_states = torch.tensor(da_target_states.values, dtype=tensor_dtype) + target_states = torch.tensor( + da_target_states.values, dtype=tensor_dtype + ) target_times = torch.tensor( da_target_times.astype("datetime64[ns]").astype("int64").values, @@ -731,7 +750,10 @@ def _is_listlike(obj): ) for grid_coord in ["x", "y"]: - if grid_coord in da_datastore_state.coords and grid_coord not in da.coords: + if ( + grid_coord in da_datastore_state.coords + and grid_coord not in da.coords + ): da.coords[grid_coord] = da_datastore_state[grid_coord] if not add_time_as_dim: diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml index c97da4bc..7c5ffb3b 100644 --- a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -25,7 +25,7 @@ output: end: 2022-09-30T00:00 test: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 2f5ed96c..4a59c81e 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -40,7 +40,9 @@ def get_dataarray(self, category, split): if self.is_forecast: raise NotImplementedError() else: - da = xr.DataArray(values, dims=["time"], coords={"time": self._time_values}) + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) # add `{category}_feature` and `grid_index` dimensions da = da.expand_dims("grid_index") @@ -77,7 +79,9 @@ def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): # state and forcing variables have only one dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange(len(ANALYSIS_STATE_VALUES)) + time_values = np.datetime64("2020-01-01") + np.arange( + len(ANALYSIS_STATE_VALUES) + ) assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( @@ -98,7 +102,9 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _, _ = [tensor.numpy() for tensor in sample] + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] expected_init_states = [0, 1] if ar_steps == 3: diff --git a/tests/test_training.py b/tests/test_training.py index 28566a4b..7a1b4717 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,6 +5,7 @@ import pytest import pytorch_lightning as pl import torch + import wandb # First-party @@ -22,14 +23,10 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) +@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) + datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -38,15 +35,13 @@ def test_training(datastore_name, datastore_boundary_name): ) if not isinstance(datastore_boundary, BaseRegularGridDatastore): pytest.skip( - f"Skipping test for {datastore_boundary_name} as it is not a regular " - "grid datastore." + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." ) if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision( - "high" - ) # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s else: device_name = "cpu" From 779f3e9ed31d9525851793fae409cc145a30e15a Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:10:29 +0100 Subject: [PATCH 68/90] improved docstrings and added some assertions --- neural_lam/weather_dataset.py | 105 ++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 60f8d316..c65ec468 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -143,7 +143,13 @@ def __init__( self.da_state = self.da_state def get_time_step(times): - """Calculate the time step from the data""" + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ time_diffs = np.diff(times) if not np.all(time_diffs == time_diffs[0]): raise ValueError( @@ -234,6 +240,7 @@ def get_time_step(times): self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + # XXX: Again, the boundary data is considered forcing data for now if self.da_boundary is not None: self.ds_boundary_stats = ( self.datastore_boundary.get_standardization_dataarray( @@ -305,7 +312,7 @@ def _slice_time( is performed based on the state times. Additionally, the time difference between the matched forcing/boundary times and state times (in multiples of state time steps) is added to the forcing dataarray. This will be - used as an additional feature in the model (temporal embedding). + used as an additional input feature in the model (temporal embedding). Parameters ---------- @@ -333,23 +340,26 @@ def _slice_time( da_forcing_boundary_matched : xr.DataArray The sliced state dataarray with dims ('time', 'grid_index', 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. """ - # Number of initial steps required (e.g., for initializing models) + # The current implementation requires at least 2 time steps for the + # initial state (see GraphCast). init_steps = 2 - - # Slice the state data as before + # slice the dataarray to include the required number of time steps if self.datastore.is_forecast: - # Calculate start and end indices for slicing - start_idx = max(0, num_past_steps - init_steps) - end_idx = max(init_steps, num_past_steps) + n_steps - - # Slice the state data over the elapsed forecast duration + start_idx = max(0, self.num_past_forcing_steps - init_steps) + end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps + # this implies that the data will have both `analysis_time` and + # `elapsed_forecast_duration` dimensions for forecasts. We for now + # simply select a analysis time and the first `n_steps` forecast + # times (given no offset). Note that this means that we get one + # sample per forecast, always starting at forecast time 2. da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) - - # Create a new 'time' dimension + # create a new time dimension so that the produced sample has a + # `time` dimension, similarly to the analysis only data da_state_sliced["time"] = ( da_state_sliced.analysis_time + da_state_sliced.elapsed_forecast_duration @@ -357,9 +367,13 @@ def _slice_time( da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + # Asserting that the forecast time step is consistent + self.get_time_step(da_state_sliced.time) else: - # For analysis data, slice the time dimension directly + # For analysis data we slice the time dimension directly. The offset + # is only relevant for the very first (and last) samples in the + # dataset. start_idx = idx + max(0, num_past_steps - init_steps) end_idx = idx + max(init_steps, num_past_steps) + n_steps da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) @@ -372,7 +386,13 @@ def _slice_time( state_times = da_state_sliced["time"] state_time_step = state_times.values[1] - state_times.values[0] + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary if "analysis_time" in da_forcing_boundary.dims: + # Select the closest analysis time in the forcing/boundary data + # This is mostly relevant for boundary data where the time steps + # are not necessarily the same as the state data. But still fast + # enough for forcing data where the time steps are the same. idx = np.abs( da_forcing_boundary.analysis_time.values - self.da_state.analysis_time.values[idx] @@ -399,6 +419,8 @@ def _slice_time( da_sliced = da_sliced.rename( {"elapsed_forecast_duration": "window"} ) + + # Assign the 'window' coordinate to be relative positions da_sliced = da_sliced.assign_coords( window=np.arange(-num_past_steps, num_future_steps + 1) ) @@ -409,7 +431,10 @@ def _slice_time( da_list.append(da_sliced) - # Concatenate the list of DataArrays along the 'time' dimension + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time. da_forcing_boundary_matched = xr.concat(da_list, dim="time") forcing_time_step = ( da_forcing_boundary_matched.time.values[1] @@ -423,7 +448,9 @@ def _slice_time( ).data else: - # For analysis data, match directly using the 'time' coordinate + # For analysis data, we slice the time dimension directly. The + # offset is only relevant for the very first (and last) samples in + # the dataset. forcing_times = da_forcing_boundary["time"] # Compute time differences between forcing and state times @@ -455,7 +482,7 @@ def _slice_time( ) # Add time difference as a new coordinate to concatenate to the - # forcing features later + # forcing features later as temporal embedding da_forcing_boundary_matched["time_diff_steps"] = ( ("time", "window"), time_diff_steps, @@ -464,7 +491,26 @@ def _slice_time( return da_state_sliced, da_forcing_boundary_matched def _process_windowed_data(self, da_windowed, da_state, da_target_times): - """Helper function to process windowed data after standardization.""" + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + differences to the existing features as a temporal embedding. + + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. + + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. + + """ stacked_dim = "forcing_feature_windowed" if da_windowed is not None: # Stack the 'feature' and 'window' dimensions and add the @@ -492,8 +538,8 @@ def _process_windowed_data(self, da_windowed, da_state, da_target_times): def _build_item_dataarrays(self, idx): """ - Create the dataarrays for the initial states, target states and forcing - data for the sample at index `idx`. + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. Parameters ---------- @@ -529,7 +575,7 @@ def _build_item_dataarrays(self, idx): else: da_boundary = None - # if da_forcing is None, the function will return None for + # if da_forcing_boundary is None, the function will return None for # da_forcing_windowed if da_boundary is not None: _, da_boundary_windowed = self._slice_time( @@ -542,6 +588,9 @@ def _build_item_dataarrays(self, idx): ) else: da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. da_state, da_forcing_windowed = self._slice_time( da_state=da_state, idx=idx, @@ -584,6 +633,10 @@ def _build_item_dataarrays(self, idx): da_boundary_windowed - self.da_boundary_mean ) / self.da_boundary_std + # This function handles the stacking of the forcing and boundary data + # and adds the time step differences as a temporal embedding. + # It can handle `None` inputs for the forcing and boundary data + # (and simlpy return an empty DataArray in that case). da_forcing_windowed = self._process_windowed_data( da_forcing_windowed, da_state, da_target_times ) @@ -655,6 +708,11 @@ def __getitem__(self, idx): # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + return init_states, target_states, forcing, boundary, target_times def __iter__(self): @@ -795,9 +853,10 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: - # BUG: There also seem to be issues with "spawn", to be investigated - # default to spawn for now, as the default on linux "fork" hangs - # when using dask (which the npyfilesmeps datastore uses) + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) self.multiprocessing_context = "spawn" else: self.multiprocessing_context = None From f126ec27b6c7d8534893850f07427e3737418216 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:11:32 +0100 Subject: [PATCH 69/90] remove boundary datastore from tests that don't need it --- tests/test_datasets.py | 17 ++--------------- tests/test_training.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5fbe4a5d..063ec147 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -108,37 +108,24 @@ def test_dataset_item_shapes(datastore_name, datastore_boundary_name): # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset # length - dataset[len(dataset) - 1] @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize( - "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() -) -def test_dataset_item_create_dataarray_from_tensor( - datastore_name, datastore_boundary_name -): +def test_dataset_item_create_dataarray_from_tensor(datastore_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example( - datastore_boundary_name - ) N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 - num_past_boundary_steps = 1 - num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, - datastore_boundary=datastore_boundary, + datastore_boundary=None, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, - num_past_boundary_steps=num_past_boundary_steps, - num_future_boundary_steps=num_future_boundary_steps, ) idx = 0 diff --git a/tests/test_training.py b/tests/test_training.py index 7a1b4717..ca0ebf41 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -5,7 +5,6 @@ import pytest import pytorch_lightning as pl import torch - import wandb # First-party @@ -23,10 +22,14 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -@pytest.mark.parametrize("datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) - datastore_boundary = init_datastore_boundary_example(datastore_boundary_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( @@ -41,7 +44,9 @@ def test_training(datastore_name, datastore_boundary_name): if torch.cuda.is_available(): device_name = "cuda" - torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s else: device_name = "cpu" From 4b656da04526d3d38d71881deab18ee69519b29d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 12:43:01 +0100 Subject: [PATCH 70/90] fix scope of _get_time_step --- neural_lam/weather_dataset.py | 40 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c65ec468..3685e227 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,28 +142,14 @@ def __init__( else: self.da_state = self.da_state - def get_time_step(times): - """Calculate the time step from the data - - Parameters - ---------- - times : xr.DataArray - The time dataarray to calculate the time step from. - """ - time_diffs = np.diff(times) - if not np.all(time_diffs == time_diffs[0]): - raise ValueError( - "Inconsistent time steps in data. " - f"Found different time steps: {np.unique(time_diffs)}" - ) - return time_diffs[0] + # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time else: state_times = self.da_state.time - _ = get_time_step(state_times) + _ = self._get_time_step(state_times) # Check time coverage for forcing and boundary data if self.da_forcing is not None or self.da_boundary is not None: @@ -182,7 +168,7 @@ def get_time_step(times): forcing_times = self.da_forcing.analysis_time else: forcing_times = self.da_forcing.time - get_time_step(forcing_times.values) + self._get_time_step(forcing_times.values) if self.da_boundary is not None: # Boundary data is part of a separate datastore @@ -192,7 +178,7 @@ def get_time_step(times): boundary_times = self.da_boundary.analysis_time else: boundary_times = self.da_boundary.time - boundary_time_step = get_time_step(boundary_times.values) + boundary_time_step = self._get_time_step(boundary_times.values) boundary_time_min = boundary_times.min().values boundary_time_max = boundary_times.max().values @@ -296,6 +282,22 @@ def __len__(self): - self.num_future_forcing_steps ) + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + def _slice_time( self, da_state, @@ -368,7 +370,7 @@ def _slice_time( {"elapsed_forecast_duration": "time"} ) # Asserting that the forecast time step is consistent - self.get_time_step(da_state_sliced.time) + self._get_time_step(da_state_sliced.time) else: # For analysis data we slice the time dimension directly. The offset From 75db4b8a5ac0769dab7be8837e707b734c62ff92 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 2 Dec 2024 16:58:46 +0100 Subject: [PATCH 71/90] added information about optional boundary datastore --- README.md | 22 +++++++++++++--------- neural_lam/weather_dataset.py | 2 -- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e21b7c24..7a5e5caf 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th interface that provides the data in a data-structure that can be used within neural-lam. A datastore is used to create a `pytorch.Dataset`-derived class that samples the data in time to create individual samples for - training, validation and testing. + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. 2. **The graph structure** is used to define message-passing GNN layers, that are trained to emulate fluid flow in the atmosphere over time. The @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. The path you provide to the neural-lam config (`config.yaml`) also sets the root directory relative to which all other paths are resolved, as in the parent -directory of the config becomes the root directory. Both the datastore and +directory of the config becomes the root directory. Both the datastores and graphs you generate are then stored in subdirectories of this root directory. Exactly how and where a specific datastore expects its source data to be stored and where it stores its derived data is up to the implementation of the @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): data/ ├── config.yaml - Configuration file for neural-lam ├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml └── graphs/ - Directory containing graphs for training ``` @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: datastore: kind: mdp config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines two things: +1) the kind of datastores and the path to their config +2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha # Contact If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. -There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). -You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 3685e227..f02cfbd4 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -142,8 +142,6 @@ def __init__( else: self.da_state = self.da_state - - # Check time step consistency in state data if self.datastore.is_forecast: state_times = self.da_state.analysis_time From 4c175452af54fa4833fd9ac67bb4b1b36cdaa777 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 05:14:38 +0100 Subject: [PATCH 72/90] moved gcsfs to dev group --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38e7cb0e..f556ef6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,13 +26,12 @@ dependencies = [ "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard<0.31.0", - "gcsfs>=2021.10.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" [project.optional-dependencies] -dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] +dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2", "gcsfs>=2021.10.0"] [tool.setuptools] py-modules = ["neural_lam"] From a700350f9c0b6161ffefa06b7fa7fc7151e51f23 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Tue, 3 Dec 2024 05:14:44 +0100 Subject: [PATCH 73/90] linting --- .../era5.datastore.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml index 600a1845..7c5ffb3b 100644 --- a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -7,7 +7,7 @@ output: coord_ranges: time: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 step: PT6H chunking: time: 1 @@ -16,16 +16,16 @@ output: splits: train: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 compute_statistics: ops: [mean, std, diff_mean, diff_std] dims: [grid_index, time] val: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 test: start: 1990-09-01T00:00 - end: 2022-09-30T00:00 + end: 2022-09-30T00:00 inputs: era_height_levels: From 315aa0fbbb4d551b4ffb30b761743dbd95a14382 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 28 Oct 2024 11:20:41 +0100 Subject: [PATCH 74/90] Propagate separation of state and boundary change through training loop --- neural_lam/models/ar_model.py | 81 ++++++++++++++++++++------- neural_lam/models/base_graph_model.py | 38 ++++++++++++- neural_lam/vis.py | 52 ++++++++++++----- neural_lam/weather_dataset.py | 1 + 4 files changed, 135 insertions(+), 37 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 331966e4..95bd1154 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -48,14 +48,33 @@ def __init__( num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps + # Set up boundary mask + boundary_mask = torch.tensor( + da_boundary_mask.values, dtype=torch.float32 + ).unsqueeze( + 1 + ) # add feature dim + + self.register_buffer("boundary_mask", boundary_mask, persistent=False) + # Pre-compute interior mask for use in loss function + self.register_buffer( + "interior_mask", 1.0 - self.boundary_mask, persistent=False + ) # (num_grid_nodes, 1), 1 for non-border + # Load static features for grid/data, NB: self.predict_step assumes # dimension order to be (grid_index, static_feature) arr_static = da_static_features.transpose( "grid_index", "static_feature" ).values + static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - torch.tensor(arr_static, dtype=torch.float32), + static_features_torch[self.boundary_mask.to(torch.bool), + persistent=False, + ) + self.register_buffer( + "boundary_static_features", + static_features_torch[self.interior_mask.to(torch.bool), persistent=False, ) @@ -107,6 +126,11 @@ def __init__( grid_static_dim, ) = self.grid_static_features.shape + ( + self.num_boundary_nodes, + boundary_static_dim, # TODO Need for computation below + ) = self.boundary_static_features.shape + self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim @@ -115,6 +139,7 @@ def __init__( * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) + self.boundary_dim = self.grid_dim # TODO Compute separately # Instantiate loss function self.loss = metrics.get_metric(args.loss) @@ -190,7 +215,9 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def predict_step(self, prev_state, prev_prev_state, forcing): + def predict_step( + self, prev_state, prev_prev_state, forcing, boundary_forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, @@ -199,29 +226,31 @@ def predict_step(self, prev_state, prev_prev_state, forcing): """ raise NotImplementedError("No prediction step implemented") - def unroll_prediction(self, init_states, forcing_features, true_states): + def unroll_prediction(self, init_states, forcing, boundary_forcing): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) + init_states: (B, 2, num_grid_nodes, d_f) + forcing: (B, pred_steps, num_grid_nodes, d_static_f) + boundary_forcing: (B, pred_steps, num_boundary_nodes, d_boundary_f) """ prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] prediction_list = [] pred_std_list = [] - pred_steps = forcing_features.shape[1] + pred_steps = forcing.shape[1] for i in range(pred_steps): - forcing = forcing_features[:, i] + forcing_step = forcing[:, i] + boundary_forcing_step = boundary_forcing[:, i] pred_state, pred_std = self.predict_step( - prev_state, prev_prev_state, forcing + prev_state, prev_prev_state, forcing_step, boundary_forcing_step ) # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None prediction_list.append(pred_state) + if self.output_std: pred_std_list.append(pred_std) @@ -243,20 +272,22 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: - init_states: (B, 2,num_grid_nodes, d_features) - target_states: (B, pred_steps,num_grid_nodes, d_features) - forcing_features: (B, pred_steps,num_grid_nodes, d_forcing) - boundary_features: (B, pred_steps,num_grid_nodes, d_boundaries) - batch_times: (B, pred_steps) + Predict on single batch + batch consists of: + init_states: (B, 2, num_grid_nodes, d_features) + target_states: (B, pred_steps, num_grid_nodes, d_features) + forcing: (B, pred_steps, num_grid_nodes, d_forcing), + boundary_forcing: + (B, pred_steps, num_boundary_nodes, d_boundary_forcing), + where index 0 corresponds to index 1 of init_states """ (init_states, target_states, forcing_features, _, batch_times) = batch prediction, pred_std = self.unroll_prediction( - init_states, forcing_features, target_states - ) - # prediction: (B, pred_steps, num_grid_nodes, d_f) - # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + init_states, forcing, boundary_forcing + ) # (B, pred_steps, num_grid_nodes, d_f) + # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, + # pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std, batch_times @@ -306,7 +337,11 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss(prediction, target, pred_std), + self.loss( + prediction, + target, + pred_std, + ), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -357,7 +392,11 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss(prediction, target, pred_std), + self.loss( + prediction, + target, + pred_std, + ), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 6233b4d1..246cd93e 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -46,6 +46,12 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): # Define sub-models # Feature embedders for grid self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) + # TODO Optional separate embedder for boundary nodes + assert self.grid_dim == self.boundary_dim, ( + "Grid and boundary input dimension must be the same when using " + f"the same encoder, got grid_dim={self.grid_dim}, " + f"boundary_dim={self.boundary_dim}" + ) self.grid_embedder = utils.make_mlp( [self.grid_dim] + self.mlp_blueprint_end ) @@ -103,12 +109,15 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def predict_step( + self, prev_state, prev_prev_state, forcing, boundary_forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, forcing_dim) + boundary_forcing: (B, num_boundary_nodes, boundary_forcing_dim) """ batch_size = prev_state.shape[0] @@ -122,22 +131,45 @@ def predict_step(self, prev_state, prev_prev_state, forcing): ), dim=-1, ) + # Create full boundary node features of shape + # (B, num_boundary_nodes, boundary_dim) + boundary_features = torch.cat( + ( + boundary_forcing, + self.expand_to_batch(self.boundary_static_features, batch_size), + ), + dim=-1, + ) # Embed all features grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) + boundary_emb = self.grid_embedder(boundary_features) + # (B, num_boundary_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() + # Merge interior and boundary emb into input embedding + # TODO Can we enforce ordering in the graph creation process to make + # this just a concat instead? + input_emb = torch.zeros( + batch_size, + self.num_input_nodes, + grid_emb.shape[2], + device=grid_emb.device, + ) + input_emb[:, self.interior_mask] = grid_emb + input_emb[:, self.boundary_mask] = boundary_emb + # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch( mesh_emb, batch_size ) # (B, num_mesh_nodes, d_h) g2m_emb_expanded = self.expand_to_batch(g2m_emb, batch_size) - # This also splits representation into grid and mesh + # Encode to mesh mesh_rep = self.g2m_gnn( - grid_emb, mesh_emb_expanded, g2m_emb_expanded + input_emb, mesh_emb_expanded, g2m_emb_expanded ) # (B, num_mesh_nodes, d_h) # Also MLP with residual for grid representation grid_rep = grid_emb + self.encoding_grid_mlp( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index efab20bf..f2775328 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -64,6 +64,40 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): return fig +def plot_on_axis( + ax, + da, + datastore, + obs_mask=None, + vmin=None, + vmax=None, + ax_title=None, + cmap="plasma", +): + """ + Plot weather state on given axis + """ + ax.set_global() + ax.coastlines() # Add coastline outlines + + extent = datastore.get_xy_extent("state") + + da.plot.imshow( + ax=ax, + origin="lower", + x="x", + extent=extent, + vmin=vmin, + vmax=vmax, + cmap=cmap, + transform=datastore.coords_projection, + ) + + if ax_title: + ax.set_title(ax_title, size=15) + return im + + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( datastore: BaseRegularGridDatastore, @@ -85,8 +119,6 @@ def plot_prediction( else: vmin, vmax = vrange - extent = datastore.get_xy_extent("state") - fig, axes = plt.subplots( 1, 2, @@ -96,16 +128,12 @@ def plot_prediction( # Plot pred and target for ax, da in zip(axes, (da_target, da_prediction)): - ax.coastlines() # Add coastline outlines - da.plot.imshow( - ax=ax, - origin="lower", - x="x", - extent=extent, + im = plot_on_axis( + ax, + da, + datastore, vmin=vmin, vmax=vmax, - cmap="plasma", - transform=datastore.coords_projection, ) # Ticks and labels @@ -133,14 +161,11 @@ def plot_spatial_error( else: vmin, vmax = vrange - extent = datastore.get_xy_extent("state") - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, ) - ax.coastlines() # Add coastline outlines error_grid = ( error.reshape( [datastore.grid_shape_state.x, datastore.grid_shape_state.y] @@ -149,6 +174,7 @@ def plot_spatial_error( .numpy() ) + # TODO: This needs to be converted to DA and use plot_on_axis im = ax.imshow( error_grid, origin="lower", diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index f02cfbd4..ed67b6f7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -234,6 +234,7 @@ def __init__( self.da_boundary_mean = self.ds_boundary_stats.forcing_mean self.da_boundary_std = self.ds_boundary_stats.forcing_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time From 19672210c761805e1ef5b8ff63e1ca3c4458ef19 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 4 Nov 2024 18:18:34 +0100 Subject: [PATCH 75/90] Start building graphs with wmg --- neural_lam/build_graph.py | 153 ++++++++++++++++++++++++++++++++++ neural_lam/models/ar_model.py | 4 +- neural_lam/plot_graph.py | 6 +- neural_lam/utils.py | 25 +++++- pyproject.toml | 1 + 5 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 neural_lam/build_graph.py diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py new file mode 100644 index 00000000..034f82cd --- /dev/null +++ b/neural_lam/build_graph.py @@ -0,0 +1,153 @@ +# Standard library +import argparse +import os + +# Third-party +import numpy as np +import weather_model_graphs as wmg + +# Local +from . import config, utils + +WMG_ARCHETYPES = { + "keisler": wmg.create.archetype.create_keisler_graph, + "graphcast": wmg.create.archetype.create_graphcast_graph, + "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, +} + + +def main(): + parser = argparse.ArgumentParser( + description="Graph generation using WMG", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Inputs and outputs + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file", + ) + parser.add_argument( + "--output_dir", + type=str, + default="graphs", + help="Directory to save graph to", + ) + + # Graph structure + parser.add_argument( + "--archetype", + type=str, + default="keisler", + help="Archetype to use to create graph (keisler/graphcast/hierarchical)", + ) + parser.add_argument( + "--mesh_node_distance", + type=float, + default=3.0, + help="Distance between created mesh nodes", + ) + parser.add_argument( + "--level_refinement_factor", + type=float, + default=3, + help="Refinement factor between grid points and bottom level of mesh hierarchy", + ) + parser.add_argument( + "--max_num_levels", + type=int, + help="Limit multi-scale mesh to given number of levels, " + "from bottom up", + ) + parser.add_argument( + "--hierarchical", + action="store_true", + help="Generate hierarchical mesh graph (default: False)", + ) + args = parser.parse_args() + + # Load grid positions + config_loader = config.Config.from_file(args.data_config) + + coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy() + # (num_nodes_full, 2) + + # Construct mask + static_data = utils.load_static_data(config_loader.dataset.name) + decode_mask = np.concatenate( + ( + np.ones(static_data["grid_static_features"].shape[0], dtype=bool), + np.zeros( + static_data["boundary_static_features"].shape[0], dtype=bool + ), + ), + axis=0, + ) + + # Build graph + assert ( + args.archetype in WMG_ARCHETYPES + ), f"Unknown archetype: {args.archetype}" + archetype_create_func = WMG_ARCHETYPES[args.archetype] + + create_kwargs = { + "coords": coords, + "mesh_node_distance": args.mesh_node_distance, + "projection": None, + "decode_mask": decode_mask, + } + if args.archetype != "keisler": + # Add additional multi-level kwargs + create_kwargs.update( + { + "level_refinement_factor": args.level_refinement_factor, + "max_num_levels": args.max_num_levels, + } + ) + + graph = archetype_create_func(**create_kwargs) + graph_comp = wmg.split_graph_by_edge_attribute(graph, attr="component") + + print("Created graph:") + for name, subgraph in graph_comp.items(): + print(f"{name}: {subgraph}") + + # Save graph + os.makedirs(args.output_dir, exist_ok=True) + for component, graph in graph_comp.items(): + # TODO This is all hack, saving in wmg needs to be consistent with nl + if component == "m2m": + if args.archetype == "hierarchical": + # Split by direction + m2m_direction_comp = wmg.split_graph_by_edge_attribute( + graph, attr="direction" + ) + for direction, graph in m2m_direction_comp.items(): + wmg.save.to_pyg( + graph=graph, + name=f"mesh_{direction}", + list_from_attribute="level", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + list_from_attribute="dummy", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 95bd1154..0311e542 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -69,12 +69,12 @@ def __init__( static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - static_features_torch[self.boundary_mask.to(torch.bool), + static_features_torch[self.boundary_mask.to(torch.bool)], persistent=False, ) self.register_buffer( "boundary_static_features", - static_features_torch[self.interior_mask.to(torch.bool), + static_features_torch[self.interior_mask.to(torch.bool)], persistent=False, ) diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 999c8e53..9c1fc0ef 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -47,10 +47,6 @@ def main(): config_path=args.datastore_config_path ) - xy = datastore.get_xy("state", stacked=True) # (N_grid, 2) - pos_max = np.max(np.abs(xy)) - grid_pos = xy / pos_max # Divide by maximum coordinate - # Load graph data graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph) hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) @@ -65,6 +61,8 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] + # Extract values needed, turn to numpy + grid_pos = utils.get_reordered_grid_pos(datastore).numpy() # Add in z-dimension z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) grid_pos = np.concatenate( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 4a0752e4..baa55610 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -114,7 +114,7 @@ def loads_file(fn): # Load static node features mesh_static_features = loads_file( - "mesh_features.pt" + "m2m_node_features.pt" ) # List of (N_mesh[l], d_mesh_static) # Some checks for consistency @@ -241,3 +241,26 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +def get_reordered_grid_pos(datastore): + """ + Interior nodes first, then boundary + """ + xy_np = datastore.get_xy() # np, (num_grid, 2) + xy_torch = torch.tensor(xy_np, dtype=torch.float32) + + da_boundary_mask = datastore.boundary_mask + boundary_mask = torch.tensor( + da_boundary_mask.values, dtype=torch.bool + ) + interior_mask = torch.logical_not(boundary_mask) + + return torch.cat( + ( + xy_torch[interior_mask], + xy_torch[boundary_mask], + ), + dim=0, + ) + # (num_total_grid_nodes, 2) diff --git a/pyproject.toml b/pyproject.toml index f556ef6b..9607d1da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "parse>=1.20.2", "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", + "weather-model-graphs>=0.2.0" ] requires-python = ">=3.9" From cb74e3f05d808608a56be1b1de927aec5c73a848 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 11 Nov 2024 14:28:06 +0100 Subject: [PATCH 76/90] Change forward pass to concat according to enforced node ordering --- neural_lam/build_graph.py | 5 +++-- neural_lam/models/base_graph_model.py | 12 ++---------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py index 034f82cd..a0c675ac 100644 --- a/neural_lam/build_graph.py +++ b/neural_lam/build_graph.py @@ -117,7 +117,8 @@ def main(): # Save graph os.makedirs(args.output_dir, exist_ok=True) for component, graph in graph_comp.items(): - # TODO This is all hack, saving in wmg needs to be consistent with nl + # This seems like a bit of a hack, maybe better if saving in wmg + # was made consistent with nl if component == "m2m": if args.archetype == "hierarchical": # Split by direction @@ -136,7 +137,7 @@ def main(): wmg.save.to_pyg( graph=graph, name=component, - list_from_attribute="dummy", + list_from_attribute="dummy", # Note: Needed to output list edge_features=["len", "vdiff"], output_directory=args.output_dir, ) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 246cd93e..481353b4 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -150,16 +150,8 @@ def predict_step( mesh_emb = self.embedd_mesh_nodes() # Merge interior and boundary emb into input embedding - # TODO Can we enforce ordering in the graph creation process to make - # this just a concat instead? - input_emb = torch.zeros( - batch_size, - self.num_input_nodes, - grid_emb.shape[2], - device=grid_emb.device, - ) - input_emb[:, self.interior_mask] = grid_emb - input_emb[:, self.boundary_mask] = boundary_emb + # We enforce ordering (interior, boundary) of nodes + input_emb = torch.cat((grid_emb, boundary_emb), dim=1) # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch( From 9715ed8eb855254e4e628f3e38ab982a5878faf9 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 11 Nov 2024 18:13:40 +0100 Subject: [PATCH 77/90] wip to make tests pass --- neural_lam/build_graph.py | 35 ++++++----- neural_lam/interaction_net.py | 6 +- neural_lam/plot_graph.py | 109 +++++++++++++++++++++++++++------- neural_lam/utils.py | 30 ++++++++-- 4 files changed, 136 insertions(+), 44 deletions(-) diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py index a0c675ac..dcbff49d 100644 --- a/neural_lam/build_graph.py +++ b/neural_lam/build_graph.py @@ -16,7 +16,7 @@ } -def main(): +def main(input_args=None): parser = argparse.ArgumentParser( description="Graph generation using WMG", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -61,16 +61,12 @@ def main(): help="Limit multi-scale mesh to given number of levels, " "from bottom up", ) - parser.add_argument( - "--hierarchical", - action="store_true", - help="Generate hierarchical mesh graph (default: False)", - ) - args = parser.parse_args() + args = parser.parse_args(input_args) # Load grid positions config_loader = config.Config.from_file(args.data_config) + # TODO Do not get normalised positions coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy() # (num_nodes_full, 2) @@ -126,13 +122,24 @@ def main(): graph, attr="direction" ) for direction, graph in m2m_direction_comp.items(): - wmg.save.to_pyg( - graph=graph, - name=f"mesh_{direction}", - list_from_attribute="level", - edge_features=["len", "vdiff"], - output_directory=args.output_dir, - ) + if direction == "same": + # Name just m2m to be consistent with non-hierarchical + wmg.save.to_pyg( + graph=graph, + name="m2m", + list_from_attribute="level", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + else: + # up and down directions + wmg.save.to_pyg( + graph=graph, + name=f"mesh_{direction}", + list_from_attribute="levels", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) else: wmg.save.to_pyg( graph=graph, diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 2f45b03f..8b8c5c85 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -30,7 +30,8 @@ def __init__( """ Create a new InteractionNet - edge_index: (2,M), Edges in pyg format + edge_index: (2,M), Edges in pyg format, with boeth sender and receiver + node indices starting at 0 input_dim: Dimensionality of input representations, for both nodes and edges update_edges: If new edge representations should be computed @@ -52,8 +53,7 @@ def __init__( # Default to input dim if not explicitly given hidden_dim = input_dim - # Make both sender and receiver indices of edge_index start at 0 - edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] + # any edge_index used here must start sender and rec. nodes at index 0 # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 edge_index[0] = ( diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 9c1fc0ef..f621d201 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -69,11 +69,9 @@ def main(): (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 ) - # List of edges to plot, (edge_index, color, line_width, label) - edge_plot_list = [ - (m2g_edge_index.numpy(), "black", 0.4, "M2G"), - (g2m_edge_index.numpy(), "black", 0.4, "G2M"), - ] + # List of edges to plot, (edge_index, from_pos, to_pos, color, + # line_width, label) + edge_plot_list = [] # Mesh positioning and edges to plot differ if we have a hierarchical graph if hierarchical: @@ -92,24 +90,80 @@ def main(): mesh_static_features, start=1 ) ] - mesh_pos = np.concatenate(mesh_level_pos, axis=0) + all_mesh_pos = np.concatenate(mesh_level_pos, axis=0) + grid_con_mesh_pos = mesh_level_pos[0] # Add inter-level mesh edges edge_plot_list += [ - (level_ei.numpy(), "blue", 1, f"M2M Level {level}") - for level, level_ei in enumerate(m2m_edge_index) + ( + level_ei.numpy(), + level_pos, + level_pos, + "blue", + 1, + f"M2M Level {level}", + ) + for level, (level_ei, level_pos) in enumerate( + zip(m2m_edge_index, mesh_level_pos) + ) ] # Add intra-level mesh edges - up_edges_ei = np.concatenate( - [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 + up_edges_ei = [ + level_up_ei.numpy() for level_up_ei in mesh_up_edge_index + ] + down_edges_ei = [ + level_down_ei.numpy() for level_down_ei in mesh_down_edge_index + ] + # Add up edges + for level_i, (up_ei, from_pos, to_pos) in enumerate( + zip(up_edges_ei, mesh_level_pos[:-1], mesh_level_pos[1:]) + ): + edge_plot_list.append( + ( + up_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh up {level_i}-{level_i+1}", + ) + ) + # Add down edges + for level_i, (down_ei, from_pos, to_pos) in enumerate( + zip(down_edges_ei, mesh_level_pos[1:], mesh_level_pos[:-1]) + ): + edge_plot_list.append( + ( + down_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh down {level_i+1}-{level_i}", + ) + ) + + edge_plot_list.append( + ( + m2g_edge_index.numpy(), + grid_con_mesh_pos, + grid_pos, + "black", + 0.4, + "M2G", + ) ) - down_edges_ei = np.concatenate( - [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], - axis=1, + edge_plot_list.append( + ( + g2m_edge_index.numpy(), + grid_pos, + grid_con_mesh_pos, + "black", + 0.4, + "G2M", + ) ) - edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) - edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) mesh_node_size = 2.5 else: @@ -123,21 +177,30 @@ def main(): (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 ) - edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) + edge_plot_list.append( + (m2m_edge_index.numpy(), mesh_pos, mesh_pos, "blue", 1, "M2M") + ) + edge_plot_list.append( + (m2g_edge_index.numpy(), mesh_pos, grid_pos, "black", 0.4, "M2G") + ) + edge_plot_list.append( + (g2m_edge_index.numpy(), grid_pos, mesh_pos, "black", 0.4, "G2M") + ) - # All node positions in one array - node_pos = np.concatenate((mesh_pos, grid_pos), axis=0) + all_mesh_pos = mesh_pos # Add edges data_objs = [] for ( ei, + from_pos, + to_pos, col, width, label, ) in edge_plot_list: - edge_start = node_pos[ei[0]] # (M, 2) - edge_end = node_pos[ei[1]] # (M, 2) + edge_start = from_pos[ei[0]] # (M, 2) + edge_end = to_pos[ei[1]] # (M, 2) n_edges = edge_start.shape[0] x_edges = np.stack( @@ -174,9 +237,9 @@ def main(): ) data_objs.append( go.Scatter3d( - x=mesh_pos[:, 0], - y=mesh_pos[:, 1], - z=mesh_pos[:, 2], + x=all_mesh_pos[:, 0], + y=all_mesh_pos[:, 1], + z=all_mesh_pos[:, 2], mode="markers", marker={"color": "blue", "size": mesh_node_size}, name="Mesh nodes", diff --git a/neural_lam/utils.py b/neural_lam/utils.py index baa55610..c0207123 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -33,6 +33,13 @@ def __iter__(self): return (self[i] for i in range(len(self))) +def zero_index_edge_index(edge_index): + """ + Make both sender and receiver indices of edge_index start at 0 + """ + return edge_index - edge_index.min(dim=1, keepdim=True)[0] + + def load_graph(graph_dir_path, device="cpu"): """Load all tensors representing the graph from `graph_dir_path`. @@ -71,11 +78,13 @@ def load_graph(graph_dir_path, device="cpu"): - mesh_down_edge_index - g2m_features - m2g_features - - m2m_features + - m2m_node_features - mesh_up_features - mesh_down_features - mesh_static_features + + Load all tensors representing the graph """ def loads_file(fn): @@ -87,11 +96,16 @@ def loads_file(fn): # Load edges (edge_index) m2m_edge_index = BufferList( - loads_file("m2m_edge_index.pt"), persistent=False + [zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")], + persistent=False, ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) + # Change first indices to 0 + g2m_edge_index = zero_index_edge_index(g2m_edge_index) + m2g_edge_index = zero_index_edge_index(m2g_edge_index) + n_levels = len(m2m_edge_index) hierarchical = n_levels > 1 # Nor just single level mesh graph @@ -128,10 +142,18 @@ def loads_file(fn): if hierarchical: # Load up and down edges and features mesh_up_edge_index = BufferList( - loads_file("mesh_up_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_up_edge_index.pt") + ], + persistent=False, ) # List of (2, M_up[l]) mesh_down_edge_index = BufferList( - loads_file("mesh_down_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_down_edge_index.pt") + ], + persistent=False, ) # List of (2, M_down[l]) mesh_up_features = loads_file( From 336fba9c6843838533222f7ea0618d59f44ff427 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 12 Nov 2024 14:02:22 +0100 Subject: [PATCH 78/90] Fix edge index manipulation to make training work again --- neural_lam/interaction_net.py | 4 ++-- neural_lam/models/base_graph_model.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 8b8c5c85..417aae1a 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -56,8 +56,8 @@ def __init__( # any edge_index used here must start sender and rec. nodes at index 0 # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 - edge_index[0] = ( - edge_index[0] + self.num_rec + edge_index = torch.stack( + (edge_index[0] + self.num_rec, edge_index[1]), dim=0 ) # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 481353b4..c0b21a75 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -19,8 +19,6 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): super().__init__(args, config=config, datastore=datastore) # Load graph with static features - # NOTE: (IMPORTANT!) mesh nodes MUST have the first - # num_mesh_nodes indices, graph_dir_path = datastore.root_path / "graph" / args.graph self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path From ce3ea6d7d44f5126ec088daec6665a65f04fe83b Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 12 Nov 2024 16:06:28 +0100 Subject: [PATCH 79/90] Work on fixing plotting functionality --- neural_lam/vis.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index f2775328..d744f542 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -73,6 +73,7 @@ def plot_on_axis( vmax=None, ax_title=None, cmap="plasma", + grid_limits=None ): """ Plot weather state on given axis @@ -82,7 +83,7 @@ def plot_on_axis( extent = datastore.get_xy_extent("state") - da.plot.imshow( + im = da.plot.imshow( ax=ax, origin="lower", x="x", @@ -95,6 +96,7 @@ def plot_on_axis( if ax_title: ax.set_title(ax_title, size=15) + return im @@ -173,6 +175,7 @@ def plot_spatial_error( .T.cpu() .numpy() ) + extent = datastore.get_xy_extent("state") # TODO: This needs to be converted to DA and use plot_on_axis im = ax.imshow( From a520505ceac8a584c9c7e6698c0a2c3911cb8fa2 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 13 Nov 2024 13:32:16 +0100 Subject: [PATCH 80/90] Linting --- neural_lam/build_graph.py | 8 +++++--- neural_lam/vis.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py index dcbff49d..c13dc629 100644 --- a/neural_lam/build_graph.py +++ b/neural_lam/build_graph.py @@ -41,7 +41,8 @@ def main(input_args=None): "--archetype", type=str, default="keisler", - help="Archetype to use to create graph (keisler/graphcast/hierarchical)", + help="Archetype to use to create graph " + "(keisler/graphcast/hierarchical)", ) parser.add_argument( "--mesh_node_distance", @@ -53,7 +54,8 @@ def main(input_args=None): "--level_refinement_factor", type=float, default=3, - help="Refinement factor between grid points and bottom level of mesh hierarchy", + help="Refinement factor between grid points and bottom level of " + "mesh hierarchy", ) parser.add_argument( "--max_num_levels", @@ -144,7 +146,7 @@ def main(input_args=None): wmg.save.to_pyg( graph=graph, name=component, - list_from_attribute="dummy", # Note: Needed to output list + list_from_attribute="dummy", # Note: Needed to output list edge_features=["len", "vdiff"], output_directory=args.output_dir, ) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index d744f542..7e7bbf42 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -73,7 +73,7 @@ def plot_on_axis( vmax=None, ax_title=None, cmap="plasma", - grid_limits=None + grid_limits=None, ): """ Plot weather state on given axis From 793e6c04436a3afd842bd051bde8705d405829e5 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 13 Nov 2024 13:54:53 +0100 Subject: [PATCH 81/90] Add optional separate grid embedder for boundary --- neural_lam/models/ar_model.py | 2 +- neural_lam/models/base_graph_model.py | 22 +++++++++++++++------- neural_lam/train_model.py | 7 +++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0311e542..c3870fbc 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -128,7 +128,7 @@ def __init__( ( self.num_boundary_nodes, - boundary_static_dim, # TODO Need for computation below + boundary_static_dim, # TODO Will need for computation below ) = self.boundary_static_features.shape self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes self.grid_dim = ( diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index c0b21a75..de8d87db 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -44,15 +44,23 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): # Define sub-models # Feature embedders for grid self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) - # TODO Optional separate embedder for boundary nodes - assert self.grid_dim == self.boundary_dim, ( - "Grid and boundary input dimension must be the same when using " - f"the same encoder, got grid_dim={self.grid_dim}, " - f"boundary_dim={self.boundary_dim}" - ) self.grid_embedder = utils.make_mlp( [self.grid_dim] + self.mlp_blueprint_end ) + # Optional separate embedder for boundary nodes + print(args.shared_grid_embedder) + if args.shared_grid_embedder: + assert self.grid_dim == self.boundary_dim, ( + "Grid and boundary input dimension must be the same when using " + f"the same embedder, got grid_dim={self.grid_dim}, " + f"boundary_dim={self.boundary_dim}" + ) + self.boundary_embedder = self.grid_embedder + else: + self.boundary_embedder = utils.make_mlp( + [self.boundary_dim] + self.mlp_blueprint_end + ) + self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -141,7 +149,7 @@ def predict_step( # Embed all features grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) - boundary_emb = self.grid_embedder(boundary_features) + boundary_emb = self.boundary_embedder(boundary_features) # (B, num_boundary_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 2a61e86c..7e0b47c6 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -116,6 +116,13 @@ def main(input_args=None): "output dimensions " "(default: False (no))", ) + parser.add_argument( + "--shared_grid_embedder", + action="store_true", # Default to separate embedders + help="If the same embedder MLP should be used for interior and boundary" + " grid nodes. Note that this requires the same dimensionality for " + "both kinds of grid inputs. (default: False (no))", + ) # Training options parser.add_argument( From 3515460cfc6959a26f2de47b9770c2d738c94d73 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 13 Nov 2024 13:58:39 +0100 Subject: [PATCH 82/90] Make new graph creation script main and only one --- ...ld_graph.py => build_rectangular_graph.py} | 2 +- neural_lam/create_graph.py | 610 ------------------ 2 files changed, 1 insertion(+), 611 deletions(-) rename neural_lam/{build_graph.py => build_rectangular_graph.py} (98%) delete mode 100644 neural_lam/create_graph.py diff --git a/neural_lam/build_graph.py b/neural_lam/build_rectangular_graph.py similarity index 98% rename from neural_lam/build_graph.py rename to neural_lam/build_rectangular_graph.py index c13dc629..84585540 100644 --- a/neural_lam/build_graph.py +++ b/neural_lam/build_rectangular_graph.py @@ -18,7 +18,7 @@ def main(input_args=None): parser = argparse.ArgumentParser( - description="Graph generation using WMG", + description="Rectangular graph generation using weather-models-graph", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py deleted file mode 100644 index ef979be3..00000000 --- a/neural_lam/create_graph.py +++ /dev/null @@ -1,610 +0,0 @@ -# Standard library -import os -from argparse import ArgumentParser - -# Third-party -import matplotlib -import matplotlib.pyplot as plt -import networkx -import numpy as np -import scipy.spatial -import torch -import torch_geometric as pyg -from torch_geometric.utils.convert import from_networkx - -# Local -from .config import load_config_and_datastore -from .datastore.base import BaseRegularGridDatastore - - -def plot_graph(graph, title=None): - fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H - edge_index = graph.edge_index - pos = graph.pos - - # Fix for re-indexed edge indices only containing mesh nodes at - # higher levels in hierarchy - edge_index = edge_index - edge_index.min() - - if pyg.utils.is_undirected(edge_index): - # Keep only 1 direction of edge_index - edge_index = edge_index[:, edge_index[0] < edge_index[1]] # (2, M/2) - # TODO: indicate direction of directed edges - - # Move all to cpu and numpy, compute (in)-degrees - degrees = ( - pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() - ) - edge_index = edge_index.cpu().numpy() - pos = pos.cpu().numpy() - - # Plot edges - from_pos = pos[edge_index[0]] # (M/2, 2) - to_pos = pos[edge_index[1]] # (M/2, 2) - edge_lines = np.stack((from_pos, to_pos), axis=1) - axis.add_collection( - matplotlib.collections.LineCollection( - edge_lines, lw=0.4, colors="black", zorder=1 - ) - ) - - # Plot nodes - node_scatter = axis.scatter( - pos[:, 0], - pos[:, 1], - c=degrees, - s=3, - marker="o", - zorder=2, - cmap="viridis", - clim=None, - ) - - plt.colorbar(node_scatter, aspect=50) - - if title is not None: - axis.set_title(title) - - return fig, axis - - -def sort_nodes_internally(nx_graph): - # For some reason the networkx .nodes() return list can not be sorted, - # but this is the ordering used by pyg when converting. - # This function fixes this. - H = networkx.DiGraph() - H.add_nodes_from(sorted(nx_graph.nodes(data=True))) - H.add_edges_from(nx_graph.edges(data=True)) - return H - - -def save_edges(graph, name, base_path): - torch.save( - graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt") - ) - edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) # Save as float32 - torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) - - -def save_edges_list(graphs, name, base_path): - torch.save( - [graph.edge_index for graph in graphs], - os.path.join(base_path, f"{name}_edge_index.pt"), - ) - edge_features = [ - torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) - for graph in graphs - ] # Save as float32 - torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) - - -def from_networkx_with_start_index(nx_graph, start_index): - pyg_graph = from_networkx(nx_graph) - pyg_graph.edge_index += start_index - return pyg_graph - - -def mk_2d_graph(xy, nx, ny): - xm, xM = np.amin(xy[:, :, 0][:, 0]), np.amax(xy[:, :, 0][:, 0]) - ym, yM = np.amin(xy[:, :, 1][0, :]), np.amax(xy[:, :, 1][0, :]) - - # avoid nodes on border - dx = (xM - xm) / nx - dy = (yM - ym) / ny - lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) - ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) - - mg = np.meshgrid(lx, ly, indexing="ij") # Use 'ij' indexing for (Nx,Ny) - g = networkx.grid_2d_graph(len(lx), len(ly)) - - for node in g.nodes: - g.nodes[node]["pos"] = np.array([mg[0][node], mg[1][node]]) - - # add diagonal edges - g.add_edges_from( - [((x, y), (x + 1, y + 1)) for y in range(ny - 1) for x in range(nx - 1)] - + [ - ((x + 1, y), (x, y + 1)) - for y in range(ny - 1) - for x in range(nx - 1) - ] - ) - - # turn into directed graph - dg = networkx.DiGraph(g) - for u, v in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) - dg.edges[u, v]["len"] = d - dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] - dg.add_edge(v, u) - dg.edges[v, u]["len"] = d - dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] - - return dg - - -def prepend_node_index(graph, new_index): - # Relabel node indices in graph, insert (graph_level, i, j) - ijk = [tuple((new_index,) + x) for x in graph.nodes] - to_mapping = dict(zip(graph.nodes, ijk)) - return networkx.relabel_nodes(graph, to_mapping, copy=True) - - -def create_graph( - graph_dir_path: str, - xy: np.ndarray, - n_max_levels: int, - hierarchical: bool, - create_plot: bool, -): - """ - Create graph components from `xy` grid coordinates and store in - `graph_dir_path`. - - Creates the following files for all graphs: - - g2m_edge_index.pt [2, N_g2m_edges] - - g2m_features.pt [N_g2m_edges, d_features] - - m2g_edge_index.pt [2, N_m2m_edges] - - m2g_features.pt [N_m2m_edges, d_features] - - m2m_edge_index.pt list of [2, N_m2m_edges_level], length==n_levels - - m2m_features.pt list of [N_m2m_edges_level, d_features], - length==n_levels - - mesh_features.pt list of [N_mesh_nodes_level, d_mesh_static], - length==n_levels - - where - d_features: - number of features per edge (currently d_features==3, for - edge-length, x and y) - N_g2m_edges: - number of edges in the graph from grid-to-mesh - N_m2g_edges: - number of edges in the graph from mesh-to-grid - N_m2m_edges_level: - number of edges in the graph from mesh-to-mesh at a given level - (list index corresponds to the level) - d_mesh_static: - number of static features per mesh node (currently - d_mesh_static==2, for x and y) - N_mesh_nodes_level: - number of nodes in the mesh at a given level - - And in addition for hierarchical graphs: - - mesh_up_edge_index.pt - list of [2, N_mesh_updown_edges_level], length==n_levels-1 - - mesh_up_features.pt - list of [N_mesh_updown_edges_level, d_features], length==n_levels-1 - - mesh_down_edge_index.pt - list of [2, N_mesh_updown_edges_level], length==n_levels-1 - - mesh_down_features.pt - list of [N_mesh_updown_edges_level, d_features], length==n_levels-1 - - where N_mesh_updown_edges_level is the number of edges in the graph from - mesh-to-mesh between two consecutive levels (list index corresponds index - of lower level) - - - Parameters - ---------- - graph_dir_path : str - Path to store the graph components. - xy : np.ndarray - Grid coordinates, expected to be of shape (Nx, Ny, 2). - n_max_levels : int - Limit multi-scale mesh to given number of levels, from bottom up - (default: None (no limit)). - hierarchical : bool - Generate hierarchical mesh graph (default: False). - create_plot : bool - If graphs should be plotted during generation (default: False). - - Returns - ------- - None - - """ - os.makedirs(graph_dir_path, exist_ok=True) - - print(f"Writing graph components to {graph_dir_path}") - - grid_xy = torch.tensor(xy) - pos_max = torch.max(torch.abs(grid_xy)) - - # - # Mesh - # - - # graph geometry - nx = 3 # number of children =nx**2 - nlev = int(np.log(max(xy.shape[:2])) / np.log(nx)) - nleaf = nx**nlev # leaves at the bottom = nleaf**2 - - mesh_levels = nlev - 1 - if n_max_levels: - # Limit the levels in mesh graph - mesh_levels = min(mesh_levels, n_max_levels) - - # print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}") - - # multi resolution tree levels - G = [] - for lev in range(1, mesh_levels + 1): - n = int(nleaf / (nx**lev)) - g = mk_2d_graph(xy, n, n) - if create_plot: - plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}") - plt.show() - - G.append(g) - - if hierarchical: - # Relabel nodes of each level with level index first - G = [ - prepend_node_index(graph, level_i) - for level_i, graph in enumerate(G) - ] - - num_nodes_level = np.array([len(g_level.nodes) for g_level in G]) - # First node index in each level in the hierarchical graph - first_index_level = np.concatenate( - (np.zeros(1, dtype=int), np.cumsum(num_nodes_level[:-1])) - ) - - # Create inter-level mesh edges - up_graphs = [] - down_graphs = [] - for from_level, to_level, G_from, G_to, start_index in zip( - range(1, mesh_levels), - range(0, mesh_levels - 1), - G[1:], - G[:-1], - first_index_level[: mesh_levels - 1], - ): - # start out from graph at from level - G_down = G_from.copy() - G_down.clear_edges() - G_down = networkx.DiGraph(G_down) - - # Add nodes of to level - G_down.add_nodes_from(G_to.nodes(data=True)) - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - v_to_list = list(G_to.nodes) - v_from_list = list(G_from.nodes) - v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) - kdt_m = scipy.spatial.KDTree(v_from_xy) - - # add edges from mesh to grid - for v in v_to_list: - # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] - u = v_from_list[neigh_idx] - - # add edge from mesh to grid - G_down.add_edge(u, v) - d = np.sqrt( - np.sum( - (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 - ) - ) - G_down.edges[u, v]["len"] = d - G_down.edges[u, v]["vdiff"] = ( - G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] - ) - - # relabel nodes to integers (sorted) - G_down_int = networkx.convert_node_labels_to_integers( - G_down, first_label=start_index, ordering="sorted" - ) # Issue with sorting here - G_down_int = sort_nodes_internally(G_down_int) - pyg_down = from_networkx_with_start_index(G_down_int, start_index) - - # Create up graph, invert downwards edges - up_edges = torch.stack( - (pyg_down.edge_index[1], pyg_down.edge_index[0]), dim=0 - ) - pyg_up = pyg_down.clone() - pyg_up.edge_index = up_edges - - up_graphs.append(pyg_up) - down_graphs.append(pyg_down) - - if create_plot: - plot_graph( - pyg_down, title=f"Down graph, {from_level} -> {to_level}" - ) - plt.show() - - plot_graph( - pyg_down, title=f"Up graph, {to_level} -> {from_level}" - ) - plt.show() - - # Save up and down edges - save_edges_list(up_graphs, "mesh_up", graph_dir_path) - save_edges_list(down_graphs, "mesh_down", graph_dir_path) - - # Extract intra-level edges for m2m - m2m_graphs = [ - from_networkx_with_start_index( - networkx.convert_node_labels_to_integers( - level_graph, first_label=start_index, ordering="sorted" - ), - start_index, - ) - for level_graph, start_index in zip(G, first_index_level) - ] - - mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs] - - # For use in g2m and m2g - G_bottom_mesh = G[0] - - joint_mesh_graph = networkx.union_all([graph for graph in G]) - all_mesh_nodes = joint_mesh_graph.nodes(data=True) - - else: - # combine all levels to one graph - G_tot = G[0] - for lev in range(1, len(G)): - nodes = list(G[lev - 1].nodes) - n = int(np.sqrt(len(nodes))) - ij = ( - np.array(nodes) - .reshape((n, n, 2))[1::nx, 1::nx, :] - .reshape(int(n / nx) ** 2, 2) - ) - ij = [tuple(x) for x in ij] - G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) - G_tot = networkx.compose(G_tot, G[lev]) - - # Relabel mesh nodes to start with 0 - G_tot = prepend_node_index(G_tot, 0) - - # relabel nodes to integers (sorted) - G_int = networkx.convert_node_labels_to_integers( - G_tot, first_label=0, ordering="sorted" - ) - - # Graph to use in g2m and m2g - G_bottom_mesh = G_tot - all_mesh_nodes = G_tot.nodes(data=True) - - # export the nx graph to PyTorch geometric - pyg_m2m = from_networkx(G_int) - m2m_graphs = [pyg_m2m] - mesh_pos = [pyg_m2m.pos.to(torch.float32)] - - if create_plot: - plot_graph(pyg_m2m, title="Mesh-to-mesh") - plt.show() - - # Save m2m edges - save_edges_list(m2m_graphs, "m2m", graph_dir_path) - - # Divide mesh node pos by max coordinate of grid cell - mesh_pos = [pos / pos_max for pos in mesh_pos] - - # Save mesh positions - torch.save( - mesh_pos, os.path.join(graph_dir_path, "mesh_features.pt") - ) # mesh pos, in float32 - - # - # Grid2Mesh - # - - # radius within which grid nodes are associated with a mesh node - # (in terms of mesh distance) - DM_SCALE = 0.67 - - # mesh nodes on lowest level - vm = G_bottom_mesh.nodes - vm_xy = np.array([xy for _, xy in vm.data("pos")]) - # distance between mesh nodes - dm = np.sqrt( - np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) - ) - - # grid nodes - Nx, Ny = xy.shape[:2] - - G_grid = networkx.grid_2d_graph(Ny, Nx) - G_grid.clear_edges() - - # vg features (only pos introduced here) - for node in G_grid.nodes: - # pos is in feature but here explicit for convenience - G_grid.nodes[node]["pos"] = xy[ - node[1], node[0] - ] # xy is already (Nx,Ny,2) - - # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes - # (i,j) and impose sorting order such that vm are the first nodes - G_grid = prepend_node_index(G_grid, 1000) - - # build kd tree for grid point pos - # order in vg_list should be same as in vg_xy - vg_list = list(G_grid.nodes) - vg_xy = np.array( - [xy[node[2], node[1]] for node in vg_list] - ) # xy is already (Nx,Ny,2) - kdt_g = scipy.spatial.KDTree(vg_xy) - - # now add (all) mesh nodes, include features (pos) - G_grid.add_nodes_from(all_mesh_nodes) - - # Re-create graph with sorted node indices - # Need to do sorting of nodes this way for indices to map correctly to pyg - G_g2m = networkx.Graph() - G_g2m.add_nodes_from(sorted(G_grid.nodes(data=True))) - - # turn into directed graph - G_g2m = networkx.DiGraph(G_g2m) - - # add edges - for v in vm: - # find neighbours (index to vg_xy) - neigh_idxs = kdt_g.query_ball_point(vm[v]["pos"], dm * DM_SCALE) - for i in neigh_idxs: - u = vg_list[i] - # add edge from grid to mesh - G_g2m.add_edge(u, v) - d = np.sqrt( - np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2) - ) - G_g2m.edges[u, v]["len"] = d - G_g2m.edges[u, v]["vdiff"] = ( - G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] - ) - - pyg_g2m = from_networkx(G_g2m) - - if create_plot: - plot_graph(pyg_g2m, title="Grid-to-mesh") - plt.show() - - # - # Mesh2Grid - # - - # start out from Grid2Mesh and then replace edges - G_m2g = G_g2m.copy() - G_m2g.clear_edges() - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - vm_list = list(vm) - kdt_m = scipy.spatial.KDTree(vm_xy) - - # add edges from mesh to grid - for v in vg_list: - # find 4 nearest neighbours (index to vm_xy) - neigh_idxs = kdt_m.query(G_m2g.nodes[v]["pos"], 4)[1] - for i in neigh_idxs: - u = vm_list[i] - # add edge from mesh to grid - G_m2g.add_edge(u, v) - d = np.sqrt( - np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2) - ) - G_m2g.edges[u, v]["len"] = d - G_m2g.edges[u, v]["vdiff"] = ( - G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] - ) - - # relabel nodes to integers (sorted) - G_m2g_int = networkx.convert_node_labels_to_integers( - G_m2g, first_label=0, ordering="sorted" - ) - pyg_m2g = from_networkx(G_m2g_int) - - if create_plot: - plot_graph(pyg_m2g, title="Mesh-to-grid") - plt.show() - - # Save g2m and m2g everything - # g2m - save_edges(pyg_g2m, "g2m", graph_dir_path) - # m2g - save_edges(pyg_m2g, "m2g", graph_dir_path) - - -def create_graph_from_datastore( - datastore: BaseRegularGridDatastore, - output_root_path: str, - n_max_levels: int = None, - hierarchical: bool = False, - create_plot: bool = False, -): - if isinstance(datastore, BaseRegularGridDatastore): - xy = datastore.get_xy(category="state", stacked=False) - else: - raise NotImplementedError( - "Only graph creation for BaseRegularGridDatastore is supported" - ) - - create_graph( - graph_dir_path=output_root_path, - xy=xy, - n_max_levels=n_max_levels, - hierarchical=hierarchical, - create_plot=create_plot, - ) - - -def cli(input_args=None): - parser = ArgumentParser(description="Graph generation arguments") - parser.add_argument( - "--config_path", - type=str, - help="Path to neural-lam configuration file", - ) - parser.add_argument( - "--name", - type=str, - default="multiscale", - help="Name to save graph as (default: multiscale)", - ) - parser.add_argument( - "--plot", - action="store_true", - help="If graphs should be plotted during generation " - "(default: False)", - ) - parser.add_argument( - "--levels", - type=int, - help="Limit multi-scale mesh to given number of levels, " - "from bottom up (default: None (no limit))", - ) - parser.add_argument( - "--hierarchical", - action="store_true", - help="Generate hierarchical mesh graph (default: False)", - ) - args = parser.parse_args(input_args) - - assert ( - args.config_path is not None - ), "Specify your config with --config_path" - - # Load neural-lam configuration and datastore to use - _, datastore = load_config_and_datastore(config_path=args.config_path) - - create_graph_from_datastore( - datastore=datastore, - output_root_path=os.path.join(datastore.root_path, "graph", args.name), - n_max_levels=args.levels, - hierarchical=args.hierarchical, - create_plot=args.plot, - ) - - -if __name__ == "__main__": - cli() From 05d91f1c0065428ae2e572b6c59cffd61a2c5e1d Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 13 Nov 2024 14:06:48 +0100 Subject: [PATCH 83/90] Fix some typos and forgot code --- neural_lam/interaction_net.py | 2 +- neural_lam/models/base_graph_model.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 417aae1a..46223b88 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -30,7 +30,7 @@ def __init__( """ Create a new InteractionNet - edge_index: (2,M), Edges in pyg format, with boeth sender and receiver + edge_index: (2,M), Edges in pyg format, with both sender and receiver node indices starting at 0 input_dim: Dimensionality of input representations, for both nodes and edges diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index de8d87db..d5b39bf7 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -48,7 +48,6 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): [self.grid_dim] + self.mlp_blueprint_end ) # Optional separate embedder for boundary nodes - print(args.shared_grid_embedder) if args.shared_grid_embedder: assert self.grid_dim == self.boundary_dim, ( "Grid and boundary input dimension must be the same when using " From 3eba43c2f8072e755a39ac9cb8de9a50320bd578 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 27 Nov 2024 12:23:37 +0100 Subject: [PATCH 84/90] Correct handling of node indices for m2g when using decode_mask --- neural_lam/build_rectangular_graph.py | 5 ++-- neural_lam/utils.py | 38 ++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/neural_lam/build_rectangular_graph.py b/neural_lam/build_rectangular_graph.py index 84585540..7c3151f4 100644 --- a/neural_lam/build_rectangular_graph.py +++ b/neural_lam/build_rectangular_graph.py @@ -93,8 +93,8 @@ def main(input_args=None): create_kwargs = { "coords": coords, "mesh_node_distance": args.mesh_node_distance, - "projection": None, "decode_mask": decode_mask, + "return_components": True, } if args.archetype != "keisler": # Add additional multi-level kwargs @@ -105,8 +105,7 @@ def main(input_args=None): } ) - graph = archetype_create_func(**create_kwargs) - graph_comp = wmg.split_graph_by_edge_attribute(graph, attr="component") + graph_comp = archetype_create_func(**create_kwargs) print("Created graph:") for name, subgraph in graph_comp.items(): diff --git a/neural_lam/utils.py b/neural_lam/utils.py index c0207123..6241c1ca 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -94,6 +94,11 @@ def loads_file(fn): weights_only=True, ) + # Load static node features + mesh_static_features = loads_file( + "m2m_node_features.pt" + ) # List of (N_mesh[l], d_mesh_static) + # Load edges (edge_index) m2m_edge_index = BufferList( [zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")], @@ -104,7 +109,33 @@ def loads_file(fn): # Change first indices to 0 g2m_edge_index = zero_index_edge_index(g2m_edge_index) - m2g_edge_index = zero_index_edge_index(m2g_edge_index) + # m2g has to be handled specially as not all mesh nodes might be indexed in + # m2g_edge_index + m2g_min_indices = m2g_edge_index.min(dim=1, keepdim=True)[0] + if m2g_min_indices[0] < m2g_min_indices[1]: + # mesh has the first indices + # Number of mesh nodes at level that connects to grid + num_mesh_nodes = mesh_static_features[0].shape[0] + + m2g_edge_index = torch.stack( + ( + m2g_edge_index[0], + m2g_edge_index[1] - num_mesh_nodes, + ), + dim=0, + ) + else: + # grid (interior) has the first indices + # NOTE: Below works, but would be good with a better way to get this + num_interior_nodes = m2g_edge_index[1].max() + 1 + + m2g_edge_index = torch.stack( + ( + m2g_edge_index[0] - num_interior_nodes, + m2g_edge_index[1], + ), + dim=0, + ) n_levels = len(m2m_edge_index) hierarchical = n_levels > 1 # Nor just single level mesh graph @@ -126,11 +157,6 @@ def loads_file(fn): g2m_features = g2m_features / longest_edge m2g_features = m2g_features / longest_edge - # Load static node features - mesh_static_features = loads_file( - "m2m_node_features.pt" - ) # List of (N_mesh[l], d_mesh_static) - # Some checks for consistency assert ( len(m2m_features) == n_levels From f1b73592f09db5e177f57f3f44a188ccaa129250 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Thu, 28 Nov 2024 10:53:17 +0100 Subject: [PATCH 85/90] Linting and bugfixes --- neural_lam/models/ar_model.py | 8 +++++++- neural_lam/utils.py | 6 ++---- neural_lam/vis.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index c3870fbc..f8eef057 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -281,7 +281,13 @@ def common_step(self, batch): (B, pred_steps, num_boundary_nodes, d_boundary_forcing), where index 0 corresponds to index 1 of init_states """ - (init_states, target_states, forcing_features, _, batch_times) = batch + ( + init_states, + target_states, + forcing, + boundary_forcing, + batch_times, + ) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing, boundary_forcing diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 6241c1ca..8e43fa40 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -295,13 +295,11 @@ def get_reordered_grid_pos(datastore): """ Interior nodes first, then boundary """ - xy_np = datastore.get_xy() # np, (num_grid, 2) + xy_np = datastore.get_xy() # np, (num_grid, 2) xy_torch = torch.tensor(xy_np, dtype=torch.float32) da_boundary_mask = datastore.boundary_mask - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.bool - ) + boundary_mask = torch.tensor(da_boundary_mask.values, dtype=torch.bool) interior_mask = torch.logical_not(boundary_mask) return torch.cat( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 7e7bbf42..10b84fb7 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -130,7 +130,7 @@ def plot_prediction( # Plot pred and target for ax, da in zip(axes, (da_target, da_prediction)): - im = plot_on_axis( + plot_on_axis( ax, da, datastore, From fa6c9e3071627769112ed2fb3f872e4f019ea62f Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 2 Dec 2024 14:31:01 +0100 Subject: [PATCH 86/90] Make graph creation and plotting work with datastores --- neural_lam/build_rectangular_graph.py | 47 ++++++++++++++-------- neural_lam/datastore/mdp.py | 2 +- neural_lam/datastore/npyfilesmeps/store.py | 2 +- neural_lam/models/base_graph_model.py | 2 +- neural_lam/plot_graph.py | 41 ++++++++++--------- neural_lam/utils.py | 2 +- 6 files changed, 57 insertions(+), 39 deletions(-) diff --git a/neural_lam/build_rectangular_graph.py b/neural_lam/build_rectangular_graph.py index 7c3151f4..e4570397 100644 --- a/neural_lam/build_rectangular_graph.py +++ b/neural_lam/build_rectangular_graph.py @@ -7,7 +7,8 @@ import weather_model_graphs as wmg # Local -from . import config, utils +from . import utils +from .config import load_config_and_datastore WMG_ARCHETYPES = { "keisler": wmg.create.archetype.create_keisler_graph, @@ -24,10 +25,14 @@ def main(input_args=None): # Inputs and outputs parser.add_argument( - "--data_config", + "--config_path", type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file", + help="Path to the configuration for neural-lam", + ) + parser.add_argument( + "--name", + type=str, + help="Name to save graph as (default: multiscale)", ) parser.add_argument( "--output_dir", @@ -65,21 +70,28 @@ def main(input_args=None): ) args = parser.parse_args(input_args) - # Load grid positions - config_loader = config.Config.from_file(args.data_config) + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + assert ( + args.name is not None + ), "Specify the name to save graph as with --name" + _, datastore = load_config_and_datastore(config_path=args.config_path) + + # Load grid positions # TODO Do not get normalised positions - coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy() + coords = utils.get_reordered_grid_pos(datastore).numpy() # (num_nodes_full, 2) # Construct mask - static_data = utils.load_static_data(config_loader.dataset.name) + num_full_grid = coords.shape[0] + num_boundary = datastore.boundary_mask.to_numpy().sum() + num_interior = num_full_grid - num_boundary decode_mask = np.concatenate( ( - np.ones(static_data["grid_static_features"].shape[0], dtype=bool), - np.zeros( - static_data["boundary_static_features"].shape[0], dtype=bool - ), + np.ones(num_interior, dtype=bool), + np.zeros(num_boundary, dtype=bool), ), axis=0, ) @@ -112,7 +124,8 @@ def main(input_args=None): print(f"{name}: {subgraph}") # Save graph - os.makedirs(args.output_dir, exist_ok=True) + graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name) + os.makedirs(graph_dir_path, exist_ok=True) for component, graph in graph_comp.items(): # This seems like a bit of a hack, maybe better if saving in wmg # was made consistent with nl @@ -130,7 +143,7 @@ def main(input_args=None): name="m2m", list_from_attribute="level", edge_features=["len", "vdiff"], - output_directory=args.output_dir, + output_directory=graph_dir_path, ) else: # up and down directions @@ -139,7 +152,7 @@ def main(input_args=None): name=f"mesh_{direction}", list_from_attribute="levels", edge_features=["len", "vdiff"], - output_directory=args.output_dir, + output_directory=graph_dir_path, ) else: wmg.save.to_pyg( @@ -147,14 +160,14 @@ def main(input_args=None): name=component, list_from_attribute="dummy", # Note: Needed to output list edge_features=["len", "vdiff"], - output_directory=args.output_dir, + output_directory=graph_dir_path, ) else: wmg.save.to_pyg( graph=graph, name=component, edge_features=["len", "vdiff"], - output_directory=args.output_dir, + output_directory=graph_dir_path, ) diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 809bbdb8..0ed92129 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -394,7 +394,7 @@ def grid_shape_state(self): assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) - def get_xy(self, category: str, stacked: bool) -> ndarray: + def get_xy(self, category: str, stacked: bool = True) -> ndarray: """Return the x, y coordinates of the dataset. Parameters diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 24349e7e..dfb0b9c9 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -614,7 +614,7 @@ def get_vars_long_names(self, category: str) -> List[str]: def get_num_data_vars(self, category: str) -> int: return len(self.get_vars_names(category=category)) - def get_xy(self, category: str, stacked: bool) -> np.ndarray: + def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: """Return the x, y coordinates of the dataset. Parameters diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index d5b39bf7..1feec63d 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -19,7 +19,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): super().__init__(args, config=config, datastore=datastore) # Load graph with static features - graph_dir_path = datastore.root_path / "graph" / args.graph + graph_dir_path = datastore.root_path / "graphs" / args.graph self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path ) diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index f621d201..9d04f3e3 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -11,25 +11,20 @@ from . import utils from .config import load_config_and_datastore -MESH_HEIGHT = 0.1 -MESH_LEVEL_DIST = 0.2 -GRID_HEIGHT = 0 - def main(): """Plot graph structure in 3D using plotly.""" parser = ArgumentParser(description="Plot graph") parser.add_argument( - "--datastore_config_path", + "--config_path", type=str, - default="tests/datastore_examples/mdp/config.yaml", - help="Path for the datastore config", + help="Path to the configuration for neural-lam", ) parser.add_argument( - "--graph", + "--name", type=str, default="multiscale", - help="Graph to plot (default: multiscale)", + help="Name of saved graph to plot (default: multiscale)", ) parser.add_argument( "--save", @@ -43,12 +38,15 @@ def main(): ) args = parser.parse_args() - _, datastore = load_config_and_datastore( - config_path=args.datastore_config_path - ) + + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + + _, datastore = load_config_and_datastore(config_path=args.config_path) # Load graph data - graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph) + graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name) hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], @@ -63,12 +61,18 @@ def main(): # Extract values needed, turn to numpy grid_pos = utils.get_reordered_grid_pos(datastore).numpy() - # Add in z-dimension - z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) + grid_scale = np.ptp(grid_pos) + + # Add in z-dimension for grid + z_grid = np.zeros((grid_pos.shape[0],)) # Grid sits at z=0 grid_pos = np.concatenate( (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 ) + # Compute z-coordinate height of mesh nodes + mesh_base_height = 0.05 * grid_scale + mesh_level_height_diff = 0.1 * grid_scale + # List of edges to plot, (edge_index, from_pos, to_pos, color, # line_width, label) edge_plot_list = [] @@ -79,8 +83,8 @@ def main(): np.concatenate( ( level_static_features.numpy(), - MESH_HEIGHT - + MESH_LEVEL_DIST + mesh_base_height + + mesh_level_height_diff * height_level * np.ones((level_static_features.shape[0], 1)), ), @@ -170,7 +174,8 @@ def main(): mesh_pos = mesh_static_features.numpy() mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy() - z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees + # 1% higher per neighbor + z_mesh = (1 + 0.01 * mesh_degrees) * mesh_base_height mesh_node_size = mesh_degrees / 2 mesh_pos = np.concatenate( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 8e43fa40..6f910cee 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -295,7 +295,7 @@ def get_reordered_grid_pos(datastore): """ Interior nodes first, then boundary """ - xy_np = datastore.get_xy() # np, (num_grid, 2) + xy_np = datastore.get_xy("state") # np, (num_grid, 2) xy_torch = torch.tensor(xy_np, dtype=torch.float32) da_boundary_mask = datastore.boundary_mask From 4d853843e61dc64614a7415aa72c8f0ecbc63441 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 2 Dec 2024 16:25:01 +0100 Subject: [PATCH 87/90] Fix graph loading and boundary mask --- neural_lam/build_rectangular_graph.py | 10 ++++++---- neural_lam/models/ar_model.py | 4 ++-- neural_lam/models/base_graph_model.py | 2 +- neural_lam/plot_graph.py | 6 ++++-- neural_lam/train_model.py | 2 +- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/neural_lam/build_rectangular_graph.py b/neural_lam/build_rectangular_graph.py index e4570397..df7f8ba8 100644 --- a/neural_lam/build_rectangular_graph.py +++ b/neural_lam/build_rectangular_graph.py @@ -30,7 +30,7 @@ def main(input_args=None): help="Path to the configuration for neural-lam", ) parser.add_argument( - "--name", + "--graph_name", type=str, help="Name to save graph as (default: multiscale)", ) @@ -74,8 +74,8 @@ def main(input_args=None): args.config_path is not None ), "Specify your config with --config_path" assert ( - args.name is not None - ), "Specify the name to save graph as with --name" + args.graph_name is not None + ), "Specify the name to save graph as with --graph_name" _, datastore = load_config_and_datastore(config_path=args.config_path) @@ -124,7 +124,9 @@ def main(input_args=None): print(f"{name}: {subgraph}") # Save graph - graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name) + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) os.makedirs(graph_dir_path, exist_ok=True) for component, graph in graph_comp.items(): # This seems like a bit of a hack, maybe better if saving in wmg diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f8eef057..1a24136f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -69,12 +69,12 @@ def __init__( static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - static_features_torch[self.boundary_mask.to(torch.bool)], + static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], persistent=False, ) self.register_buffer( "boundary_static_features", - static_features_torch[self.interior_mask.to(torch.bool)], + static_features_torch[self.interior_mask[:, 0].to(torch.bool)], persistent=False, ) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 1feec63d..52f2d7a3 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -19,7 +19,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): super().__init__(args, config=config, datastore=datastore) # Load graph with static features - graph_dir_path = datastore.root_path / "graphs" / args.graph + graph_dir_path = datastore.root_path / "graphs" / args.graph_name self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path ) diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 9d04f3e3..11bd795a 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -21,7 +21,7 @@ def main(): help="Path to the configuration for neural-lam", ) parser.add_argument( - "--name", + "--graph_name", type=str, default="multiscale", help="Name of saved graph to plot (default: multiscale)", @@ -46,7 +46,9 @@ def main(): _, datastore = load_config_and_datastore(config_path=args.config_path) # Load graph data - graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name) + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 7e0b47c6..3c2dbece 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -78,7 +78,7 @@ def main(input_args=None): # Model architecture parser.add_argument( - "--graph", + "--graph_name", type=str, default="multiscale", help="Graph to load and use in graph-based model " From 9edfec37af343be4675de402a1b7d11f7731ddd7 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 2 Dec 2024 16:33:41 +0100 Subject: [PATCH 88/90] Fix boundary masking bug for static features --- neural_lam/models/ar_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 1a24136f..ceadb856 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -69,12 +69,12 @@ def __init__( static_features_torch = torch.tensor(arr_static, dtype=torch.float32) self.register_buffer( "grid_static_features", - static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], + static_features_torch[self.interior_mask[:, 0].to(torch.bool)], persistent=False, ) self.register_buffer( "boundary_static_features", - static_features_torch[self.interior_mask[:, 0].to(torch.bool)], + static_features_torch[self.boundary_mask[:, 0].to(torch.bool)], persistent=False, ) From 6e1c53ca70678c559d4a37324221105beb799cea Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 3 Dec 2024 11:35:16 +0100 Subject: [PATCH 89/90] Add flag making boundary forcing optional in models --- neural_lam/models/ar_model.py | 27 +++++++---- neural_lam/models/base_graph_model.py | 66 ++++++++++++++++----------- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index ceadb856..ef766113 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -48,6 +48,10 @@ def __init__( num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps + # TODO: Set based on existing of boundary forcing datastore + # TODO: Adjust what is stored here based on self.boundary_forced + self.boundary_forced = False + # Set up boundary mask boundary_mask = torch.tensor( da_boundary_mask.values, dtype=torch.float32 @@ -125,12 +129,6 @@ def __init__( self.num_grid_nodes, grid_static_dim, ) = self.grid_static_features.shape - - ( - self.num_boundary_nodes, - boundary_static_dim, # TODO Will need for computation below - ) = self.boundary_static_features.shape - self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim @@ -139,7 +137,16 @@ def __init__( * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) - self.boundary_dim = self.grid_dim # TODO Compute separately + if self.boundary_forced: + self.boundary_dim = self.grid_dim # TODO Compute separately + ( + self.num_boundary_nodes, + boundary_static_dim, # TODO Will need for computation below + ) = self.boundary_static_features.shape + self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes + else: + # Only interior grid nodes + self.num_input_nodes = self.num_grid_nodes # Instantiate loss function self.loss = metrics.get_metric(args.loss) @@ -241,7 +248,11 @@ def unroll_prediction(self, init_states, forcing, boundary_forcing): for i in range(pred_steps): forcing_step = forcing[:, i] - boundary_forcing_step = boundary_forcing[:, i] + + if self.boundary_forced: + boundary_forcing_step = boundary_forcing[:, i] + else: + boundary_forcing_step = None pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing_step, boundary_forcing_step diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 52f2d7a3..61c1a681 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -47,18 +47,22 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): self.grid_embedder = utils.make_mlp( [self.grid_dim] + self.mlp_blueprint_end ) - # Optional separate embedder for boundary nodes - if args.shared_grid_embedder: - assert self.grid_dim == self.boundary_dim, ( - "Grid and boundary input dimension must be the same when using " - f"the same embedder, got grid_dim={self.grid_dim}, " - f"boundary_dim={self.boundary_dim}" - ) - self.boundary_embedder = self.grid_embedder - else: - self.boundary_embedder = utils.make_mlp( - [self.boundary_dim] + self.mlp_blueprint_end - ) + + if self.boundary_forced: + # Define embedder for boundary nodes + # Optional separate embedder for boundary nodes + if args.shared_grid_embedder: + assert self.grid_dim == self.boundary_dim, ( + "Grid and boundary input dimension must " + "be the same when using " + f"the same embedder, got grid_dim={self.grid_dim}, " + f"boundary_dim={self.boundary_dim}" + ) + self.boundary_embedder = self.grid_embedder + else: + self.boundary_embedder = utils.make_mlp( + [self.boundary_dim] + self.mlp_blueprint_end + ) self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -136,27 +140,37 @@ def predict_step( ), dim=-1, ) - # Create full boundary node features of shape - # (B, num_boundary_nodes, boundary_dim) - boundary_features = torch.cat( - ( - boundary_forcing, - self.expand_to_batch(self.boundary_static_features, batch_size), - ), - dim=-1, - ) + + if self.boundary_forced: + # Create full boundary node features of shape + # (B, num_boundary_nodes, boundary_dim) + boundary_features = torch.cat( + ( + boundary_forcing, + self.expand_to_batch( + self.boundary_static_features, batch_size + ), + ), + dim=-1, + ) + + # Embed boundary features + boundary_emb = self.boundary_embedder(boundary_features) + # (B, num_boundary_nodes, d_h) # Embed all features grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) - boundary_emb = self.boundary_embedder(boundary_features) - # (B, num_boundary_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() - # Merge interior and boundary emb into input embedding - # We enforce ordering (interior, boundary) of nodes - input_emb = torch.cat((grid_emb, boundary_emb), dim=1) + if self.boundary_forced: + # Merge interior and boundary emb into input embedding + # We enforce ordering (interior, boundary) of nodes + input_emb = torch.cat((grid_emb, boundary_emb), dim=1) + else: + # Only maps from interior to mesh + input_emb = grid_emb # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch( From 4bcaa4b48c9a70a753599423b72dc5cd889cbd52 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 3 Dec 2024 11:58:34 +0100 Subject: [PATCH 90/90] Linting --- neural_lam/weather_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index ed67b6f7..f02cfbd4 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -234,7 +234,6 @@ def __init__( self.da_boundary_mean = self.ds_boundary_stats.forcing_mean self.da_boundary_std = self.ds_boundary_stats.forcing_std - def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time