Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load best model + evaluate best model on devset #11

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions pie/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def get_targets(settings):
return [task['name'] for task in settings.tasks if task.get('target')]


def get_fname_infix(settings):
def get_fname_infix(settings, epoch=None):
# fname
fname = os.path.join(settings.modelpath, settings.modelname)
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
infix = '+'.join(get_targets(settings)) + '-' + timestamp
infix = '+'.join(get_targets(settings))
if epoch:
infix += f"-{epoch}"
infix += '-' + timestamp
return fname, infix


Expand Down Expand Up @@ -152,21 +155,35 @@ def run(settings):
model.eval()
running_time = time.time() - running_time

# evaluate best model on devset
if settings.dev_path:
print()
print("Evaluating best model on dev set...")
print()
model.eval()
stored_scores = {}
with torch.no_grad():
dev_loss = trainer.evaluate(devset)
print()
print("::: Dev losses :::")
print()
print('\n'.join('{}: {:.4f}'.format(k, v) for k, v in dev_loss.items()))
print()
summary = model.evaluate(devset, trainer.dataset)
for task_name, scorer in summary.items():
stored_scores[task_name] = scorer.get_scores()
scorer.print_summary(scores=stored_scores[task_name])

# evaluate best model on test set
if settings.test_path:
print("Evaluating model on test set")
print("Evaluating best model on test set")
try:
testset = Dataset(settings, Reader(settings, settings.test_path), label_encoder)
for task in model.evaluate(testset, trainset).values():
task.print_summary()
except Exception as E:
print(E)

# save model
fpath, infix = get_fname_infix(settings)
if not settings.run_test:
fpath = model.save(fpath, infix=infix, settings=settings)
print("Saved best model to: [{}]".format(fpath))

if devset is not None and not settings.run_test:
scorers = model.evaluate(devset, trainset)
scores = []
Expand All @@ -181,10 +198,24 @@ def run(settings):
path = '{}.results.{}.csv'.format(
settings.modelname, '-'.join(get_targets(settings)))
with open(path, 'a') as f:
_, infix = get_fname_infix(settings)
line = [infix, str(seed), str(running_time)]
line += scores
f.write('{}\n'.format('\t'.join(line)))

# save model
if not settings.run_test:
# Save best model
fpath, infix = get_fname_infix(settings, epoch="best")
fpath = model.save(fpath, infix=infix, settings=settings)
print("Saved best model to: [{}]".format(fpath))
# Save last model
if "last_state_dict" in model.__dict__:
model.load_state_dict(model.last_state_dict)
fpath, infix = get_fname_infix(settings, epoch="last")
fpath = model.save(fpath, infix=infix, settings=settings)
print("Saved last model to: [{}]".format(fpath))

print("Bye!")


Expand Down
35 changes: 28 additions & 7 deletions pie/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import collections
import random
import tempfile
import warnings
from typing import ClassVar

import tqdm
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self, settings):
self.threshold = settings.threshold
self.min_weight = settings.min_weight
self.fid = os.path.join(tempfile.gettempdir(), str(uuid.uuid1()))
self.best_epoch = 0

def __repr__(self):
# task scheduler
Expand Down Expand Up @@ -111,7 +113,7 @@ def is_best(self, task, value):
else:
raise ValueError("Wrong mode value [{}] for task: {}".format(mode, task))

def step(self, scores, model):
def step(self, scores, model, epoch):
"""
Advance schedule step based on dev scores
"""
Expand All @@ -129,6 +131,7 @@ def step(self, scores, model):
if is_target:
# serialize model params
torch.save(model.state_dict(), self.fid)
self.best_epoch = epoch
else:
self.tasks[task]['steps'] += 1

Expand Down Expand Up @@ -224,6 +227,8 @@ def __init__(self, settings, model, dataset, num_instances):
else:
self.check_freq = 0 # no checks

self.current_epoch = 0

self.task_scheduler = TaskScheduler(settings)
self.lr_scheduler = LRScheduler(
self.optimizer,
Expand Down Expand Up @@ -329,7 +334,7 @@ def run_check(self, devset):
dev_scores['lm_fwd'] = dev_loss['lm_fwd']
dev_scores['lm_bwd'] = dev_loss['lm_bwd']

self.task_scheduler.step(dev_scores, self.model)
self.task_scheduler.step(dev_scores, self.model, self.current_epoch)
self.lr_scheduler.step(dev_scores[self.target_task])

if self.verbose:
Expand All @@ -340,7 +345,7 @@ def run_check(self, devset):

return dev_scores

def train_epoch(self, devset, epoch):
def train_epoch(self, devset):
rep_loss = collections.defaultdict(float)
rep_batches = collections.defaultdict(int)
rep_items, rep_start = 0, time.time()
Expand Down Expand Up @@ -395,23 +400,39 @@ def train_epochs(self, epochs, devset=None):
scores = None

try:
for epoch in range(1, epochs + 1):
for epoch in range(self.current_epoch + 1, epochs + 1):
# train epoch
self.current_epoch += 1
epoch_start = time.time()
logging.info("Starting epoch [{}]".format(epoch))
self.train_epoch(devset, epoch)
self.train_epoch(devset)
epoch_total = time.time() - epoch_start
logging.info("Finished epoch [{}] in [{:.0f}] secs".format(
epoch, epoch_total))

except EarlyStopException as e:
logging.info("Early stopping training: "
"task [{}] with best score {:.4f}".format(e.task, e.loss))


print(f"Loading best model (epoch {self.task_scheduler.best_epoch}) for target task {self.target_task}")
self.model.last_state_dict = self.model.state_dict()
self.model.load_state_dict(e.best_state_dict)
scores = {e.task: e.loss}
else:
# Load best model
print(f"Loading best model (epoch {self.task_scheduler.best_epoch}) for target task {self.target_task}")
self.model.last_state_dict = self.model.state_dict()
if os.path.exists(self.task_scheduler.fid):
best_state_dict = torch.load(self.task_scheduler.fid)
self.model.load_state_dict(best_state_dict)
else:
warnings.warn(
f"Temp path with best model weights doesn't exist ({self.task_scheduler.fid}). "
"Maybe the model never improved over training ?"
)
scores = {self.target_task: self.task_scheduler.tasks[self.target_task]['best']}

logging.info("Finished training in [{:.0f}] secs".format(time.time() - start))

# will be None if no dev test was provided or the model failed to converge
# will be None if no dev test was provided
return scores