-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar100.py
129 lines (94 loc) · 4.04 KB
/
cifar100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from torchvision import models
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from use_pretrained_model import use_pretrained_model, Network
# hyperparameters
num_epochs = 25
batch_size = 128
learning_rate = 0.001
momentum = 0.9
print_every = 10
hidden_layers = [512]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frozen = False # whether the parameters in the loaded model are trained or not
# dataset
# dataset has PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True,download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False,download=True, transform=transform)
classes=train_dataset.classes
print("RESNET NOT FROZEN")
model =torchvision.models.resnet18(pretrained=True)
# Freeze parameters so we don't backprop through them
if frozen:
for param in model.parameters():
param.requires_grad = False
input_size_last_layer = 512 # has to be adapted according to model
model.fc=Network(input_size_last_layer,len(classes),hidden_layers) # has to be adapted according to model
acc_resnet=use_pretrained_model(model,classes,learning_rate, momentum,train_dataset, test_dataset,batch_size,device,num_epochs,print_every)
print("DENSENET NOT FROZEN")
model =model = models.densenet121(pretrained=True)
# Freeze parameters so we don't backprop through them
if frozen:
for param in model.parameters():
param.requires_grad = False
input_size_last_layer = 1024 # has to be adapted according to model
model.classifier=Network(input_size_last_layer,len(classes),hidden_layers) # has to be adapted according to model
acc_densenet=use_pretrained_model(model,classes,learning_rate, momentum,train_dataset, test_dataset,batch_size,device,num_epochs,print_every)
print("VGG NOT FROZEN")
model =model = models.vgg16(pretrained=True)
# Freeze parameters so we don't backprop through them
if frozen:
for param in model.parameters():
param.requires_grad = False
input_size_last_layer = 4096 # has to be adapted according to model
model.classifier[6]=Network(input_size_last_layer,len(classes),[2048,1024,512]) # has to be adapted according to model
acc_vgg=use_pretrained_model(model,classes,learning_rate, momentum,train_dataset, test_dataset,batch_size,device,num_epochs,print_every)
print("RESNEXT NOT FROZEN")
model =model =torchvision.models.resnext50_32x4d(pretrained=True)
# Freeze parameters so we don't backprop through them
if frozen:
for param in model.parameters():
param.requires_grad = False
input_size_last_layer = 2048 # has to be adapted according to model
model.fc=Network(input_size_last_layer,len(classes),[1024,512]) # has to be adapted according to model
acc_resnext=use_pretrained_model(model,classes,learning_rate, momentum,train_dataset, test_dataset,batch_size,device,num_epochs,print_every)
# create plot
n_groups=11
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.2
opacity = 0.8
rects1 = plt.bar(index, acc_resnet[:11], bar_width,
alpha=opacity,
color='b',
label='ResNet')
rects2 = plt.bar(index + bar_width, acc_densenet[:11], bar_width,
alpha=opacity,
color='g',
label='DenseNet')
rects3 = plt.bar(index + 2*bar_width, acc_vgg[:11], bar_width,
alpha=opacity,
color='r',
label='VGG')
rects3 = plt.bar(index + 3*bar_width, acc_resnext[:11], bar_width,
alpha=opacity,
color='y',
label='ResNext')
label_names =["Total"]
for c in classes[:10]:
label_names.append(c)
plt.xlabel('Network')
plt.ylabel('Accuracies')
plt.title('CIFAR 100')
plt.xticks(index + bar_width, label_names)
plt.legend()
plt.savefig("cifar100.png")
plt.show()