-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_captioning.py
149 lines (115 loc) · 5.72 KB
/
train_captioning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import nltk
import torch
import torch.nn as nn
from torchvision import transforms
import sys
from data_loader import get_loader
import math
import ssl
import torch.utils.data as data
import numpy as np
import os
from models import EncoderCNN, DecoderRNN
import utils
import validation
if __name__ == "__main__":
ssl._create_default_https_context = ssl._create_unverified_context
nltk.download('punkt')
batch_size = 64 # batch size
vocab_threshold = 6 # minimum word count threshold
vocab_from_file = True # if True, load existing vocab file
embed_size = 512 # dimensionality of image and word embeddings
hidden_size = 512 # number of features in hidden state of the RNN decoder
num_epochs = 5 # number of training epochs (1 for testing)
save_every = 1 # determines frequency of saving model weights
print_every = 200 # determines window for printing average loss
log_file = 'training_log.txt' # name of file with saved training loss and perplexity
transform_train = transforms.Compose([
transforms.Resize(256), # smaller edge of image resized to 256
transforms.RandomCrop(224), # get 224x224 crop from random location
transforms.RandomHorizontalFlip(), # horizontally flip image with probability=0.5
transforms.ToTensor(), # convert the PIL Image to a tensor
transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model
(0.229, 0.224, 0.225))])
# Build data loader.
data_loader = get_loader(transform=transform_train, mode='train', batch_size=batch_size, vocab_threshold=vocab_threshold,
vocab_from_file=vocab_from_file)
val_data_loader = get_loader(transform=transform_train, mode='val')
# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)
print(vocab_size)
# Initialize the encoder and decoder.
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
# Move models to GPU if CUDA is available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)
# Define the loss function.
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
# optimizer = torch.optim.Adam(params, lr=0.01, betas=(0.9, 0.999), eps=1e-08)
# optimizer = torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08)
# Set the total number of training steps per epoch.
total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)
# Open the training log file.
f = open(log_file, 'w')
# Collect losses in these arrays
training_loss_per_epoch = []
val_loss_per_epoch = []
for epoch in range(1, num_epochs + 1):
avg_batch_loss = 0
encoder.train()
decoder.train()
for i_step in range(1, total_step + 1):
# Randomly sample a caption length, and sample indices with that length.
indices = data_loader.dataset.get_train_indices()
# Create and assign a batch sampler to retrieve a batch with the sampled indices.
new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
data_loader.batch_sampler.sampler = new_sampler
# Obtain the batch.
images, captions = next(iter(data_loader))
# Move batch of images and captions to GPU if CUDA is available.
images = images.to(device)
captions = captions.to(device)
# Zero the gradients.
decoder.zero_grad()
encoder.zero_grad()
# Pass the inputs through the CNN-RNN model.
features = encoder(images)
outputs = decoder(features, captions)
# Calculate the batch loss.
# print("outputs.shape: ", outputs.shape)
loss = criterion(outputs.contiguous().view(-1, vocab_size), captions.view(-1))
avg_batch_loss += loss
# Backward pass.
loss.backward()
# Update the parameters in the optimizer.
optimizer.step()
# Get training statistics.
stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (
epoch, num_epochs, i_step, total_step, loss.item(), np.exp(loss.item()))
# Print training statistics (on same line).
print('\r' + stats, end="")
sys.stdout.flush()
# Print training statistics to file.
f.write(stats + '\n')
f.flush()
# Print training statistics (on different line).
if i_step % print_every == 0:
print('\r' + stats)
# Save the weights.
if epoch % save_every == 0:
torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))
avg_batch_loss /= total_step
training_loss_per_epoch.append(avg_batch_loss)
avg_loss = 0
val_loss = validation.validate(encoder, decoder, criterion, val_data_loader, vocab_size, epoch, device)
val_loss_per_epoch.append(avg_loss)
# Close the training log file.
f.close()
utils.plotLosses(training_loss_per_epoch,
val_loss_per_epoch,
'Cross Entropy Loss (per Epoch)')