Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wiseodd committed Nov 15, 2024
1 parent 88f8a56 commit c0f958b
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions ults/ults.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,6 @@ def search(self) -> tuple[torch.Tensor, float, int]:
)

if self.stop_at_eos and child_tokens[0, -1] == self.eos_token:
child_samples = (
children_observations[i].repeat(self.sample_size)
+ self.ngram_penalty * penalty
)

if self.use_full_budget:
# make sure we don't select the eos node again in the next iteration
# by setting it to -inf.
Expand All @@ -432,6 +427,11 @@ def search(self) -> tuple[torch.Tensor, float, int]:
float("-inf"),
device=self.device,
)
else:
child_samples = (
children_observations[i].repeat(self.sample_size)
+ self.ngram_penalty * penalty
)
else:
child_samples = (
children_samples[i] + self.ngram_penalty * penalty
Expand All @@ -458,17 +458,19 @@ def search(self) -> tuple[torch.Tensor, float, int]:
):
if self.use_full_budget:
# we want to compare by average log likelihood
child_obs = child_obs / child_tokens.size(-1)
if (
child_obs + self.ngram_penalty * penalty
> best_observed_value
):
best_path = children_tokens[i][None, :]
best_observed_value = (
child_obs.item() + self.ngram_penalty * penalty
observed_value = (
child_obs / child_tokens.size(-1)
+ self.ngram_penalty * penalty
)

if observed_value > best_observed_value:
best_path = children_tokens[i][None, :]
best_observed_value = observed_value
best_observed_loglike = child_obs.item()

if self.use_full_budget:
best_observed_loglike /= child_tokens.size(-1)

# Update optimal value distribution of parents
self.backup(new_node_name)

Expand All @@ -477,6 +479,7 @@ def search(self) -> tuple[torch.Tensor, float, int]:
overall_max_samples = self.tree.nodes["0"]["max_samples"]
else:
overall_max_samples = self.tree.nodes["0"]["samples"]

prob_result_nodes = (
torch.sum(best_observed_value >= overall_max_samples) / self.sample_size
)
Expand Down

0 comments on commit c0f958b

Please sign in to comment.