Skip to content

Commit

Permalink
Merge branch 'ngram' into keep-searching
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliaGrosse committed Nov 15, 2024
2 parents 443b28e + a1e172c commit 14cb82a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 5 deletions.
7 changes: 7 additions & 0 deletions ults/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def generate(
stop_at_eos: bool = True,
use_full_budget: bool = True,
acquisition_function: str = "posterior",
ngram_penalty: float = 0.0,
ngram_order: int = 4,
) -> ULTSOutput:
"""ULTS: Uncertainty-guided Likelihood-Tree Search.
Expand Down Expand Up @@ -59,6 +61,9 @@ def generate(
"posterior": pick child node based on posterior over max loglik.
"posterior_descendant": pick child node based on posterior over max loglik
of the best descendant.
ngram_penalty: penalty parameter for punishing repetitive sequences
ngram_order: highest order of the n-grams that should be taken into account for punishing
repetitive sequences, n>1.
Returns:
Expand All @@ -80,6 +85,8 @@ def generate(
stop_at_eos=stop_at_eos,
use_full_budget=use_full_budget,
acquisition_function=acquisition_function,
ngram_penalty=ngram_penalty,
ngram_order=ngram_order,
)

# Generation results --- full sequence and total_loglik include context
Expand Down
49 changes: 44 additions & 5 deletions ults/ults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy import stats
from torch.distributions.beta import Beta
from transformers import BatchEncoding
from ults import utils


class ULTS:
Expand Down Expand Up @@ -37,6 +38,9 @@ class ULTS:
"posterior": pick child node based on posterior over max loglik.
"posterior_descendant": pick child node based on posterior over mx loglik
of best descendant.
ngram_penalty: penalty parameter for punishing repetitive sequences
ngram_order: highest order of the n-grams that should be taken into account for punishing
repetitive sequences, n>1
"""

def __init__(
Expand All @@ -56,6 +60,8 @@ def __init__(
stop_at_eos: bool = True,
use_full_budget: bool = False,
acquisition_function: str = "posterior",
ngram_penalty: float = 0.0,
ngram_order: int = 4,
):
if prior_kind == "empirical" and prior_empirical_dataset_name is None:
raise ValueError(
Expand All @@ -67,6 +73,9 @@ def __init__(
"`acquisition_function` can only be `posterior` or `posterior_descendant`."
)

if ngram_order < 2:
raise ValueError("ngram_order can only be > 1.")

self.model = model
self.is_encoder_decoder = model.config.is_encoder_decoder
self.epsilon = epsilon
Expand Down Expand Up @@ -101,6 +110,9 @@ def __init__(
self.encoder_inputs = None
self.encoder_outputs = None

self.ngram_penalty = ngram_penalty * (self.depth + tokens.size(-1))
self.ngram_order = ngram_order

self.tree = nx.DiGraph()
self.tree.add_node(
"0",
Expand Down Expand Up @@ -258,6 +270,15 @@ def budget_left(self) -> bool:
"""
return self.max_beam_size >= self.used_max_beam_size[-1]

def log_diversity(self, tokens) -> float:
"""Diversity measure of a token sequence (Also see: https://arxiv.org/pdf/2202.06417)"""
return np.sum(
[
np.log(1 - utils.rep_n(tokens, n) / 100)
for n in range(2, self.ngram_order + 1)
]
)

def set_nodes_to_inactive(self) -> None:
"""Check the number of expanded nodes per level of the tree. If this number
exceeds the constraint on the maximal number, set all nodes on this level and
Expand Down Expand Up @@ -392,17 +413,25 @@ def search(self) -> tuple[torch.Tensor, float, int]:
print(child_obs)
child_name = new_node_name + "*" + str(i)
child_tokens = children_tokens[i][None, :]
penalty = (
0
if self.ngram_penalty == 0
else self.log_diversity(child_tokens[0].tolist())
)

if self.stop_at_eos and child_tokens[0, -1] == self.eos_token:
child_samples = children_observations[i].repeat(
self.sample_size
child_samples = (
children_observations[i].repeat(self.sample_size)
+ self.ngram_penalty * penalty
)
if self.use_full_budget:
# make sure we don't select the eos node again in the next iteration
# by setting it to -inf.
child_samples = torch.full(self.sample_size, float('-inf'))
else:
child_samples = children_samples[i]
child_samples = (
children_samples[i] + self.ngram_penalty * penalty
)

self.tree.add_node(
child_name,
Expand All @@ -425,12 +454,19 @@ def search(self) -> tuple[torch.Tensor, float, int]:
):



if self.use_full_budget:
# we want to compare by average log likelihood
child_obs = child_obs / child_tokens.size(-1)
if child_obs > best_observed_value:
if (
child_obs + self.ngram_penalty * penalty
> best_observed_value
):
best_path = children_tokens[i][None, :]
best_observed_value = child_obs.item()
best_observed_value = (
child_obs.item() + self.ngram_penalty * penalty
)
best_observed_loglike = child_obs.item()

# Update optimal value distribution of parents
self.backup(new_node_name)
Expand All @@ -444,6 +480,9 @@ def search(self) -> tuple[torch.Tensor, float, int]:
torch.sum(best_observed_value >= overall_max_samples) / self.sample_size
)

if self.ngram_penalty > 0:
best_observed_value = best_observed_loglike

# translate to total log-likelihood again
if self.use_full_budget:
best_observed_value = best_observed_value * best_path.size(-1)
Expand Down
11 changes: 11 additions & 0 deletions ults/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Utilities."""

def n_grams(tokens, n) -> int:
"""n_grams in the token sequence."""
return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]

def rep_n(tokens, n) -> float:
"""portion of duplicate n-grams. (Also see: https://arxiv.org/pdf/2202.06417)"""
total_ngrams = n_grams(tokens, n)
unique_ngrams = set(total_ngrams)
return 100 * (1 - len(unique_ngrams) / len(total_ngrams))

0 comments on commit 14cb82a

Please sign in to comment.