Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加了自动调整学习率方法以及DCGAN的鉴别网络 #34

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Train CIFAR10 with PyTorch
#Train CIFAR10 with PyTorch

I'm playing with [PyTorch](http://pytorch.org/) on the CIFAR10 dataset.

Expand All @@ -25,6 +25,9 @@ Cons:
| [DenseNet121](https://arxiv.org/abs/1608.06993) | 95.04% |
| [PreActResNet18](https://arxiv.org/abs/1603.05027) | 95.11% |
| [DPN92](https://arxiv.org/abs/1707.01629) | 95.16% |
| [DCGAN'S netD] | 96% |

The acc for DCGAN's netD version looks strange, but i validate in `test.py`. **Hope anyone can help me verify it!**

## Learning rate adjustment
I manually change the `lr` during training:
Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from preact_resnet import *
from mobilenet import *
from mobilenetv2 import *
from dcgan import _DCGANConf
Binary file added models/__init__.pyc
Binary file not shown.
84 changes: 84 additions & 0 deletions models/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
#print m.weight.data.size()
m.weight.data.normal_(0.0, 0.02)
if classname.find('Linear') != -1:
m.bias.data.fill_(0)
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

class _DCGAND(nn.Module):
def __init__(self, ngpu=0, ndf=64, nc=3):
super(_DCGAND, self).__init__()
self.ngpu = ngpu
self.layer1 = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer2 = nn.Sequential(
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer3 = nn.Sequential(
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True)
)
self.pred = nn.Sequential(
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
x = self.layer1(input)
x = self.layer2(x)
x = self.layer3(x)
output = self.pred(x)
return output.view(-1, 1).squeeze(1), x

class _DCGANConf(nn.Module):
def __init__(self, ngpu=0, ndf=64, nc=3, num_classes=21):
super(_DCGANConf, self).__init__()
self.dcgan = _DCGAND(0, 64, 3)
self.conf = nn.Linear(512*4*4, num_classes)

def forward(self, input):
_, y = self.dcgan(input)
y = y.view(y.size(0), -1)
conf = self.conf(y)
return conf

def build_dcganconf(ngpu=0, ndf=64, nc=3, num_classes=21):
net = _DCGANConf(ngpu=ngpu, ndf=ndf, nc=nc, num_classes=num_classes)
net.apply(weights_init)
return net


Binary file added models/dcgan.pyc
Binary file not shown.
Binary file added models/densenet.pyc
Binary file not shown.
Binary file added models/dpn.pyc
Binary file not shown.
Binary file added models/googlenet.pyc
Binary file not shown.
Binary file added models/lenet.pyc
Binary file not shown.
Binary file added models/mobilenet.pyc
Binary file not shown.
Binary file added models/mobilenetv2.pyc
Binary file not shown.
Binary file added models/pnasnet.pyc
Binary file not shown.
Binary file added models/preact_resnet.pyc
Binary file not shown.
Binary file added models/resnet.pyc
Binary file not shown.
Binary file added models/resnext.pyc
Binary file not shown.
Binary file added models/senet.pyc
Binary file not shown.
Binary file added models/shufflenet.pyc
Binary file not shown.
Binary file added models/vgg.pyc
Binary file not shown.
104 changes: 104 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from models.dcgan import build_dcganconf
from utils import progress_bar
from torch.autograd import Variable


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
# transforms.RandomCrop(32, padding=4),
transforms.Resize((64, 64)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.t7')
net = checkpoint['net']
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

if use_cuda:
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True

def test(epoch):
global best_acc
net.eval()
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)

_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(testloader), 'Acc: {:3f} {} {}'.format(100.*correct/total, correct, total))

# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.module if use_cuda else net,
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc

for epoch in range(start_epoch, start_epoch+1):
test(epoch)
if epoch in [150, 250]:
step_index += 1
adjust_learning_rate(optimizer, 0.1, step_index)
164 changes: 164 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from models.dcgan import build_dcganconf
from utils import progress_bar
from torch.autograd import Variable


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
# transforms.RandomCrop(32, padding=4),
transforms.Resize((64, 64)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.t7')
net = checkpoint['net']
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
else:
print('==> Building model..')
# net = VGG('VGG19')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
net = build_dcganconf(num_classes=10)
print(net)

if use_cuda:
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
inputs, targets = Variable(inputs), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

train_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)

test_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.module if use_cuda else net,
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc

def adjust_learning_rate(optimizer, gamma, step):
"""Sets the learning rate to the initial LR decayed by 10 at every specified step
# Adapted from PyTorch Imagenet example:
# https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
lr = args.lr * (gamma ** (step))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
print(param_group['lr'])

step_index = 0
for epoch in range(start_epoch, start_epoch+150):
train(epoch)
test(epoch)
if epoch in [150, 250]:
step_index += 1
adjust_learning_rate(optimizer, 0.1, step_index)
Binary file added utils.pyc
Binary file not shown.