diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py index b168591d..c0e8e459 100644 --- a/src/anemoi/models/layers/bounding.py +++ b/src/anemoi/models/layers/bounding.py @@ -11,6 +11,7 @@ from abc import ABC from abc import abstractmethod +from typing import Optional import torch from torch import nn @@ -30,12 +31,28 @@ def __init__( *, variables: list[str], name_to_index: dict, + statistics: Optional[dict] = None, + name_to_index_stats: Optional[dict] = None, ) -> None: + """Initializes the bounding strategy. + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + statistics : dict, optional + A dictionary containing the statistics of the variables. + name_to_index_stats : dict, optional + A dictionary mapping the variable names to their corresponding indices in the statistics dictionary + """ super().__init__() self.name_to_index = name_to_index self.variables = variables self.data_index = self._create_index(variables=self.variables) + self.statistics = statistics + self.name_to_index_stats = name_to_index_stats def _create_index(self, variables: list[str]) -> InputTensorIndex: return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only @@ -63,7 +80,92 @@ class ReluBounding(BaseBounding): def forward(self, x: torch.Tensor) -> torch.Tensor: x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) return x + + +class NormalizedReluBounding(BaseBounding): + """Bounding variable with a ReLU activation and customizable normalized thresholds.""" + def __init__( + self, + *, + variables: list[str], + name_to_index: dict, + min_val: list[float], + normalizer: list[str], + statistics: dict, + name_to_index_stats: dict, + ) -> None: + """Initializes the NormalizedReluBounding with the specified parameters. + + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + statistics : dict + A dictionary containing the statistics of the variables (mean, std, min, max, etc.). + min_val : list[float] + The minimum values for the ReLU activation. It should be given in the same order as the variables. + normalizer : list[str] + A list of normalization types to apply, one per variable. Options: 'mean-std', 'min-max', 'max', 'std'. + name_to_index_stats : dict + A dictionary mapping the variable names to their corresponding indices in the statistics dictionary. + """ + super().__init__( + variables=variables, + name_to_index=name_to_index, + statistics=statistics, + name_to_index_stats=name_to_index_stats, + ) + self.min_val = min_val + self.normalizer = normalizer + + # Validate normalizer input + if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer): + raise ValueError("Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding.") + if len(self.normalizer) != len(variables): + raise ValueError("The length of the normalizer list must match the number of variables in NormalizedReluBounding.") + if len(self.min_val) != len(variables): + raise ValueError("The length of the min_val list must match the number of variables in NormalizedReluBounding.") + + self.norm_min_val = torch.zeros(len(variables)) + for ii, variable in enumerate(variables): + stat_index = self.name_to_index_stats[variable] + if self.normalizer[ii] == "mean-std": + mean = self.statistics["mean"][stat_index] + std = self.statistics["stdev"][stat_index] + self.norm_min_val[ii] = (min_val[ii] - mean) / std + elif self.normalizer[ii] == "min-max": + min_stat = self.statistics["min"][stat_index] + max_stat = self.statistics["max"][stat_index] + self.norm_min_val[ii] = (min_val[ii] - min_stat) / (max_stat - min_stat) + elif self.normalizer[ii] == "max": + max_stat = self.statistics["max"][stat_index] + self.norm_min_val[ii] = min_val[ii] / max_stat + elif self.normalizer[ii] == "std": + std = self.statistics["stdev"][stat_index] + self.norm_min_val[ii] = min_val[ii] / std + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the ReLU activation with the normalized minimum values to the input tensor. + + Parameters + ---------- + x : torch.Tensor + The input tensor to process. + + Returns + ------- + torch.Tensor + The processed tensor with bounding applied. + """ + self.norm_min_val = self.norm_min_val.to(x.device) + x[..., self.data_index] = ( + torch.nn.functional.relu(x[..., self.data_index] - self.norm_min_val) + + self.norm_min_val + ) + return x class HardtanhBounding(BaseBounding): """Initializes the bounding with specified minimum and maximum values for bounding.