Skip to content

Commit

Permalink
save loss_ave to checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhou-wang committed Feb 7, 2022
1 parent 8568630 commit 39ad741
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def parse_args():
loss_cp = checkpoint['loss']
if 'iter_count' in checkpoint:
iter_count = checkpoint['iter_count']
if 'loss_ave' in checkpoint:
loss_ave = checkpoint['loss_ave']
else:
rodnet.load_state_dict(checkpoint)

Expand Down Expand Up @@ -281,7 +283,8 @@ def parse_args():
'iter': iter,
'model_state_dict': rodnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_confmap,
'loss': loss_confmap.item(),
'loss_ave': loss_ave,
'iter_count': iter_count,
}
save_model_path = '%s/epoch_%02d_iter_%010d.pkl' % (model_dir, epoch + 1, iter_count + 1)
Expand All @@ -298,7 +301,8 @@ def parse_args():
'iter': iter,
'model_state_dict': rodnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_confmap,
'loss': loss_confmap.item(),
'loss_ave': loss_ave,
'iter_count': iter_count,
}
save_model_path = '%s/epoch_%02d_final.pkl' % (model_dir, epoch + 1)
Expand Down

0 comments on commit 39ad741

Please sign in to comment.