diff --git a/new_data_example.py b/new_data_example.py index c3d7b9d..1a0eae9 100644 --- a/new_data_example.py +++ b/new_data_example.py @@ -23,16 +23,15 @@ def main(): # # Your own data paths.... # - + print("Extracting dataset features for training, and testing...") # extract_myo_all_csv('/home/skmiec/Documents/ex5/a/myo_all_data.csv', new_data, "s11", "E1") # extract_myo_all_csv('/home/skmiec/Documents/ex5/b/myo_all_data.csv', new_data, "s11", "E2") # extract_myo_all_csv('/home/skmiec/Documents/ex5/c/myo_all_data.csv', new_data, "s11", "E3") - + # # extract_myo_all_csv('/home/skmiec/Documents/ex6/a/myo_all_data.csv', new_data, "s12", "E1") # extract_myo_all_csv('/home/skmiec/Documents/ex6/b/myo_all_data.csv', new_data, "s12", "E2") # extract_myo_all_csv('/home/skmiec/Documents/ex6/c/myo_all_data.csv', new_data, "s12", "E3") - print("Extracting dataset features for training, and testing...") dataset.create_dataset(new_data, False) print("Training classifier on training dataset...") @@ -41,6 +40,7 @@ def main(): print("Testing classifier on testing dataset...") print(classifier.perform_inference(dataset.test_features, dataset.test_labels)) + classifier.save_model("/home/skmiec/Documents/") if __name__ == "__main__": main() \ No newline at end of file diff --git a/ninaeval/models/model.py b/ninaeval/models/model.py index 2fc46ea..9cec576 100644 --- a/ninaeval/models/model.py +++ b/ninaeval/models/model.py @@ -68,17 +68,20 @@ def perform_inference_helper(self, test_features): def save_figure(self, path): pass - def save_model(self, path = None): + def save_model(self, dir_path = None): """ Serializes (this) object for future loading and use - :param path: Path to save this object + :param path: Directory path to save this object """ - if path is None: - path = os.path.join(self.models_path, self.__class__.__name__ + "_" + - self.feat_extractor.__class__.__name__ ) + model_feat_name = self.__class__.__name__ + "_" + self.feat_extractor.__class__.__name__ - with open(path, 'wb') as f: + if dir_path is None: + dir_path = os.path.join(self.models_path, model_feat_name) + else: + dir_path = os.path.join(dir_path, model_feat_name) + + with open(dir_path, 'wb') as f: pickle.dump(self, f, 2) def load_model(path): @@ -337,10 +340,16 @@ def on_end(self, state): :param state: A state object used by the Torch engine """ print('Training' if state['train'] else 'Testing', 'accuracy') - print(self.classerr.value()) - if not state["train"]: + if (not state["train"]) and (self.valid_features is not None) and (self.valid_labels is not None): + print(self.classerr.value()) self.test_accs.append(self.classerr.value()) + + elif state["train"]: + print(self.classerr.value()) + else: + print(None) + self.reset_meters() def save_checkpoint(self, epoch, loss): @@ -431,6 +440,44 @@ def save_figure(self, path): plt.show() figure.savefig(path) + def update_training(self, train_features, train_labels, update_epochs): + ''' + Performs further training on additional train_features/train_labels, for update_epochs-many epochs + + :param train_features: Additional training features + :param train_labels: Additional training labels + :param update_epochs: Number of epochs used in additional training + ''' + self.valid_features = None + self.valid_labels = None + + dim_in = train_features[0].shape[0] + + if self.needs_train_feat: + self.model = self.define_model(dim_in, train_features) + else: + self.model = self.define_model(dim_in) + + # Use model for training + self.model.to(self.device) + self.model.train() + self.optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay + ) + + torch_dataset = torchnet.dataset.TensorDataset([ + (torch.from_numpy(train_features)).float(), + torch.from_numpy(train_labels) + ]) + iterator = torch_dataset.parallel( + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True + ) + self.engine.train(self.forward_pass, iterator, maxepoch=update_epochs, optimizer=self.optimizer) + @abstractmethod def forward_pass(self, sample): """