Skip to content

Commit

Permalink
Modify coco class labels
Browse files Browse the repository at this point in the history
  • Loading branch information
kuangliu committed May 10, 2018
1 parent 41dcece commit 6291f3e
Show file tree
Hide file tree
Showing 5 changed files with 107,270 additions and 107,175 deletions.
15 changes: 15 additions & 0 deletions examples/retinanet/scripts/get_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import math
import torch

from torchcv.models.retinanet import RetinaNet


model_dir = './examples/retinanet/model'
params = torch.load(os.path.join(model_dir, 'resnet50-19c8e357.pth'))

net = RetinaNet(num_classes=90)
net.fpn.load_state_dict(params, strict=False)

torch.nn.init.constant_(net.cls_head[-1].bias, -math.log(1-0.01)/0.01)
torch.save(net.state_dict(), os.path.join(model_dir, 'retinanet_resnet50.pth'))
4 changes: 2 additions & 2 deletions examples/retinanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def transform_test(img, boxes, labels):
# Model
print('==> Building model..')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = RetinaNet(num_classes=90).to(device)
net = RetinaNet(num_classes=80).to(device)
# net.load_state_dict(torch.load(args.model))
if device == 'cuda':
net = torch.nn.DataParallel(net)
Expand All @@ -82,7 +82,7 @@ def transform_test(img, boxes, labels):
best_loss = checkpoint['loss']
start_epoch = checkpoint['epoch']

criterion = FocalLoss(num_classes=90)
criterion = FocalLoss(num_classes=80)
# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
optimizer = optim.Adam(net.parameters(), lr=args.lr, amsgrad=True)

Expand Down
Loading

0 comments on commit 6291f3e

Please sign in to comment.