diff --git a/ults/functional.py b/ults/functional.py index 3ff6cd2..086112e 100644 --- a/ults/functional.py +++ b/ults/functional.py @@ -32,6 +32,8 @@ def generate( stop_at_eos: bool = True, use_full_budget: bool = True, acquisition_function: str = "posterior", + ngram_penalty: float = 0.0, + ngram_order: int = 4, ) -> ULTSOutput: """ULTS: Uncertainty-guided Likelihood-Tree Search. @@ -59,6 +61,9 @@ def generate( "posterior": pick child node based on posterior over max loglik. "posterior_descendant": pick child node based on posterior over max loglik of the best descendant. + ngram_penalty: penalty parameter for punishing repetitive sequences + ngram_order: highest order of the n-grams that should be taken into account for punishing + repetitive sequences, n>1. Returns: @@ -80,6 +85,8 @@ def generate( stop_at_eos=stop_at_eos, use_full_budget=use_full_budget, acquisition_function=acquisition_function, + ngram_penalty=ngram_penalty, + ngram_order=ngram_order, ) # Generation results --- full sequence and total_loglik include context diff --git a/ults/ults.py b/ults/ults.py index 7fcfcb3..d9af0ce 100644 --- a/ults/ults.py +++ b/ults/ults.py @@ -10,6 +10,7 @@ from scipy import stats from torch.distributions.beta import Beta from transformers import BatchEncoding +from ults import utils class ULTS: @@ -37,6 +38,9 @@ class ULTS: "posterior": pick child node based on posterior over max loglik. "posterior_descendant": pick child node based on posterior over mx loglik of best descendant. + ngram_penalty: penalty parameter for punishing repetitive sequences + ngram_order: highest order of the n-grams that should be taken into account for punishing + repetitive sequences, n>1 """ def __init__( @@ -56,6 +60,8 @@ def __init__( stop_at_eos: bool = True, use_full_budget: bool = False, acquisition_function: str = "posterior", + ngram_penalty: float = 0.0, + ngram_order: int = 4, ): if prior_kind == "empirical" and prior_empirical_dataset_name is None: raise ValueError( @@ -67,6 +73,9 @@ def __init__( "`acquisition_function` can only be `posterior` or `posterior_descendant`." ) + if ngram_order < 2: + raise ValueError("ngram_order can only be > 1.") + self.model = model self.is_encoder_decoder = model.config.is_encoder_decoder self.epsilon = epsilon @@ -101,6 +110,9 @@ def __init__( self.encoder_inputs = None self.encoder_outputs = None + self.ngram_penalty = ngram_penalty * (self.depth + tokens.size(-1)) + self.ngram_order = ngram_order + self.tree = nx.DiGraph() self.tree.add_node( "0", @@ -258,6 +270,15 @@ def budget_left(self) -> bool: """ return self.max_beam_size >= self.used_max_beam_size[-1] + def log_diversity(self, tokens) -> float: + """Diversity measure of a token sequence (Also see: https://arxiv.org/pdf/2202.06417)""" + return np.sum( + [ + np.log(1 - utils.rep_n(tokens, n) / 100) + for n in range(2, self.ngram_order + 1) + ] + ) + def set_nodes_to_inactive(self) -> None: """Check the number of expanded nodes per level of the tree. If this number exceeds the constraint on the maximal number, set all nodes on this level and @@ -392,17 +413,25 @@ def search(self) -> tuple[torch.Tensor, float, int]: print(child_obs) child_name = new_node_name + "*" + str(i) child_tokens = children_tokens[i][None, :] + penalty = ( + 0 + if self.ngram_penalty == 0 + else self.log_diversity(child_tokens[0].tolist()) + ) if self.stop_at_eos and child_tokens[0, -1] == self.eos_token: - child_samples = children_observations[i].repeat( - self.sample_size + child_samples = ( + children_observations[i].repeat(self.sample_size) + + self.ngram_penalty * penalty ) if self.use_full_budget: # make sure we don't select the eos node again in the next iteration # by setting it to -inf. child_samples = torch.full(self.sample_size, float('-inf')) else: - child_samples = children_samples[i] + child_samples = ( + children_samples[i] + self.ngram_penalty * penalty + ) self.tree.add_node( child_name, @@ -425,12 +454,19 @@ def search(self) -> tuple[torch.Tensor, float, int]: ): + if self.use_full_budget: # we want to compare by average log likelihood child_obs = child_obs / child_tokens.size(-1) - if child_obs > best_observed_value: + if ( + child_obs + self.ngram_penalty * penalty + > best_observed_value + ): best_path = children_tokens[i][None, :] - best_observed_value = child_obs.item() + best_observed_value = ( + child_obs.item() + self.ngram_penalty * penalty + ) + best_observed_loglike = child_obs.item() # Update optimal value distribution of parents self.backup(new_node_name) @@ -444,6 +480,9 @@ def search(self) -> tuple[torch.Tensor, float, int]: torch.sum(best_observed_value >= overall_max_samples) / self.sample_size ) + if self.ngram_penalty > 0: + best_observed_value = best_observed_loglike + # translate to total log-likelihood again if self.use_full_budget: best_observed_value = best_observed_value * best_path.size(-1) diff --git a/ults/utils.py b/ults/utils.py new file mode 100644 index 0000000..f2fc2ac --- /dev/null +++ b/ults/utils.py @@ -0,0 +1,11 @@ +"""Utilities.""" + +def n_grams(tokens, n) -> int: + """n_grams in the token sequence.""" + return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] + +def rep_n(tokens, n) -> float: + """portion of duplicate n-grams. (Also see: https://arxiv.org/pdf/2202.06417)""" + total_ngrams = n_grams(tokens, n) + unique_ngrams = set(total_ngrams) + return 100 * (1 - len(unique_ngrams) / len(total_ngrams))