Skip to content

Commit

Permalink
cleaning up the data-loader
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Oct 7, 2024
1 parent edbd8d1 commit 19c5768
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def forward(
[mlp(transformed_timeseries) for mlp in self.mlp_heads], dim=2
)

# Generate quantile mask
# Generate quantile masks
# (Discussed in Section 3.3, subsection on uncertainty estimation)
quantile_mask = self.quantile_masker(
rearrange(transformed_timeseries.detach(), "b t c -> b c t")
Expand Down
25 changes: 13 additions & 12 deletions flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ def __init__(
:param scaling: (highly reccomended) If provided should be a subclass of sklearn.base.BaseEstimator
and sklearn.base.TransformerMixin) i.e StandardScaler, MaxAbsScaler, MinMaxScaler, etc) Note without
a scaler the loss is likely to explode and cause infinite loss which will corrupt weights
:param start_stamp int: Optional if you want to only use part of a CSV for training, validation
:param start_stamp: Optional if you want to only use part of a CSV for training, validation
or testing supply these
:param end_stamp int: Optional if you want to only use part of a CSV for training, validation,
or testing supply these
:type start_stamp: int, optional
:param end_stamp: Optional if you want to only use part of a CSV for training, validation,
or testing supply these
:type end_stamp: int, optional
:param sort_column str: The column to sort the time series on prior to forecast.
:param scaled_cols: The columns you want scaling applied to (if left blank will default to all columns)
:param feature_params: These are the datetime features you want to create.
Expand Down Expand Up @@ -120,7 +122,7 @@ def __len__(self) -> int:
len(self.df.index) - self.forecast_history - self.forecast_length - 1
)

def __sample_and_track_series__(self, idx, series_id=None):
def __sample_and_track_series__(self, idx: int, series_id=None):
pass

def inverse_scale(
Expand Down Expand Up @@ -159,16 +161,16 @@ def inverse_scale(


class CSVSeriesIDLoader(CSVDataLoader):
def __init__(self, series_id_col: str, main_params: dict, return_method: str, return_all=True):
def __init__(self, series_id_col: str, main_params: dict, return_method: str, return_all: bool = True):
"""A data-loader for a CSV file that contains a series ID column.
:param series_id_col: The id column of the series you want to forecast.
:type series_id_col: str
:param main_params: The central set of parameters
:type main_params: dict
:param return_method: The method of return
:param return_method: The method of return (e.g. all series at once, one at a time, or a random sample)
:type return_method: str
:param return_all: Whether to return all items, defaults to True
:param return_all: Whether to return all items if set to True then __validate_data_in_df__, defaults to True
:type return_all: bool, optional
"""
main_params1 = deepcopy(main_params)
Expand Down Expand Up @@ -201,7 +203,7 @@ def __init__(self, series_id_col: str, main_params: dict, return_method: str, re
print("unique dict")

def __validate_data__in_df(self):
"""Makes sure the data in the data-frame is the proper length for each series e."""
"""Makes sure the data in the data-frame is the proper length for each series."""
if self.return_all_series:
len_first = len(self.listed_vals[0])
print("Length of first series is:" + str(len_first))
Expand All @@ -228,7 +230,6 @@ def __getitem__(self, idx: int) -> Tuple[Dict, Dict]:
targ_list = {}
for va in self.listed_vals:
# We need to exclude the index column on one end and the series id column on the other

targ_start_idx = idx + self.forecast_history
idx2 = va[self.series_id_col].iloc[0]
va_returned = va[va.columns.difference([self.series_id_col], sort=False)]
Expand Down Expand Up @@ -264,9 +265,9 @@ def __init__(
):
"""
A data loader for the test data and plotting code it is a subclass of CSVDataLoader.
:param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame
:param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame.
:type df_path: str
:param int forecast_total: The total length of the forecast
:param int forecast_total: The total length of the forecast.
:type forecast_total: int
"""
if "file_path" not in kwargs:
Expand Down Expand Up @@ -308,7 +309,7 @@ def __getitem__(self, idx):
historical_rows = self.df.iloc[idx: self.forecast_history + idx]
target_idx_start = self.forecast_history + idx
# Why aren't we using these
# targ_rows = self.df.iloc[
# targ_rows = self.df.ilo c[
# target_idx_start : self.forecast_total + target_idx_start
# ]
all_rows_orig = self.original_df.iloc[
Expand Down

0 comments on commit 19c5768

Please sign in to comment.