forked from facebookresearch/EGG
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
86 lines (73 loc) · 2.18 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import egg.core as core
def add_weight_decay(model, weight_decay=1e-5, skip_name=""):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or skip_name in name:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": no_decay, "weight_decay": 0.0},
{"params": decay, "weight_decay": weight_decay},
]
def get_opts(params):
parser = argparse.ArgumentParser()
# Data opts
parser.add_argument(
"--dataset_dir",
type=str,
default="/datasets01/imagenet_full_size/061417/train",
help="Dataset location",
)
parser.add_argument(
"--dataset_name",
type=str,
default="imagenet",
choices=["cifar10", "imagenet"],
help="Dataset used for training",
)
parser.add_argument("--image_size", type=int, default=224, help="Image size")
parser.add_argument(
"--num_workers", type=int, default=4, help="Workers used in the dataloader"
)
# Vision module opts
parser.add_argument(
"--model_name",
type=str,
default="resnet50",
choices=["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"],
help="Model name for the encoder",
)
# Loss opts
parser.add_argument(
"--ntxent_tau",
type=float,
default=0.1,
help="Temperature for NT XEnt loss",
)
# Arch opts
parser.add_argument(
"--output_size", type=int, default=128, help="Sender/Receiver output size"
)
# Misc opts
parser.add_argument(
"--weight_decay",
type=float,
default=10e-6,
help="Weight decay used for SGD",
)
parser.add_argument(
"--pdb",
action="store_true",
default=False,
help="Run the game with pdb enabled",
)
opts = core.init(arg_parser=parser, params=params)
return opts