From 88f8a56b37e51a6f330d9305f4ed0e766b5d4538 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Fri, 15 Nov 2024 11:03:46 -0500 Subject: [PATCH] Fix torch.fill issue --- ults/ults.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ults/ults.py b/ults/ults.py index 9398cb5..ec6fe1a 100644 --- a/ults/ults.py +++ b/ults/ults.py @@ -423,10 +423,15 @@ def search(self) -> tuple[torch.Tensor, float, int]: children_observations[i].repeat(self.sample_size) + self.ngram_penalty * penalty ) + if self.use_full_budget: # make sure we don't select the eos node again in the next iteration # by setting it to -inf. - child_samples = torch.full(self.sample_size, float('-inf')) + child_samples = torch.full( + (self.sample_size,), + float("-inf"), + device=self.device, + ) else: child_samples = ( children_samples[i] + self.ngram_penalty * penalty @@ -451,9 +456,6 @@ def search(self) -> tuple[torch.Tensor, float, int]: if child_depth == self.depth or ( self.stop_at_eos and child_tokens[0, -1] == self.eos_token ): - - - if self.use_full_budget: # we want to compare by average log likelihood child_obs = child_obs / child_tokens.size(-1) @@ -486,5 +488,4 @@ def search(self) -> tuple[torch.Tensor, float, int]: if self.use_full_budget: best_observed_value = best_observed_value * best_path.size(-1) - return best_path, best_observed_value, n_llm_calls