forked from facebookresearch/EGG
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patharchs.py
82 lines (63 loc) · 2.62 KB
/
archs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torchvision
def get_vision_module(encoder_arch: str):
"""Loads ResNet encoder from torchvision along with features number"""
resnets = {
"resnet18": torchvision.models.resnet18(),
"resnet34": torchvision.models.resnet34(),
"resnet50": torchvision.models.resnet50(),
"resnet101": torchvision.models.resnet101(),
"resnet152": torchvision.models.resnet152(),
}
if encoder_arch not in resnets:
raise KeyError(f"{encoder_arch} is not a valid ResNet version")
model = resnets[encoder_arch]
features_dim = model.fc.in_features
model.fc = nn.Identity()
return model, features_dim
class VisionModule(nn.Module):
def __init__(self, vision_module: nn.Module):
super(VisionModule, self).__init__()
self.encoder = vision_module
def forward(self, x_i, x_j):
encoded_input_sender = self.encoder(x_i)
encoded_input_recv = self.encoder(x_j)
return encoded_input_sender, encoded_input_recv
class VisionGameWrapper(nn.Module):
def __init__(self, game: nn.Module, vision_module: nn.Module):
super(VisionGameWrapper, self).__init__()
self.game = game
self.vision_module = vision_module
def forward(self, sender_input, labels, receiver_input=None, _aux_input=None):
x_i, x_j = sender_input
sender_encoded_input, receiver_encoded_input = self.vision_module(x_i, x_j)
return self.game(
sender_input=sender_encoded_input,
labels=labels,
receiver_input=receiver_encoded_input,
)
class Sender(nn.Module):
def __init__(self, visual_features_dim: int, output_dim: int):
super(Sender, self).__init__()
self.fc = nn.Sequential(
nn.Linear(visual_features_dim, visual_features_dim),
nn.BatchNorm1d(visual_features_dim),
nn.ReLU(),
nn.Linear(visual_features_dim, output_dim, bias=False),
)
def forward(self, x):
return self.fc(x)
class Receiver(nn.Module):
def __init__(self, visual_features_dim: int, output_dim: int):
super(Receiver, self).__init__()
self.fc = nn.Sequential(
nn.Linear(visual_features_dim, visual_features_dim),
nn.BatchNorm1d(visual_features_dim),
nn.ReLU(),
nn.Linear(visual_features_dim, output_dim, bias=False),
)
def forward(self, x, _input):
return self.fc(_input)