Skip to content

Commit

Permalink
batch training
Browse files Browse the repository at this point in the history
  • Loading branch information
Pharaun85 committed Oct 24, 2020
1 parent 88bce6f commit 5ac82f3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 2 additions & 0 deletions batch_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
python3 train.py --batch_size=32 --lr=0.00001 --optimizer adam --dataset ../DualBiSeNet/data_raw_bev_mask/ --train --embedding --weighted --num_epoch 300 --validation_step 5 --telegram --patience 2 --patience_start 50 --dataloader generatedDataset --lossfunction MSE --teacher_path ./trainedmodels/teacher/teacher_model_27e66514.pth
16 changes: 8 additions & 8 deletions teacher_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def train(args, model, optimizer, dataloader_train, dataloader_val, dataset_trai
wandb.log({"Val/loss": loss_val,
"Val/Acc": acc_val,
"random_rate": random_rate,
"conf-matrix_{}_{}".format(wandb.run.name, epoch): wandb.Image(plt)}, step=epoch)
"conf-matrix_{}_{}".format(wandb.run.id, epoch): wandb.Image(plt)}, step=epoch)
if args.nowandb:
if args.triplet:
print('Saving model: ',
Expand All @@ -472,17 +472,17 @@ def train(args, model, optimizer, dataloader_train, dataloader_val, dataset_trai
else:
if args.triplet:
print('Saving model: ',
os.path.join(args.save_model_path, 'teacher_model_{}.pth'.format(wandb.run.name)))
os.path.join(args.save_model_path, 'teacher_model_{}.pth'.format(wandb.run.id)))
torch.save(bestModel, os.path.join(args.save_model_path,
'teacher_model_{}.pth'.format(wandb.run.name)))
savepath = os.path.join(args.save_model_path, 'teacher_model_{}.pth'.format(wandb.run.name))
'teacher_model_{}.pth'.format(wandb.run.id)))
savepath = os.path.join(args.save_model_path, 'teacher_model_{}.pth'.format(wandb.run.id))
else:
print('Saving model: ',
os.path.join(args.save_model_path, 'teacher_model_class_{}.pth'.format(wandb.run.name)))
os.path.join(args.save_model_path, 'teacher_model_class_{}.pth'.format(wandb.run.id)))
torch.save(bestModel, os.path.join(args.save_model_path,
'teacher_model_class_{}.pth'.format(wandb.run.name)))
'teacher_model_class_{}.pth'.format(wandb.run.id)))
savepath = os.path.join(args.save_model_path,
'teacher_model_class_{}.pth'.format(wandb.run.name))
'teacher_model_class_{}.pth'.format(wandb.run.id))

elif epoch < args.patience_start:
patience = 0
Expand Down Expand Up @@ -533,7 +533,7 @@ def train(args, model, optimizer, dataloader_train, dataloader_val, dataset_trai
'in https://docs.wandb.com/sweeps/configuration#command')
parser.add_argument('--telegram', action='store_true', help='Send info through Telegram')

parser.add_argument('--triplet', type=bool, default=True, help='Triplet Loss')
parser.add_argument('--triplet', action='store_true', help='Triplet Loss')
parser.add_argument('--swap', action='store_true', help='Triplet Loss swap')
parser.add_argument('--margin', type=float, default=2.0, help='margin in triplet')
parser.add_argument('--no_noise', action='store_true', help='In case you want to disable the noise injection in '
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def main(args, model=None):
if args.wandb_group_id:
group_id = args.wandb_group_id
else:
group_id = 'Kitti2011_Homography'
group_id = 'Kitti2011_mask'

print(args)
warnings.filterwarnings("ignore")
Expand Down

0 comments on commit 5ac82f3

Please sign in to comment.