Skip to content

Commit

Permalink
Fix torch.fill issue
Browse files Browse the repository at this point in the history
  • Loading branch information
wiseodd committed Nov 15, 2024
1 parent 62c2d76 commit 88f8a56
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions ults/ults.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,15 @@ def search(self) -> tuple[torch.Tensor, float, int]:
children_observations[i].repeat(self.sample_size)
+ self.ngram_penalty * penalty
)

if self.use_full_budget:
# make sure we don't select the eos node again in the next iteration
# by setting it to -inf.
child_samples = torch.full(self.sample_size, float('-inf'))
child_samples = torch.full(
(self.sample_size,),
float("-inf"),
device=self.device,
)
else:
child_samples = (
children_samples[i] + self.ngram_penalty * penalty
Expand All @@ -451,9 +456,6 @@ def search(self) -> tuple[torch.Tensor, float, int]:
if child_depth == self.depth or (
self.stop_at_eos and child_tokens[0, -1] == self.eos_token
):



if self.use_full_budget:
# we want to compare by average log likelihood
child_obs = child_obs / child_tokens.size(-1)
Expand Down Expand Up @@ -486,5 +488,4 @@ def search(self) -> tuple[torch.Tensor, float, int]:
if self.use_full_budget:
best_observed_value = best_observed_value * best_path.size(-1)


return best_path, best_observed_value, n_llm_calls

0 comments on commit 88f8a56

Please sign in to comment.