-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_encoder_single_task.py
89 lines (72 loc) · 2.86 KB
/
train_encoder_single_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
from models import *
from sprites_datagen.moving_sprites import MovingSpriteDataset
from general_utils import AttrDict, parse_dataset, task_to_idx
from sprites_datagen.rewards import *
from plotter import *
import argparse
import random
parser = argparse.ArgumentParser(description='Reward')
parser.add_argument('-r', '--reward', help='Specify the reward')
args = parser.parse_args()
spec = AttrDict(
resolution=64,
max_seq_len=30,
max_speed=0.05, # total image range [0, 1]
obj_size=0.2, # size of objects, full images is 1.0
shapes_per_traj=3, # number of shapes per trajectory
rewards=[VertPosReward, HorPosReward, AgentXReward, AgentYReward,
TargetXReward, TargetYReward], # total 6 tasks here
)
# constants
N = 5
T = 30 - N
lr = 0.001
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'current device is {device}')
# dataset
ds = MovingSpriteDataset(spec=spec)
dataset_size = 1000
print('preparing dataset')
buffer = [parse_dataset(ds[0], N, T) for _ in range(dataset_size)]
print('dataset prepared.')
batch_size = 16
####### model definition ######
encoder = Encoder(64) # shared
hidden = HiddenStateEncoder(encoder=encoder) #shared
fre = FutureRewardsEstimator(hidden, N=N, T=T, Heads=1)
print(fre)
for m in fre.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
nn.init.constant_(m.bias, 0.1)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
optim = torch.optim.RAdam(params=fre.parameters(), lr=lr, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optim, lr_lambda=lambda epoch: 0.9999 ** epoch, last_epoch=-1, verbose=False)
####### end of the model definition ######
Epochs = 1000
losses = []
# task_name = 'vertical_position'
task_name = args.reward
for e in range(Epochs):
loss_epoch = 0
batch = random.sample(buffer, batch_size)
for data in batch:
# preprocessing
tasks, input_images = data
task = tasks[task_to_idx(task_name)]
optim.zero_grad()
# learn for each tasks
estimated_reward = fre(input_images)
loss = nn.MSELoss(reduction='sum')(estimated_reward, task)
loss.backward()
loss_epoch += loss.item()
optim.step()
scheduler.step()
losses.append(loss_epoch / batch_size)
print(f"epoch:{e} - loss:{np.mean(losses[-30:]):.5f} - lr:{optim.param_groups[0]['lr']:.8f} ")
torch.save(encoder.state_dict(), f'Results/encoder/encoderv2_{task_name}.pth')
torch.save(fre.state_dict(), f'Results/encoder/frev2_{task_name}.pth')
plot_and_save_loss_per_epoch_1(losses, f'encoderv2_{task_name} pretraining', 'encoder')