Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Implementation of NormalizedReluBounding for all the types of normali…
Browse files Browse the repository at this point in the history
…zations.
  • Loading branch information
lzampier committed Dec 16, 2024
1 parent 0deb66b commit dfae71e
Showing 1 changed file with 102 additions and 0 deletions.
102 changes: 102 additions & 0 deletions src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from abc import ABC
from abc import abstractmethod
from typing import Optional

import torch
from torch import nn
Expand All @@ -30,12 +31,28 @@ def __init__(
*,
variables: list[str],
name_to_index: dict,
statistics: Optional[dict] = None,
name_to_index_stats: Optional[dict] = None,
) -> None:
"""Initializes the bounding strategy.
Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
statistics : dict, optional
A dictionary containing the statistics of the variables.
name_to_index_stats : dict, optional
A dictionary mapping the variable names to their corresponding indices in the statistics dictionary
"""
super().__init__()

self.name_to_index = name_to_index
self.variables = variables
self.data_index = self._create_index(variables=self.variables)
self.statistics = statistics
self.name_to_index_stats = name_to_index_stats

def _create_index(self, variables: list[str]) -> InputTensorIndex:
return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only
Expand Down Expand Up @@ -63,7 +80,92 @@ class ReluBounding(BaseBounding):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index])
return x


class NormalizedReluBounding(BaseBounding):
"""Bounding variable with a ReLU activation and customizable normalized thresholds."""

def __init__(
self,
*,
variables: list[str],
name_to_index: dict,
min_val: list[float],
normalizer: list[str],
statistics: dict,
name_to_index_stats: dict,
) -> None:
"""Initializes the NormalizedReluBounding with the specified parameters.
Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
statistics : dict
A dictionary containing the statistics of the variables (mean, std, min, max, etc.).
min_val : list[float]
The minimum values for the ReLU activation. It should be given in the same order as the variables.
normalizer : list[str]
A list of normalization types to apply, one per variable. Options: 'mean-std', 'min-max', 'max', 'std'.
name_to_index_stats : dict
A dictionary mapping the variable names to their corresponding indices in the statistics dictionary.
"""
super().__init__(
variables=variables,
name_to_index=name_to_index,
statistics=statistics,
name_to_index_stats=name_to_index_stats,
)
self.min_val = min_val
self.normalizer = normalizer

# Validate normalizer input
if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer):
raise ValueError("Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding.")
if len(self.normalizer) != len(variables):
raise ValueError("The length of the normalizer list must match the number of variables in NormalizedReluBounding.")
if len(self.min_val) != len(variables):
raise ValueError("The length of the min_val list must match the number of variables in NormalizedReluBounding.")

self.norm_min_val = torch.zeros(len(variables))
for ii, variable in enumerate(variables):
stat_index = self.name_to_index_stats[variable]
if self.normalizer[ii] == "mean-std":
mean = self.statistics["mean"][stat_index]
std = self.statistics["stdev"][stat_index]
self.norm_min_val[ii] = (min_val[ii] - mean) / std
elif self.normalizer[ii] == "min-max":
min_stat = self.statistics["min"][stat_index]
max_stat = self.statistics["max"][stat_index]
self.norm_min_val[ii] = (min_val[ii] - min_stat) / (max_stat - min_stat)
elif self.normalizer[ii] == "max":
max_stat = self.statistics["max"][stat_index]
self.norm_min_val[ii] = min_val[ii] / max_stat
elif self.normalizer[ii] == "std":
std = self.statistics["stdev"][stat_index]
self.norm_min_val[ii] = min_val[ii] / std

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ReLU activation with the normalized minimum values to the input tensor.
Parameters
----------
x : torch.Tensor
The input tensor to process.
Returns
-------
torch.Tensor
The processed tensor with bounding applied.
"""
self.norm_min_val = self.norm_min_val.to(x.device)
x[..., self.data_index] = (
torch.nn.functional.relu(x[..., self.data_index] - self.norm_min_val)
+ self.norm_min_val
)
return x

class HardtanhBounding(BaseBounding):
"""Initializes the bounding with specified minimum and maximum values for bounding.
Expand Down

0 comments on commit dfae71e

Please sign in to comment.