diff --git a/.circleci/config.yml b/.circleci/config.yml index 917c5f0ab..7c8445030 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -227,7 +227,6 @@ jobs: path: test-results destination: test-results-decoder - trainer_test: <<: *defaults steps: @@ -319,6 +318,7 @@ jobs: name: Trainer1 tests when: always command: | + coverage run flood_forecast/trainer.py -p tests/anomaly_transformer.json coverage run flood_forecast/trainer.py -p tests/test_informer.json coverage run flood_forecast/trainer.py -p tests/test_iTransformer.json coverage run flood_forecast/trainer.py -p tests/tsmixer_test.json diff --git a/flood_forecast/deployment/inference.py b/flood_forecast/deployment/inference.py index 79347e558..a2cff9d8c 100644 --- a/flood_forecast/deployment/inference.py +++ b/flood_forecast/deployment/inference.py @@ -16,7 +16,7 @@ class InferenceMode(object): def __init__(self, forecast_steps: int, n_samp: int, model_params, csv_path: Union[str, pd.DataFrame], weight_path, wandb_proj: str = None, torch_script=False): - """Class to handle inference for models, + """Class to handle inference for models :param forecasts_steps: Number of time-steps to forecast (doesn't have to be hours) :type forecast_steps: int diff --git a/flood_forecast/meta_models/merging_model.py b/flood_forecast/meta_models/merging_model.py index ac0e05702..f3204212a 100644 --- a/flood_forecast/meta_models/merging_model.py +++ b/flood_forecast/meta_models/merging_model.py @@ -5,7 +5,7 @@ class MergingModel(torch.nn.Module): def __init__(self, method: str, other_params: Dict): - """A model meant to help merge meta-data with the temporal data + """A model meant to help merge meta-data with the temporal data. :param method: The method you want to use (Bilinear, Bilinear2, MultiAttn, Concat) :type method: str diff --git a/flood_forecast/model_dict_function.py b/flood_forecast/model_dict_function.py index 598df9bc7..1d29560c5 100644 --- a/flood_forecast/model_dict_function.py +++ b/flood_forecast/model_dict_function.py @@ -22,12 +22,13 @@ from flood_forecast.basic.d_n_linear import DLinear, NLinear from flood_forecast.transformer_xl.itransformer import ITransformer from flood_forecast.transformer_xl.cross_former import Crossformer as Crossformer10 +from flood_forecast.transformer_xl.anomaly_transformer import AnomalyTransformer from torchtsmixer import TSMixer from torchtsmixer import TSMixerExt """ -Utility dictionaries to map a string to a c class +Utility dictionaries to map a string to a class """ pytorch_model_dict = { "MultiAttnHeadSimple": MultiAttnHeadSimple, @@ -48,7 +49,8 @@ "NLinear": NLinear, "TSMixer": TSMixer, "TSMixerExt": TSMixerExt, - "ITransformer": ITransformer + "ITransformer": ITransformer, + "AnomalyTransformer": AnomalyTransformer } pytorch_criterion_dict = { diff --git a/flood_forecast/series_id_helper.py b/flood_forecast/series_id_helper.py index ed7fa0f09..9ed018dc7 100644 --- a/flood_forecast/series_id_helper.py +++ b/flood_forecast/series_id_helper.py @@ -32,7 +32,7 @@ def handle_csv_id_validation(src: Dict[int, torch.Tensor], trg: Dict[int, torch. :type src: Dict[int, torchd :param trg: _description_ :type trg: Dict[int, torch.Tensor] - :param model: _description_ + :param model: _description :type model: torch.nn.Module :param criterion: _description_ :type criterion: List @@ -42,7 +42,7 @@ def handle_csv_id_validation(src: Dict[int, torch.Tensor], trg: Dict[int, torch. :type n_targs: int, optional :param max_seq_len: _description_, defaults to 100 :type max_seq_len: int, optional - :return: Returns a dictionary of losses for each criterion + :return: Returns a dictionary of losses for each criterion. :rtype: Dict[str, float] """ scaled_crit = dict.fromkeys(criterion, 0) diff --git a/flood_forecast/temporal_decoding.py b/flood_forecast/temporal_decoding.py index 4a72aeccd..46d04e0dd 100644 --- a/flood_forecast/temporal_decoding.py +++ b/flood_forecast/temporal_decoding.py @@ -61,11 +61,12 @@ def decoding_function(model, src: torch.Tensor, trg: torch.Tensor, forecast_leng out = model(src, src_temp, filled_target, tar_temp[:, i:i + residual, :]) residual1 = forecast_length if i + forecast_length <= max_len else max_len % forecast_length out1[:, i: i + residual1, :n_target] = out[:, -residual1:, :] - # Need better variable names + # Need better variable names. filled_target1 = torch.zeros_like(filled_target[:, 0:forecast_length * 2, :]) if filled_target1.shape[1] == forecast_length * 2: - filled_target1[:, -forecast_length * 2:-forecast_length, :n_target] = out[:, -forecast_length:, :] + # always use n_target + filled_target1[:, -forecast_length * 2:-forecast_length, :n_target] = out[:, -forecast_length:, :n_target] filled_target = torch.cat((filled_target, filled_target1), dim=1) assert out1[0, 0, 0] != 0 assert out1[0, 0, 0] != 0 - return out1[:, -max_len:, :n_target] + return out1[:, -max_len:, :n_target] # [B, L, D] diff --git a/flood_forecast/transformer_xl/anomaly_transformer.py b/flood_forecast/transformer_xl/anomaly_transformer.py index 2550a7838..5d1c53e71 100644 --- a/flood_forecast/transformer_xl/anomaly_transformer.py +++ b/flood_forecast/transformer_xl/anomaly_transformer.py @@ -1,2 +1,91 @@ -class AnomalyTransformer(): - pass +import torch +import torch.nn as nn +import torch.nn.functional as F +from flood_forecast.transformer_xl.attn import AnomalyAttention, AttentionLayer +from flood_forecast.transformer_xl.data_embedding import DataEmbedding + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn, mask, sigma = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn, mask, sigma + + +class Encoder(nn.Module): + def __init__(self, attn_layers, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + series_list = [] + prior_list = [] + sigma_list = [] + for attn_layer in self.attn_layers: + x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask) + series_list.append(series) + prior_list.append(prior) + sigma_list.append(sigma) + + if self.norm is not None: + x = self.norm(x) + + return x, series_list, prior_list, sigma_list + + +class AnomalyTransformer(nn.Module): + def __init__(self, win_size, enc_in, c_out, d_model=512, n_heads=8, e_layers=3, d_ff=512, + dropout=0.0, activation='gelu', output_attention=True): + super(AnomalyTransformer, self).__init__() + self.output_attention = output_attention + + # Encoding + self.embedding = DataEmbedding(enc_in, d_model, dropout) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + AnomalyAttention(win_size, False, attention_dropout=dropout, output_attention=output_attention), + d_model, n_heads), + d_model, + d_ff, + dropout=dropout, + activation=activation + ) for l in range(e_layers) + ], + norm_layer=torch.nn.LayerNorm(d_model) + ) + + self.projection = nn.Linear(d_model, c_out, bias=True) + + def forward(self, x): + enc_out = self.embedding(x) + enc_out, series, prior, sigmas = self.encoder(enc_out) + enc_out = self.projection(enc_out) + + if self.output_attention: + return enc_out, series, prior, sigmas + else: + return enc_out # [B, L, D] \ No newline at end of file diff --git a/flood_forecast/transformer_xl/attn.py b/flood_forecast/transformer_xl/attn.py index bc7f6fdbf..62c12a39e 100644 --- a/flood_forecast/transformer_xl/attn.py +++ b/flood_forecast/transformer_xl/attn.py @@ -3,6 +3,7 @@ import numpy as np from math import sqrt from einops import rearrange +import math class TriangularCausalMask(): @@ -16,6 +17,49 @@ def mask(self): return self._mask +class AnomalyAttention(nn.Module): + def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=False): + super(AnomalyAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + window_size = win_size + self.distances = torch.zeros((window_size, window_size)).to(self.device) + for i in range(window_size): + for j in range(window_size): + self.distances[i][j] = abs(i - j) + + def forward(self, queries, keys, values, sigma, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + attn = scale * scores + + sigma = sigma.transpose(1, 2) # B L H -> B H L + window_size = attn.shape[-1] + sigma = torch.sigmoid(sigma * 5) + 1e-5 + sigma = torch.pow(3, sigma) - 1 + sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size) # B H L L + prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).to(self.device) + prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2)) + + series = self.dropout(torch.softmax(attn, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", series, values) + + if self.output_attention: + return (V.contiguous(), series, prior, sigma) + else: + return (V.contiguous(), None) + + class ProbMask(): def __init__(self, B, H, L, index, scores, device="cpu"): _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) @@ -69,6 +113,7 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): class FlashAttention(nn.Module): def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): super(FlashAttention, self).__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention @@ -83,9 +128,9 @@ def flash_attention_forward(self, Q, K, V, mask=None): l3 = torch.zeros(Q.shape[:-1])[..., None] m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF - O1 = O1.to(device='cuda') - l3 = l3.to(device='cuda') - m = m.to(device='cuda') + O1 = O1.to(device=self.device) + l3 = l3.to(device=self.device) + m = m.to(device=self.device) Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) KV_BLOCK_SIZE = BLOCK_SIZE diff --git a/tests/anomaly_transformer.json b/tests/anomaly_transformer.json new file mode 100644 index 000000000..c8fd2dc97 --- /dev/null +++ b/tests/anomaly_transformer.json @@ -0,0 +1,49 @@ +{ + "model_name": "AnomalyTransformer", + "model_type": "PyTorch", + "model_params": { + "win_size": 100, + "c_out": 3, + "enc_in": 3 + }, + "n_targets": 3, + "dataset_params": + { "class": "AutoEncoder", + "training_path": "tests/test_data/keag_small.csv", + "validation_path": "tests/test_data/keag_small.csv", + "test_path": "tests/test_data/keag_small.csv", + "batch_size":4, + "forecast_history": 100, + "train_end": 200, + "valid_start":301, + "valid_end": 401, + "relevant_cols": ["cfs", "precip", "temp"], + "scaler": "StandardScaler", + "interpolate": false + }, + "training_params": + { + "criterion":"MSE", + "optimizer": "Adam", + "lr": 0.03, + "epochs": 1, + "batch_size":4, + "optim_params": + { + } + }, + "GCS": false, + + "wandb": { + "name": "flood_forecast_circleci", + "project": "repo-flood_forecast", + "tags": ["dummy_run", "circleci", "anomaly"] + }, + "metrics":["MSE"], + + "inference_params":{ + "hours_to_forecast":1 + + } +} + \ No newline at end of file diff --git a/tests/classification_test.json b/tests/classification_test.json index b36b1860b..7e58df48a 100644 --- a/tests/classification_test.json +++ b/tests/classification_test.json @@ -36,7 +36,7 @@ "reduction": "sum"}, "optim_params": {}, - "lr": 0.3, + "lr": 0.03, "epochs": 1, "batch_size":4 }, diff --git a/tests/cross_former.json b/tests/cross_former.json index eb5c1bb6c..0837252b8 100644 --- a/tests/cross_former.json +++ b/tests/cross_former.json @@ -37,7 +37,7 @@ { }, - "lr": 0.03, + "lr": 0.003, "epochs": 1, "batch_size":4 diff --git a/tests/custom_encode.json b/tests/custom_encode.json index 353910ab3..53d199fc4 100644 --- a/tests/custom_encode.json +++ b/tests/custom_encode.json @@ -36,7 +36,7 @@ { }, - "lr": 0.3, + "lr": 0.03, "epochs": 30, "batch_size":4 diff --git a/tests/test_informer.py b/tests/test_informer.py index 4d672855f..82a2f559c 100644 --- a/tests/test_informer.py +++ b/tests/test_informer.py @@ -50,7 +50,7 @@ def test_temporal_loader(self): self.assertEqual(result[0][1].shape[1], 4) self.assertEqual(result[0][0].shape[1], 3) self.assertEqual(result[0][1].shape[0], 5) - # Test output right order + # Test output right order. This is a bit of a hacky test. temporal_src_embd = result[0][1] second = temporal_src_embd[2, :] self.assertEqual(second[0], 5) @@ -60,6 +60,7 @@ def test_temporal_loader(self): d = DataEmbedding(3, 128) embedding = d(result[0][0].unsqueeze(0), temporal_src_embd.unsqueeze(0)) self.assertEqual(embedding.shape[2], 128) + # Until fixed """ i = Informer(3, 3, 3, 5, 5, out_len=4, factor=1) r0 = result[0][0].unsqueeze(0) @@ -127,6 +128,7 @@ def test_decoding_3(self): self.assertEqual(d.shape[1], 336) def test_t_loade2(self): + # a. s_wargs = { "file_path": "tests/test_data/keag_small.csv", "forecast_history": 39,