From c0f958b4e6ccff60875f42947674b584c691a7f8 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Fri, 15 Nov 2024 11:44:29 -0500 Subject: [PATCH] Fix bug --- ults/ults.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/ults/ults.py b/ults/ults.py index ec6fe1a..3864b97 100644 --- a/ults/ults.py +++ b/ults/ults.py @@ -419,11 +419,6 @@ def search(self) -> tuple[torch.Tensor, float, int]: ) if self.stop_at_eos and child_tokens[0, -1] == self.eos_token: - child_samples = ( - children_observations[i].repeat(self.sample_size) - + self.ngram_penalty * penalty - ) - if self.use_full_budget: # make sure we don't select the eos node again in the next iteration # by setting it to -inf. @@ -432,6 +427,11 @@ def search(self) -> tuple[torch.Tensor, float, int]: float("-inf"), device=self.device, ) + else: + child_samples = ( + children_observations[i].repeat(self.sample_size) + + self.ngram_penalty * penalty + ) else: child_samples = ( children_samples[i] + self.ngram_penalty * penalty @@ -458,17 +458,19 @@ def search(self) -> tuple[torch.Tensor, float, int]: ): if self.use_full_budget: # we want to compare by average log likelihood - child_obs = child_obs / child_tokens.size(-1) - if ( - child_obs + self.ngram_penalty * penalty - > best_observed_value - ): - best_path = children_tokens[i][None, :] - best_observed_value = ( - child_obs.item() + self.ngram_penalty * penalty + observed_value = ( + child_obs / child_tokens.size(-1) + + self.ngram_penalty * penalty ) + + if observed_value > best_observed_value: + best_path = children_tokens[i][None, :] + best_observed_value = observed_value best_observed_loglike = child_obs.item() + if self.use_full_budget: + best_observed_loglike /= child_tokens.size(-1) + # Update optimal value distribution of parents self.backup(new_node_name) @@ -477,6 +479,7 @@ def search(self) -> tuple[torch.Tensor, float, int]: overall_max_samples = self.tree.nodes["0"]["max_samples"] else: overall_max_samples = self.tree.nodes["0"]["samples"] + prob_result_nodes = ( torch.sum(best_observed_value >= overall_max_samples) / self.sample_size )