Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anomaly Transformer (Training Code) #736

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ jobs:
path: test-results
destination: test-results-decoder


trainer_test:
<<: *defaults
steps:
Expand Down Expand Up @@ -319,6 +318,7 @@ jobs:
name: Trainer1 tests
when: always
command: |
coverage run flood_forecast/trainer.py -p tests/anomaly_transformer.json
coverage run flood_forecast/trainer.py -p tests/test_informer.json
coverage run flood_forecast/trainer.py -p tests/test_iTransformer.json
coverage run flood_forecast/trainer.py -p tests/tsmixer_test.json
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/deployment/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class InferenceMode(object):
def __init__(self, forecast_steps: int, n_samp: int, model_params, csv_path: Union[str, pd.DataFrame], weight_path,
wandb_proj: str = None, torch_script=False):
"""Class to handle inference for models,
"""Class to handle inference for models

:param forecasts_steps: Number of time-steps to forecast (doesn't have to be hours)
:type forecast_steps: int
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/meta_models/merging_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class MergingModel(torch.nn.Module):
def __init__(self, method: str, other_params: Dict):
"""A model meant to help merge meta-data with the temporal data
"""A model meant to help merge meta-data with the temporal data.

:param method: The method you want to use (Bilinear, Bilinear2, MultiAttn, Concat)
:type method: str
Expand Down
6 changes: 4 additions & 2 deletions flood_forecast/model_dict_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from flood_forecast.basic.d_n_linear import DLinear, NLinear
from flood_forecast.transformer_xl.itransformer import ITransformer
from flood_forecast.transformer_xl.cross_former import Crossformer as Crossformer10
from flood_forecast.transformer_xl.anomaly_transformer import AnomalyTransformer
from torchtsmixer import TSMixer
from torchtsmixer import TSMixerExt


"""
Utility dictionaries to map a string to a c class
Utility dictionaries to map a string to a class
"""
pytorch_model_dict = {
"MultiAttnHeadSimple": MultiAttnHeadSimple,
Expand All @@ -48,7 +49,8 @@
"NLinear": NLinear,
"TSMixer": TSMixer,
"TSMixerExt": TSMixerExt,
"ITransformer": ITransformer
"ITransformer": ITransformer,
"AnomalyTransformer": AnomalyTransformer
}

pytorch_criterion_dict = {
Expand Down
4 changes: 2 additions & 2 deletions flood_forecast/series_id_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def handle_csv_id_validation(src: Dict[int, torch.Tensor], trg: Dict[int, torch.
:type src: Dict[int, torchd
:param trg: _description_
:type trg: Dict[int, torch.Tensor]
:param model: _description_
:param model: _description
:type model: torch.nn.Module
:param criterion: _description_
:type criterion: List
Expand All @@ -42,7 +42,7 @@ def handle_csv_id_validation(src: Dict[int, torch.Tensor], trg: Dict[int, torch.
:type n_targs: int, optional
:param max_seq_len: _description_, defaults to 100
:type max_seq_len: int, optional
:return: Returns a dictionary of losses for each criterion
:return: Returns a dictionary of losses for each criterion.
:rtype: Dict[str, float]
"""
scaled_crit = dict.fromkeys(criterion, 0)
Expand Down
7 changes: 4 additions & 3 deletions flood_forecast/temporal_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ def decoding_function(model, src: torch.Tensor, trg: torch.Tensor, forecast_leng
out = model(src, src_temp, filled_target, tar_temp[:, i:i + residual, :])
residual1 = forecast_length if i + forecast_length <= max_len else max_len % forecast_length
out1[:, i: i + residual1, :n_target] = out[:, -residual1:, :]
# Need better variable names
# Need better variable names.
filled_target1 = torch.zeros_like(filled_target[:, 0:forecast_length * 2, :])
if filled_target1.shape[1] == forecast_length * 2:
filled_target1[:, -forecast_length * 2:-forecast_length, :n_target] = out[:, -forecast_length:, :]
# always use n_target
filled_target1[:, -forecast_length * 2:-forecast_length, :n_target] = out[:, -forecast_length:, :n_target]
filled_target = torch.cat((filled_target, filled_target1), dim=1)
assert out1[0, 0, 0] != 0
assert out1[0, 0, 0] != 0
return out1[:, -max_len:, :n_target]
return out1[:, -max_len:, :n_target] # [B, L, D]
93 changes: 91 additions & 2 deletions flood_forecast/transformer_xl/anomaly_transformer.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,91 @@
class AnomalyTransformer():
pass
import torch
import torch.nn as nn
import torch.nn.functional as F
from flood_forecast.transformer_xl.attn import AnomalyAttention, AttentionLayer
from flood_forecast.transformer_xl.data_embedding import DataEmbedding


class EncoderLayer(nn.Module):
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu

def forward(self, x, attn_mask=None):
new_x, attn, mask, sigma = self.attention(
x, x, x,
attn_mask=attn_mask
)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))

return self.norm2(x + y), attn, mask, sigma


class Encoder(nn.Module):
def __init__(self, attn_layers, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer

def forward(self, x, attn_mask=None):
# x [B, L, D]
series_list = []
prior_list = []
sigma_list = []
for attn_layer in self.attn_layers:
x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask)
series_list.append(series)
prior_list.append(prior)
sigma_list.append(sigma)

if self.norm is not None:
x = self.norm(x)

return x, series_list, prior_list, sigma_list


class AnomalyTransformer(nn.Module):
def __init__(self, win_size, enc_in, c_out, d_model=512, n_heads=8, e_layers=3, d_ff=512,
dropout=0.0, activation='gelu', output_attention=True):
super(AnomalyTransformer, self).__init__()
self.output_attention = output_attention

# Encoding
self.embedding = DataEmbedding(enc_in, d_model, dropout)

# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
AnomalyAttention(win_size, False, attention_dropout=dropout, output_attention=output_attention),
d_model, n_heads),
d_model,
d_ff,
dropout=dropout,
activation=activation
) for l in range(e_layers)
],
norm_layer=torch.nn.LayerNorm(d_model)
)

self.projection = nn.Linear(d_model, c_out, bias=True)

def forward(self, x):
enc_out = self.embedding(x)
enc_out, series, prior, sigmas = self.encoder(enc_out)
enc_out = self.projection(enc_out)

if self.output_attention:
return enc_out, series, prior, sigmas
else:
return enc_out # [B, L, D]
51 changes: 48 additions & 3 deletions flood_forecast/transformer_xl/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from math import sqrt
from einops import rearrange
import math


class TriangularCausalMask():
Expand All @@ -16,6 +17,49 @@ def mask(self):
return self._mask


class AnomalyAttention(nn.Module):
def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=False):
super(AnomalyAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
window_size = win_size
self.distances = torch.zeros((window_size, window_size)).to(self.device)
for i in range(window_size):
for j in range(window_size):
self.distances[i][j] = abs(i - j)

def forward(self, queries, keys, values, sigma, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)

scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = scale * scores

sigma = sigma.transpose(1, 2) # B L H -> B H L
window_size = attn.shape[-1]
sigma = torch.sigmoid(sigma * 5) + 1e-5
sigma = torch.pow(3, sigma) - 1
sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size) # B H L L
prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).to(self.device)
prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2))

series = self.dropout(torch.softmax(attn, dim=-1))
V = torch.einsum("bhls,bshd->blhd", series, values)

if self.output_attention:
return (V.contiguous(), series, prior, sigma)
else:
return (V.contiguous(), None)


class ProbMask():
def __init__(self, B, H, L, index, scores, device="cpu"):
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
Expand Down Expand Up @@ -69,6 +113,7 @@ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
class FlashAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FlashAttention, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
Expand All @@ -83,9 +128,9 @@ def flash_attention_forward(self, Q, K, V, mask=None):
l3 = torch.zeros(Q.shape[:-1])[..., None]
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF

O1 = O1.to(device='cuda')
l3 = l3.to(device='cuda')
m = m.to(device='cuda')
O1 = O1.to(device=self.device)
l3 = l3.to(device=self.device)
m = m.to(device=self.device)

Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
KV_BLOCK_SIZE = BLOCK_SIZE
Expand Down
49 changes: 49 additions & 0 deletions tests/anomaly_transformer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"model_name": "AnomalyTransformer",
"model_type": "PyTorch",
"model_params": {
"win_size": 100,
"c_out": 3,
"enc_in": 3
},
"n_targets": 3,
"dataset_params":
{ "class": "AutoEncoder",
"training_path": "tests/test_data/keag_small.csv",
"validation_path": "tests/test_data/keag_small.csv",
"test_path": "tests/test_data/keag_small.csv",
"batch_size":4,
"forecast_history": 100,
"train_end": 200,
"valid_start":301,
"valid_end": 401,
"relevant_cols": ["cfs", "precip", "temp"],
"scaler": "StandardScaler",
"interpolate": false
},
"training_params":
{
"criterion":"MSE",
"optimizer": "Adam",
"lr": 0.03,
"epochs": 1,
"batch_size":4,
"optim_params":
{
}
},
"GCS": false,

"wandb": {
"name": "flood_forecast_circleci",
"project": "repo-flood_forecast",
"tags": ["dummy_run", "circleci", "anomaly"]
},
"metrics":["MSE"],

"inference_params":{
"hours_to_forecast":1

}
}

2 changes: 1 addition & 1 deletion tests/classification_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"reduction": "sum"},
"optim_params":
{},
"lr": 0.3,
"lr": 0.03,
"epochs": 1,
"batch_size":4
},
Expand Down
2 changes: 1 addition & 1 deletion tests/cross_former.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
{

},
"lr": 0.03,
"lr": 0.003,
"epochs": 1,
"batch_size":4

Expand Down
2 changes: 1 addition & 1 deletion tests/custom_encode.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
{

},
"lr": 0.3,
"lr": 0.03,
"epochs": 30,
"batch_size":4

Expand Down
4 changes: 3 additions & 1 deletion tests/test_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_temporal_loader(self):
self.assertEqual(result[0][1].shape[1], 4)
self.assertEqual(result[0][0].shape[1], 3)
self.assertEqual(result[0][1].shape[0], 5)
# Test output right order
# Test output right order. This is a bit of a hacky test.
temporal_src_embd = result[0][1]
second = temporal_src_embd[2, :]
self.assertEqual(second[0], 5)
Expand All @@ -60,6 +60,7 @@ def test_temporal_loader(self):
d = DataEmbedding(3, 128)
embedding = d(result[0][0].unsqueeze(0), temporal_src_embd.unsqueeze(0))
self.assertEqual(embedding.shape[2], 128)
# Until fixed
"""
i = Informer(3, 3, 3, 5, 5, out_len=4, factor=1)
r0 = result[0][0].unsqueeze(0)
Expand Down Expand Up @@ -127,6 +128,7 @@ def test_decoding_3(self):
self.assertEqual(d.shape[1], 336)

def test_t_loade2(self):
# a.
s_wargs = {
"file_path": "tests/test_data/keag_small.csv",
"forecast_history": 39,
Expand Down