diff --git a/tadpole_algorithms/models/ecmeb/__init__.py b/tadpole_algorithms/models/ecmeb/__init__.py index 69fc982..78e3ca6 100644 --- a/tadpole_algorithms/models/ecmeb/__init__.py +++ b/tadpole_algorithms/models/ecmeb/__init__.py @@ -156,8 +156,7 @@ def fill_nans_by_older_values(train_df): train_df[df_filled_nans.columns] = df_filled_nans return train_df - def train(self, train_set_path): - train_df = pd.read_csv(train_set_path) + def train(self, train_df): train_df = self.preprocess(train_df) futures = self.get_futures(train_df) diff --git a/tadpole_algorithms/models/simple_svm/__init__.py b/tadpole_algorithms/models/simple_svm/__init__.py index 066e49c..cf6fa1f 100644 --- a/tadpole_algorithms/models/simple_svm/__init__.py +++ b/tadpole_algorithms/models/simple_svm/__init__.py @@ -71,8 +71,7 @@ def set_futures(self, train_df): train_df = train_df.drop(train_df.groupby('RID').tail(1).index.values) return train_df - def train(self, train_set_path): - train_df = pd.read_csv(train_set_path) + def train(self, train_df): train_df = self.preprocess(train_df) train_df = self.set_futures(train_df) diff --git a/tadpole_algorithms/models/tadpole_model.py b/tadpole_algorithms/models/tadpole_model.py index 1ba16b4..592b5b4 100644 --- a/tadpole_algorithms/models/tadpole_model.py +++ b/tadpole_algorithms/models/tadpole_model.py @@ -3,11 +3,11 @@ class TadpoleModel(ABC): @abstractmethod - def train(self, train_set_path): + def train(self, train_df): pass @abstractmethod - def predict(self, test_set_path): + def predict(self, test_df): pass def save(self, path):