Skip to content

Commit

Permalink
Added support for online training.
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiankmiec committed May 27, 2019
1 parent fdc99d7 commit 19e4b7e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 11 deletions.
6 changes: 3 additions & 3 deletions new_data_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@ def main():
#
# Your own data paths....
#

print("Extracting dataset features for training, and testing...")
# extract_myo_all_csv('/home/skmiec/Documents/ex5/a/myo_all_data.csv', new_data, "s11", "E1")
# extract_myo_all_csv('/home/skmiec/Documents/ex5/b/myo_all_data.csv', new_data, "s11", "E2")
# extract_myo_all_csv('/home/skmiec/Documents/ex5/c/myo_all_data.csv', new_data, "s11", "E3")

#
# extract_myo_all_csv('/home/skmiec/Documents/ex6/a/myo_all_data.csv', new_data, "s12", "E1")
# extract_myo_all_csv('/home/skmiec/Documents/ex6/b/myo_all_data.csv', new_data, "s12", "E2")
# extract_myo_all_csv('/home/skmiec/Documents/ex6/c/myo_all_data.csv', new_data, "s12", "E3")

print("Extracting dataset features for training, and testing...")
dataset.create_dataset(new_data, False)

print("Training classifier on training dataset...")
Expand All @@ -41,6 +40,7 @@ def main():
print("Testing classifier on testing dataset...")
print(classifier.perform_inference(dataset.test_features, dataset.test_labels))

classifier.save_model("/home/skmiec/Documents/")

if __name__ == "__main__":
main()
63 changes: 55 additions & 8 deletions ninaeval/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,20 @@ def perform_inference_helper(self, test_features):
def save_figure(self, path):
pass

def save_model(self, path = None):
def save_model(self, dir_path = None):
"""
Serializes (this) object for future loading and use
:param path: Path to save this object
:param path: Directory path to save this object
"""
if path is None:
path = os.path.join(self.models_path, self.__class__.__name__ + "_" +
self.feat_extractor.__class__.__name__ )
model_feat_name = self.__class__.__name__ + "_" + self.feat_extractor.__class__.__name__

with open(path, 'wb') as f:
if dir_path is None:
dir_path = os.path.join(self.models_path, model_feat_name)
else:
dir_path = os.path.join(dir_path, model_feat_name)

with open(dir_path, 'wb') as f:
pickle.dump(self, f, 2)

def load_model(path):
Expand Down Expand Up @@ -337,10 +340,16 @@ def on_end(self, state):
:param state: A state object used by the Torch engine
"""
print('Training' if state['train'] else 'Testing', 'accuracy')
print(self.classerr.value())

if not state["train"]:
if (not state["train"]) and (self.valid_features is not None) and (self.valid_labels is not None):
print(self.classerr.value())
self.test_accs.append(self.classerr.value())

elif state["train"]:
print(self.classerr.value())
else:
print(None)

self.reset_meters()

def save_checkpoint(self, epoch, loss):
Expand Down Expand Up @@ -431,6 +440,44 @@ def save_figure(self, path):
plt.show()
figure.savefig(path)

def update_training(self, train_features, train_labels, update_epochs):
'''
Performs further training on additional train_features/train_labels, for update_epochs-many epochs
:param train_features: Additional training features
:param train_labels: Additional training labels
:param update_epochs: Number of epochs used in additional training
'''
self.valid_features = None
self.valid_labels = None

dim_in = train_features[0].shape[0]

if self.needs_train_feat:
self.model = self.define_model(dim_in, train_features)
else:
self.model = self.define_model(dim_in)

# Use model for training
self.model.to(self.device)
self.model.train()
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)

torch_dataset = torchnet.dataset.TensorDataset([
(torch.from_numpy(train_features)).float(),
torch.from_numpy(train_labels)
])
iterator = torch_dataset.parallel(
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True
)
self.engine.train(self.forward_pass, iterator, maxepoch=update_epochs, optimizer=self.optimizer)

@abstractmethod
def forward_pass(self, sample):
"""
Expand Down

0 comments on commit 19e4b7e

Please sign in to comment.