forked from microsoft/Oscar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_retrieval.py
664 lines (607 loc) · 33.8 KB
/
run_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
# Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license.
from __future__ import absolute_import, division, print_function
import argparse
import os
import base64
import os.path as op
import random, json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from oscar.utils.tsv_file import TSVFile
from oscar.utils.logger import setup_logger
from oscar.utils.misc import mkdir, set_seed
from oscar.modeling.modeling_bert import ImageBertForSequenceClassification
from transformers.pytorch_transformers import BertTokenizer, BertConfig
from transformers.pytorch_transformers import AdamW, WarmupLinearSchedule, WarmupConstantSchedule
class RetrievalDataset(Dataset):
""" Image/Text Retrieval Dataset"""
def __init__(self, tokenizer, args, split='train', is_train=True):
"""
tokenizer: tokenizer to process caption text.
args: configureation parameters including max_seq_length, etc.
split: used to infer the data used for training or testing.
All files are in .pt format of a dictionary with image keys and
image features (pytorch tensors), captions (list of str, support multiple
captions per image), labels (list of dictionary or str of all labels),
"""
super(RetrievalDataset, self).__init__()
self.img_file = args.img_feat_file
caption_file = op.join(args.data_dir, '{}_captions.pt'.format(split))
self.img_tsv = TSVFile(self.img_file)
self.captions = torch.load(caption_file)
self.img_keys = list(self.captions.keys()) # img_id as int
if not type(self.captions[self.img_keys[0]]) == list:
self.captions = {k: json.loads(self.captions[k]) for k in self.img_keys}
# get the image image_id to index map
imgid2idx_file = op.join(op.dirname(self.img_file), 'imageid2idx.json')
self.image_id2idx = json.load(open(imgid2idx_file)) # img_id as string
if args.add_od_labels:
label_data_dir = op.dirname(self.img_file)
label_file = os.path.join(label_data_dir, "predictions.tsv")
self.label_tsv = TSVFile(label_file)
self.labels = {}
for line_no in range(self.label_tsv.num_rows()):
row = self.label_tsv.seek(line_no)
image_id = row[0]
if int(image_id) in self.img_keys:
results = json.loads(row[1])
objects = results['objects'] if type(
results) == dict else results
self.labels[int(image_id)] = {
"image_h": results["image_h"] if type(
results) == dict else 600,
"image_w": results["image_w"] if type(
results) == dict else 800,
"class": [cur_d['class'] for cur_d in objects],
"boxes": np.array([cur_d['rect'] for cur_d in objects],
dtype=np.float32)
}
self.label_tsv._fp.close()
self.label_tsv._fp = None
if is_train:
self.num_captions_per_img = args.num_captions_per_img_train
else:
self.num_captions_per_img = args.num_captions_per_img_val
if args.eval_img_keys_file:
# select a subset of image keys for evaluation. eg. COCO 1k and 5k
# eval_img_keys_file is a list of image keys saved in tsv file
with open(op.join(args.data_dir, args.eval_img_keys_file), 'r') as f:
img_keys = f.readlines()
self.img_keys = [int(k.strip()) for k in img_keys]
self.captions = {k: self.captions[k] for k in self.img_keys}
if args.add_od_labels:
self.labels = {k: self.labels[k] for k in self.img_keys}
if args.eval_caption_index_file:
# hard negative image/caption indexs for retrieval re-rank setting.
# useful for mini val set to monitor the performance during training.
# However, it cannot be used together with cross image evaluation.
self.has_caption_indexs = True
assert not args.cross_image_eval
caption_index_file = op.join(args.data_dir, args.eval_caption_index_file)
self.caption_indexs = torch.load(caption_index_file)
if not type(self.caption_indexs[self.img_keys[0]]) == list:
self.caption_indexs = {k: json.loads(self.caption_indexs[k]) for k in self.img_keys}
else:
self.has_caption_indexs = False
self.is_train = is_train
self.output_mode = args.output_mode
self.tokenizer = tokenizer
self.max_seq_len = args.max_seq_length
self.max_img_seq_len = args.max_img_seq_length
self.args = args
def get_image_caption_index(self, index):
# return img_idx to access features and [img_key, cap_idx] to access caption
if not self.is_train and self.args.cross_image_eval:
img_idx = index // (self.num_captions_per_img * len(self.img_keys))
cap_idx = index % (self.num_captions_per_img * len(self.img_keys))
img_idx1 = cap_idx // self.num_captions_per_img
cap_idx1 = cap_idx % self.num_captions_per_img
return img_idx, [self.img_keys[img_idx1], cap_idx1]
if not self.is_train and self.has_caption_indexs:
img_idx = index // self.num_captions_per_img
cap_idx = index % self.num_captions_per_img
img_key1, cap_idx1 = self.caption_indexs[self.img_keys[img_idx]][cap_idx]
return img_idx, [img_key1, cap_idx1]
img_idx = index // self.num_captions_per_img
cap_idx = index % self.num_captions_per_img
return img_idx, [self.img_keys[img_idx], cap_idx]
def get_label(self, index):
img_idx, cap_idx = self.get_image_caption_index(index)
return 1 if self.img_keys[img_idx] == cap_idx[0] else 0
def get_od_labels(self, img_key):
if self.args.add_od_labels:
if type(self.labels[img_key]) == str:
od_labels = self.labels[img_key]
else:
od_labels = ' '.join(self.labels[img_key]['class'])
return od_labels
def tensorize_example(self, text_a, img_feat, text_b=None,
cls_token_segment_id=0, pad_token_segment_id=0,
sequence_a_segment_id=0, sequence_b_segment_id=1):
tokens_a = self.tokenizer.tokenize(text_a)
if len(tokens_a) > self.args.max_seq_length - 2:
tokens_a = tokens_a[:(self.args.max_seq_length - 2)]
tokens = [self.tokenizer.cls_token] + tokens_a + [self.tokenizer.sep_token]
segment_ids = [cls_token_segment_id] + [sequence_a_segment_id] * (len(tokens_a) + 1)
seq_a_len = len(tokens)
if text_b:
tokens_b = self.tokenizer.tokenize(text_b)
if len(tokens_b) > self.max_seq_len - len(tokens) - 1:
tokens_b = tokens_b[: (self.max_seq_len - len(tokens) - 1)]
tokens += tokens_b + [self.tokenizer.sep_token]
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
seq_len = len(tokens)
seq_padding_len = self.max_seq_len - seq_len
tokens += [self.tokenizer.pad_token] * seq_padding_len
segment_ids += [pad_token_segment_id] * seq_padding_len
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# image features
img_len = img_feat.shape[0]
if img_len > self.max_img_seq_len:
img_feat = img_feat[0 : self.max_img_seq_len, :]
img_len = img_feat.shape[0]
img_padding_len = 0
else:
img_padding_len = self.max_img_seq_len - img_len
padding_matrix = torch.zeros((img_padding_len, img_feat.shape[1]))
img_feat = torch.cat((img_feat, padding_matrix), 0)
# generate attention_mask
att_mask_type = self.args.att_mask_type
if att_mask_type == "CLR":
attention_mask = [1] * seq_len + [0] * seq_padding_len + \
[1] * img_len + [0] * img_padding_len
else:
# use 2D mask to represent the attention
max_len = self.max_seq_len + self.max_img_seq_len
attention_mask = torch.zeros((max_len, max_len), dtype=torch.long)
# full attention of C-C, L-L, R-R
c_start, c_end = 0, seq_a_len
l_start, l_end = seq_a_len, seq_len
r_start, r_end = self.max_seq_len, self.max_seq_len + img_len
attention_mask[c_start : c_end, c_start : c_end] = 1
attention_mask[l_start : l_end, l_start : l_end] = 1
attention_mask[r_start : r_end, r_start : r_end] = 1
if att_mask_type == 'CL':
attention_mask[c_start : c_end, l_start : l_end] = 1
attention_mask[l_start : l_end, c_start : c_end] = 1
elif att_mask_type == 'CR':
attention_mask[c_start : c_end, r_start : r_end] = 1
attention_mask[r_start : r_end, c_start : c_end] = 1
elif att_mask_type == 'LR':
attention_mask[l_start : l_end, r_start : r_end] = 1
attention_mask[r_start : r_end, l_start : l_end] = 1
else:
raise ValueError("Unsupported attention mask type {}".format(att_mask_type))
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
return (input_ids, attention_mask, segment_ids, img_feat)
def __getitem__(self, index):
if self.is_train:
img_idx, cap_idxs = self.get_image_caption_index(index)
img_key = self.img_keys[img_idx]
feature = self.get_image(img_key)
caption = self.captions[cap_idxs[0]][cap_idxs[1]]
od_labels = self.get_od_labels(img_key)
example = self.tensorize_example(caption, feature, text_b=od_labels)
# select a negative pair
neg_img_indexs = list(range(0, img_idx)) + list(range(img_idx + 1, len(self.img_keys)))
img_idx_neg = random.choice(neg_img_indexs)
if random.random() <= 0.5:
# randomly select a negative caption from a different image.
cap_idx_neg = random.randint(0, self.num_captions_per_img - 1)
caption_neg = self.captions[self.img_keys[img_idx_neg]][cap_idx_neg]
example_neg = self.tensorize_example(caption_neg, feature, text_b=od_labels)
else:
# randomly select a negative image
feature_neg = self.get_image(self.img_keys[img_idx_neg])
od_labels_neg = self.get_od_labels(self.img_keys[img_idx_neg])
example_neg = self.tensorize_example(caption, feature_neg, text_b=od_labels_neg)
example_pair = tuple(list(example) + [1] + list(example_neg) + [0])
return index, example_pair
else:
img_idx, cap_idxs = self.get_image_caption_index(index)
img_key = self.img_keys[img_idx]
feature = self.get_image(img_key)
caption = self.captions[cap_idxs[0]][cap_idxs[1]]
od_labels = self.get_od_labels(img_key)
example = self.tensorize_example(caption, feature, text_b=od_labels)
label = 1 if img_key == cap_idxs[0] else 0
return index, tuple(list(example) + [label])
def get_image(self, image_id):
image_idx = self.image_id2idx[str(image_id)]
row = self.img_tsv.seek(image_idx)
num_boxes = int(row[1])
features = np.frombuffer(base64.b64decode(row[-1]),
dtype=np.float32).reshape((num_boxes, -1))
t_features = torch.from_numpy(features)
return t_features
def __len__(self):
if not self.is_train and self.args.cross_image_eval:
return len(self.img_keys) ** 2 * self.num_captions_per_img
return len(self.img_keys) * self.num_captions_per_img
def compute_score_with_logits(logits, labels):
if logits.shape[1] > 1:
logits = torch.max(logits, 1)[1].data # argmax
scores = logits == labels
else:
scores = torch.zeros_like(labels).cuda()
for i, (logit, label) in enumerate(zip(logits, labels)):
logit_ = torch.sigmoid(logit)
if (logit_ >= 0.5 and label == 1) or (logit_ < 0.5 and label == 0):
scores[i] = 1
return scores
def compute_ranks(dataset, results):
labels = np.array([dataset.get_label(i) for i in range(len(dataset))])
similarities = np.array([results[i] for i in range(len(dataset))])
if dataset.has_caption_indexs:
num_captions_per_img = dataset.num_captions_per_img
else:
num_captions_per_img = len(dataset.img_keys) * dataset.num_captions_per_img
labels = np.reshape(labels, [-1, num_captions_per_img])
similarities = np.reshape(similarities, [-1, num_captions_per_img])
i2t_ranks, t2i_ranks = [], []
for lab, sim in zip(labels, similarities):
inds = np.argsort(sim)[::-1]
rank = num_captions_per_img
for r, ind in enumerate(inds):
if lab[ind] == 1:
rank = r
break
i2t_ranks.append(rank)
if not dataset.has_caption_indexs:
labels = np.swapaxes(labels, 0, 1)
similarities = np.swapaxes(similarities, 0, 1)
for lab, sim in zip(labels, similarities):
inds = np.argsort(sim)[::-1]
rank = num_captions_per_img
for r, ind in enumerate(inds):
if lab[ind] == 1:
rank = r
break
t2i_ranks.append(rank)
return i2t_ranks, t2i_ranks
def save_checkpoint(model, tokenizer, args, epoch, global_step):
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
epoch, global_step))
mkdir(checkpoint_dir)
model_to_save = model.module if hasattr(model, 'module') else model
save_num = 0
while (save_num < 10):
try:
model_to_save.save_pretrained(checkpoint_dir)
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
tokenizer.save_pretrained(checkpoint_dir)
logger.info("Save checkpoint to {}".format(checkpoint_dir))
break
except:
save_num += 1
if save_num == 10:
logger.info("Failed to save checkpoint after 10 trails.")
return
def train(args, train_dataset, val_dataset, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler,
batch_size=args.train_batch_size, num_workers=args.num_workers)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // \
args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps \
* args.num_train_epochs
# Prepare optimizer and scheduler
no_decay = ['bias', 'LayerNorm.weight']
grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not \
any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if \
any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
if args.scheduler == "constant":
scheduler = WarmupConstantSchedule(
optimizer, warmup_steps=args.warmup_steps)
elif args.scheduler == "linear":
scheduler = WarmupLinearSchedule(
optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
else:
raise ValueError("Unknown scheduler type: {}".format(args.scheduler))
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step, global_loss, global_acc =0, 0.0, 0.0
model.zero_grad()
log_json = []
best_score = 0
for epoch in range(int(args.num_train_epochs)):
for step, (_, batch) in enumerate(train_dataloader):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {
'input_ids': torch.cat((batch[0], batch[5]), dim=0),
'attention_mask': torch.cat((batch[1], batch[6]), dim=0),
'token_type_ids': torch.cat((batch[2], batch[7]), dim=0),
'img_feats': torch.cat((batch[3], batch[8]), dim=0),
'labels': torch.cat((batch[4], batch[9]), dim=0)
}
outputs = model(**inputs)
loss, logits = outputs[:2]
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
batch_score = compute_score_with_logits(logits, inputs['labels']).sum()
batch_acc = batch_score.item() / (args.train_batch_size * 2)
global_loss += loss.item()
global_acc += batch_acc
if (step + 1) % args.gradient_accumulation_steps == 0:
global_step += 1
scheduler.step()
optimizer.step()
model.zero_grad()
if global_step % args.logging_steps == 0:
logger.info("Epoch: {}, global_step: {}, lr: {:.6f}, loss: {:.4f} ({:.4f}), " \
"score: {:.4f} ({:.4f})".format(epoch, global_step,
optimizer.param_groups[0]["lr"], loss, global_loss / global_step,
batch_acc, global_acc / global_step)
)
if (args.save_steps > 0 and global_step % args.save_steps == 0) or \
global_step == t_total:
save_checkpoint(model, tokenizer, args, epoch, global_step)
# evaluation
if args.evaluate_during_training:
logger.info("Perform evaluation at step: %d" % (global_step))
test_result = test(args, model, val_dataset)
eval_result = evaluate(val_dataset, test_result)
rank_accs = eval_result['i2t_retrieval']
if rank_accs['R@1'] > best_score:
best_score = rank_accs['R@1']
epoch_log = {'epoch': epoch, 'global_step': global_step,
'R1': rank_accs['R@1'], 'R5': rank_accs['R@5'],
'R10': rank_accs['R@10'], 'best_R1':best_score}
log_json.append(epoch_log)
with open(args.output_dir + '/eval_logs.json', 'w') as fp:
json.dump(log_json, fp)
return global_step, global_loss / global_step
def test(args, model, eval_dataset):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler,
batch_size=args.eval_batch_size, num_workers=args.num_workers)
logger.info("Num examples = {}".format(len(eval_dataset)))
logger.info("Evaluation batch size = {}".format(args.eval_batch_size))
model.eval()
results = {}
softmax = nn.Softmax(dim=1)
for indexs, batch in tqdm(eval_dataloader):
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2],
'img_feats': batch[3],
'labels': batch[4]
}
_, logits = model(**inputs)[:2]
if args.num_labels == 2:
probs = softmax(logits)
result = probs[:, 1] # the confidence to be a matched pair
else:
result = logits
result = [_.to(torch.device("cpu")) for _ in result]
results.update({idx.item(): res.item() for idx, res in zip(indexs, result)})
return results
def evaluate(eval_dataset, test_results):
i2t_ranks, t2i_ranks = compute_ranks(eval_dataset, test_results)
rank = [1, 5, 10]
i2t_accs = [sum([_ < r for _ in i2t_ranks]) / len(i2t_ranks) for r in rank]
logger.info("I2T Retrieval: {:.4f} @ R1, {:.4f} @ R5, {:.4f} @ R10".format(
i2t_accs[0], i2t_accs[1], i2t_accs[2]))
eval_result = {"i2t_retrieval": {"R@1": i2t_accs[0], "R@5": i2t_accs[1], "R@10": i2t_accs[2]}}
if t2i_ranks:
t2i_accs = [sum([_ < r for _ in t2i_ranks]) / len(t2i_ranks) for r in rank]
logger.info("T2I Retrieval: {:.4f} @ R1, {:.4f} @ R5, {:.4f} @ R10".format(
t2i_accs[0], t2i_accs[1], t2i_accs[2]))
eval_result["t2i_retrieval"] = {"R@1": t2i_accs[0], "R@5": t2i_accs[1], "R@10": t2i_accs[2]}
return eval_result
def get_predict_file(args):
cc = []
data = op.basename(op.join(args.data_dir, '')[:-1])
if data != 'coco_ir':
cc.append(data)
cc.append(args.test_split)
if args.add_od_labels:
cc.append('wlabels{}'.format(args.od_label_type))
return op.join(args.eval_model_dir, '{}.results.pt'.format('.'.join(cc)))
def restore_training_settings(args):
assert not args.do_train and (args.do_test or args.do_eval)
train_args = torch.load(op.join(args.eval_model_dir, 'training_args.bin'))
override_params = ['do_lower_case', 'img_feature_type', 'max_seq_length',
'max_img_seq_length', 'add_od_labels', 'od_label_type',
'use_img_layernorm', 'img_layer_norm_eps']
for param in override_params:
if hasattr(train_args, param):
train_v = getattr(train_args, param)
test_v = getattr(args, param)
if train_v != test_v:
logger.warning('Override {} with train args: {} -> {}'.format(param,
test_v, train_v))
setattr(args, param, train_v)
return args
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default='datasets/coco_ir', type=str, required=False,
help="The input data dir with all required files.")
parser.add_argument("--img_feat_file", default='datasets/coco_ir/features.tsv', type=str, required=False,
help="The absolute address of the image feature file.")
parser.add_argument("--model_name_or_path", default=None, type=str, required=False,
help="Path to pre-trained model or model type. required for training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
parser.add_argument("--loss_type", default='sfmx', type=str,
help="Loss function types: support kl, sfmx")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name.")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name.")
parser.add_argument("--max_seq_length", default=70, type=int,
help="The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, "
"sequences shorter will be padded."
"This number is calculated on COCO dataset"
"If add object detection labels, the suggested length should be 70.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_test", action='store_true', help="Whether to run inference.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run performance valuation."
"do not activate if we want to inference on dataset without gt labels.")
parser.add_argument("--test_split", default='test', type=str, help='data split name.')
parser.add_argument("--eval_img_keys_file", default='', type=str,
help="image key tsv to select a subset of images for evaluation. "
"This is useful in 5-folds evaluation. The topn index file is not "
"needed in this case.")
parser.add_argument("--eval_caption_index_file", default='', type=str,
help="index of a list of (img_key, cap_idx) for each image."
"this is used to perform re-rank using hard negative samples."
"useful for validation set to monitor the performance during training.")
parser.add_argument("--cross_image_eval", action='store_true',
help="perform cross image inference, ie. each image with all texts from other images.")
parser.add_argument("--add_od_labels", default=False, action='store_true',
help="Whether to add object detection labels or not.")
parser.add_argument("--od_label_type", default='vg', type=str,
help="label type, support vg, gt, oid")
parser.add_argument("--att_mask_type", default='CLR', type=str,
help="attention mask type, support ['CL', 'CR', 'LR', 'CLR']"
"C: caption, L: labels, R: image regions; CLR is full attention by default."
"CL means attention between caption and labels."
"please pay attention to the order CLR, which is the default concat order.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--drop_out", default=0.1, type=float, help="Drop out in BERT.")
parser.add_argument("--max_img_seq_length", default=50, type=int,
help="The maximum total input image sequence length.")
parser.add_argument("--img_feature_dim", default=2054, type=int,
help="The Image Feature Dimension.")
parser.add_argument("--img_feature_type", default='frcnn', type=str,
help="Image feature type.")
parser.add_argument("--use_img_layernorm", type=int, default=1,
help="Normalize image features with bertlayernorm")
parser.add_argument("--img_layer_norm_eps", default=1e-12, type=float,
help="The eps in image feature laynorm layer")
parser.add_argument("--per_gpu_train_batch_size", default=32, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument("--output_mode", default='classification', type=str,
help="output mode, support classification or regression.")
parser.add_argument("--num_labels", default=2, type=int,
help="num_labels is 2 for classification and 1 for regression.")
parser.add_argument("--num_captions_per_img_train", default=5, type=int,
help="number of positive matched captions for each training image.")
parser.add_argument("--num_captions_per_img_val", default=5, type=int,
help="number of captions for each testing image.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before backward.")
parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial lr.")
parser.add_argument("--weight_decay", default=0.05, type=float, help="Weight deay.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup.")
parser.add_argument("--scheduler", default='linear', type=str, help="constant or linear.")
parser.add_argument("--num_workers", default=4, type=int, help="Workers in dataloader.")
parser.add_argument("--num_train_epochs", default=20, type=int,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="Total number of training steps. Override num_train_epochs.")
parser.add_argument('--logging_steps', type=int, default=20, help="Log every X steps.")
parser.add_argument('--save_steps', type=int, default=-1,
help="Save checkpoint every X steps. Will also perform evaluatin.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Run evaluation during training at each save_steps.")
parser.add_argument("--eval_model_dir", type=str, default='',
help="Model directory for evaluation.")
parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA.")
parser.add_argument('--seed', type=int, default=88, help="random seed for initialization.")
args = parser.parse_args()
global logger
mkdir(args.output_dir)
logger = setup_logger("vlpretrain", args.output_dir, 0)
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
set_seed(args.seed, args.n_gpu)
logger.warning("Device: %s, n_gpu: %s", args.device, args.n_gpu)
logger.info('output_mode: {}, #Labels: {}'.format(args.output_mode, args.num_labels))
config_class, tokenizer_class = BertConfig, BertTokenizer
model_class = ImageBertForSequenceClassification
if args.do_train:
config = config_class.from_pretrained(args.config_name if args.config_name else \
args.model_name_or_path, num_labels=args.num_labels, finetuning_task='ir')
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name \
else args.model_name_or_path, do_lower_case=args.do_lower_case)
config.img_feature_dim = args.img_feature_dim
config.img_feature_type = args.img_feature_type
config.hidden_dropout_prob = args.drop_out
config.loss_type = args.loss_type
config.img_layer_norm_eps = args.img_layer_norm_eps
config.use_img_layernorm = args.use_img_layernorm
model = model_class.from_pretrained(args.model_name_or_path,
from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
else:
checkpoint = args.eval_model_dir
assert op.isdir(checkpoint)
config = config_class.from_pretrained(checkpoint)
tokenizer = tokenizer_class.from_pretrained(checkpoint)
logger.info("Evaluate the following checkpoint: %s", checkpoint)
model = model_class.from_pretrained(checkpoint, config=config)
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
if args.do_train:
train_dataset = RetrievalDataset(tokenizer, args, 'train', is_train=True)
if args.evaluate_during_training:
val_dataset = RetrievalDataset(tokenizer, args, 'minival', is_train=False)
else:
val_dataset = None
global_step, avg_loss = train(args, train_dataset, val_dataset, model, tokenizer)
logger.info("Training done: total_step = %s, avg loss = %s", global_step, avg_loss)
# inference and evaluation
if args.do_test or args.do_eval:
args = restore_training_settings(args)
test_dataset = RetrievalDataset(tokenizer, args, args.test_split, is_train=False)
checkpoint = args.eval_model_dir
assert op.isdir(checkpoint)
logger.info("Evaluate the following checkpoint: %s", checkpoint)
model = model_class.from_pretrained(checkpoint, config=config)
model.to(args.device)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
pred_file = get_predict_file(args)
if op.isfile(pred_file):
logger.info("Prediction file exist, skip inference.")
if args.do_eval:
test_result = torch.load(pred_file)
else:
test_result = test(args, model, test_dataset)
torch.save(test_result, pred_file)
logger.info("Prediction results saved to {}.".format(pred_file))
if args.do_eval:
eval_result = evaluate(test_dataset, test_result)
result_file = op.splitext(pred_file)[0] + '.eval.json'
with open(result_file, 'w') as f:
json.dump(eval_result, f)
logger.info("Evaluation results saved to {}.".format(result_file))
if __name__ == "__main__":
main()