Skip to content

Commit

Permalink
updated to gumbel softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
shreshthtuli committed Jan 6, 2022
1 parent 186f0b9 commit 372fd69
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 11 deletions.
6 changes: 4 additions & 2 deletions decider/SciNet_Decider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ def load_model(self):
def decision(self, workflowlist):
if not self.model_loaded: self.load_model()
memory = self.env.provisioner.memory
results = []
provision_scores = self.env.provisioner.scores
decisions = []; results = []
for CreationID, interval, SLA, application in workflowlist:
inp = one_hot(application, self.fn_names)
choice = self.choices[torch.argmax(self.model.forward_decider(memory, inp)).item()]
decision = self.model.forward_decider(memory, inp, provision_scores)
choice = self.choices[torch.argmax(decision).item()]
tasklist = self.createTasks(CreationID, interval, SLA, application, choice)
results += tasklist
return results
Binary file modified provisioner/src/checkpoints/SciNet.ckpt
Binary file not shown.
12 changes: 12 additions & 0 deletions provisioner/src/dlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,15 @@ def multi_logistic_loss(pred, label):
def RMSE(pred, label):
loss = torch.mean(torch.sum((pred - label) ** 2, 1), 0) ** 0.5
return loss

def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape)
return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature, dim):
y = logits #+ sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=dim)

def gumbel_softmax(logits, dim=-1, temperature=0.01, hard=False):
y = gumbel_softmax_sample(logits, temperature, dim)
return y.view(-1)
14 changes: 8 additions & 6 deletions provisioner/src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def __init__(self, feats):
self.transformer_decoder = TransformerDecoder(decoder_layers, 1)
self.fcn = nn.Sigmoid()
self.likelihood_1 = nn.Sequential(nn.Linear(1 + feats, 2), nn.Softmax())
self.likelihood_2 = nn.Sequential(nn.Linear(self.n_apps + feats, self.n_choices), nn.Softmax())
self.likelihood_3 = nn.Sequential(nn.Linear(self.n_apps + feats, feats), nn.Softmax())
self.likelihood_2 = nn.Sequential(nn.Linear(self.n_apps + feats + 2*feats, self.n_choices), nn.Softmax())
self.likelihood_3 = nn.Sequential(nn.Linear(self.n_apps + feats + 2*feats, feats), nn.Softmax())

def predwindow(self, src, tgt):
src = src * math.sqrt(self.n_feats)
Expand All @@ -197,10 +197,12 @@ def forward_provisioner(self, memory, hv):
score_1 = self.likelihood_1(torch.cat((hv.view(-1), memory.view(-1))))
return score_1

def forward_decider(self, memory, dv):
score_2 = self.likelihood_2(torch.cat((dv.view(-1), memory.view(-1))))
def forward_decider(self, memory, dv, p_score):
p_g = gumbel_softmax(p_score, dim=1)
score_2 = self.likelihood_2(torch.cat((dv.view(-1), memory.view(-1), p_g)))
return score_2

def forward_scheduler(self, memory, sv):
score_3 = self.likelihood_3(torch.cat((sv.view(-1), memory.view(-1))))
def forward_scheduler(self, memory, sv, p_score):
p_g = gumbel_softmax(p_score, dim=1)
score_3 = self.likelihood_3(torch.cat((sv.view(-1), memory.view(-1), p_g)))
return score_3
Binary file not shown.
4 changes: 2 additions & 2 deletions provisioner/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def backprop(epoch, model, optimizer, scheduler, data_cpu, data_provisioner, dat
p_in, p_gold = data_provisioner[i]
p_out = torch.stack([model.forward_provisioner(memory, i) for i in p_in])
d_in, d_gold = data_decider[i]
d_out = torch.stack([model.forward_decider(memory, i) for i in d_in])
d_out = torch.stack([model.forward_decider(memory, i, p_out) for i in d_in])
s_in, s_gold = data_scheduler[i]
s_out = torch.stack([model.forward_scheduler(memory, i) for i in d_in])
s_out = torch.stack([model.forward_scheduler(memory, i, p_out) for i in d_in])
p_gold, d_gold, s_gold = torch.stack(p_gold), torch.stack(d_gold), torch.stack(s_gold)
loss = loss + l2(p_out, p_gold) + l2(d_out, d_gold) + l2(s_out, s_gold)
ls.append(torch.mean(loss).item())
Expand Down
3 changes: 2 additions & 1 deletion scheduler/SciNet_Scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def placement(self, tasks):
if not self.model_loaded: self.load_model()
start = time()
memory = self.env.provisioner.memory
provision_scores = self.env.provisioner.scores
decision = []
for task in tasks:
inp = one_hot(task.application, self.fn_names)
scores = self.model.forward_scheduler(memory, inp).tolist()
scores = self.model.forward_scheduler(memory, inp, provision_scores).tolist()
# scores = scores + np.random.random(len(scores)) # debug
# mask disabled hosts
for hostID, host in enumerate(self.env.hostlist):
Expand Down

0 comments on commit 372fd69

Please sign in to comment.