-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
37 lines (29 loc) · 971 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from fit import fit
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import math
import yaml
with open('config.yaml', 'r') as config_file:
config = yaml.safe_load(config_file)
msamples = [20, 50, 100, 150, 200, 300, 500, 700, 1000, 1500, 2000, 2500, 3000]
losses = []
nepochs = []
for msample in msamples:
config['max_samples'] = msample
config['nepoch'] = (math.ceil(math.log(msample, 2)) + 1) * config['schedule_interval'] + 1
nepochs.append(config['nepoch'])
# breakpoint()
rc, gt = fit(config)
with torch.no_grad():
loss = nn.MSELoss()(rc, gt)
losses.append(loss)
log = open('log/log_exp.txt', 'w')
log.write(str(msamples[:len(losses)]) + '\n')
log.write(str(nepochs) + '\n')
log.write(str(losses) + '\n')
log.close()
plt.plot(msamples[:len(losses)], losses)
plt.xlabel('Number of Gaussians')
plt.ylabel('MSE Loss')
plt.savefig('log/loss_exp.png')