Skip to content

Commit

Permalink
Merge pull request #81 from zouharvi/patch-unnecessary-loop
Browse files Browse the repository at this point in the history
remove unnecessary loop branch
  • Loading branch information
jplalor authored Dec 16, 2024
2 parents ef317c1 + a910bf7 commit eaed714
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions py_irt/models/amortized_1pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,11 @@ def guide_irt(self, models, items, obs):
irt_batch_size = 256
loc_diffs_all, scale_diffs_all = [], []
for i in range(0, len(items), irt_batch_size):
if len(items[i:]) < irt_batch_size:
batch_xs = items[i:]
loc_diffs, scale_diffs = self.encoder.forward(batch_xs)
loc_diffs_all.extend(loc_diffs)
scale_diffs_all.extend(scale_diffs)
else:
# pick out the appropriate images from xs based on items idx
batch_xs = items[i:i+irt_batch_size]
loc_diffs, scale_diffs = self.encoder.forward(batch_xs)
loc_diffs_all.extend(loc_diffs)
scale_diffs_all.extend(scale_diffs)
# pick out the appropriate images from xs based on items idx
batch_xs = items[i:i+irt_batch_size]
loc_diffs, scale_diffs = self.encoder.forward(batch_xs)
loc_diffs_all.extend(loc_diffs)
scale_diffs_all.extend(scale_diffs)
loc_diffs_all = torch.tensor(loc_diffs_all, **options).unsqueeze(1).float()
scale_diffs_all = torch.tensor(scale_diffs_all, **options).unsqueeze(1).float()
dist_b = dist.Normal(loc_diffs_all, scale_diffs_all)
Expand Down

0 comments on commit eaed714

Please sign in to comment.