diff --git a/flood_forecast/multi_models/crossvivit.py b/flood_forecast/multi_models/crossvivit.py index 6aedca8ec..1c21aa051 100644 --- a/flood_forecast/multi_models/crossvivit.py +++ b/flood_forecast/multi_models/crossvivit.py @@ -674,7 +674,7 @@ def forward( [mlp(transformed_timeseries) for mlp in self.mlp_heads], dim=2 ) - # Generate quantile mask + # Generate quantile masks # (Discussed in Section 3.3, subsection on uncertainty estimation) quantile_mask = self.quantile_masker( rearrange(transformed_timeseries.detach(), "b t c -> b c t") diff --git a/flood_forecast/preprocessing/pytorch_loaders.py b/flood_forecast/preprocessing/pytorch_loaders.py index f15e32394..fdf96b382 100644 --- a/flood_forecast/preprocessing/pytorch_loaders.py +++ b/flood_forecast/preprocessing/pytorch_loaders.py @@ -40,10 +40,12 @@ def __init__( :param scaling: (highly reccomended) If provided should be a subclass of sklearn.base.BaseEstimator and sklearn.base.TransformerMixin) i.e StandardScaler, MaxAbsScaler, MinMaxScaler, etc) Note without a scaler the loss is likely to explode and cause infinite loss which will corrupt weights - :param start_stamp int: Optional if you want to only use part of a CSV for training, validation + :param start_stamp: Optional if you want to only use part of a CSV for training, validation or testing supply these - :param end_stamp int: Optional if you want to only use part of a CSV for training, validation, - or testing supply these + :type start_stamp: int, optional + :param end_stamp: Optional if you want to only use part of a CSV for training, validation, + or testing supply these + :type end_stamp: int, optional :param sort_column str: The column to sort the time series on prior to forecast. :param scaled_cols: The columns you want scaling applied to (if left blank will default to all columns) :param feature_params: These are the datetime features you want to create. @@ -120,7 +122,7 @@ def __len__(self) -> int: len(self.df.index) - self.forecast_history - self.forecast_length - 1 ) - def __sample_and_track_series__(self, idx, series_id=None): + def __sample_and_track_series__(self, idx: int, series_id=None): pass def inverse_scale( @@ -159,16 +161,16 @@ def inverse_scale( class CSVSeriesIDLoader(CSVDataLoader): - def __init__(self, series_id_col: str, main_params: dict, return_method: str, return_all=True): + def __init__(self, series_id_col: str, main_params: dict, return_method: str, return_all: bool = True): """A data-loader for a CSV file that contains a series ID column. :param series_id_col: The id column of the series you want to forecast. :type series_id_col: str :param main_params: The central set of parameters :type main_params: dict - :param return_method: The method of return + :param return_method: The method of return (e.g. all series at once, one at a time, or a random sample) :type return_method: str - :param return_all: Whether to return all items, defaults to True + :param return_all: Whether to return all items if set to True then __validate_data_in_df__, defaults to True :type return_all: bool, optional """ main_params1 = deepcopy(main_params) @@ -201,7 +203,7 @@ def __init__(self, series_id_col: str, main_params: dict, return_method: str, re print("unique dict") def __validate_data__in_df(self): - """Makes sure the data in the data-frame is the proper length for each series e.""" + """Makes sure the data in the data-frame is the proper length for each series.""" if self.return_all_series: len_first = len(self.listed_vals[0]) print("Length of first series is:" + str(len_first)) @@ -228,7 +230,6 @@ def __getitem__(self, idx: int) -> Tuple[Dict, Dict]: targ_list = {} for va in self.listed_vals: # We need to exclude the index column on one end and the series id column on the other - targ_start_idx = idx + self.forecast_history idx2 = va[self.series_id_col].iloc[0] va_returned = va[va.columns.difference([self.series_id_col], sort=False)] @@ -264,9 +265,9 @@ def __init__( ): """ A data loader for the test data and plotting code it is a subclass of CSVDataLoader. - :param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame + :param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame. :type df_path: str - :param int forecast_total: The total length of the forecast + :param int forecast_total: The total length of the forecast. :type forecast_total: int """ if "file_path" not in kwargs: @@ -308,7 +309,7 @@ def __getitem__(self, idx): historical_rows = self.df.iloc[idx: self.forecast_history + idx] target_idx_start = self.forecast_history + idx # Why aren't we using these - # targ_rows = self.df.iloc[ + # targ_rows = self.df.ilo c[ # target_idx_start : self.forecast_total + target_idx_start # ] all_rows_orig = self.original_df.iloc[