Skip to content

Commit

Permalink
Add seq2seq eval benchmark callback (axolotl-ai-cloud#1274)
Browse files Browse the repository at this point in the history
* Add CausalLMBenchEvalCallback for measuring seq2seq performance

* Fix code for pre-commit

* Fix typing and improve logging

* eval_sample_packing must be false with CausalLMBenchEvalCallback
  • Loading branch information
LeonardoEmili authored Feb 13, 2024
1 parent 8430db2 commit 5a5d474
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 14 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
max_steps:

eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]

loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/loftq.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/mamba/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/Mistral-7b-example/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
Expand Down
2 changes: 1 addition & 1 deletion examples/yi-34B-chat/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ num_epochs: 1
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.0
evaluate==0.4.1
scipy
scikit-learn==1.2.2
pynvml
Expand Down
11 changes: 11 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
Expand Down Expand Up @@ -148,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
Expand Down Expand Up @@ -664,6 +668,11 @@ def get_post_trainer_create_callbacks(self, trainer):

if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))

if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
Expand Down Expand Up @@ -812,6 +821,8 @@ def build(self, total_num_steps):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
Expand Down
183 changes: 182 additions & 1 deletion src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,187 @@ def on_evaluate(
return BenchEvalCallback


def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()

def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics

def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges

def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score

def evaluate_preds(sources, predictions, references):
scores = {}

for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores

def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []

for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])

prompt_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)

return eval_src, eval_pred, eval_ref

if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))

return control

return CausalLMBenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
Expand Down Expand Up @@ -388,7 +569,7 @@ def on_evaluate(

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
Expand Down
23 changes: 22 additions & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def normalize_config(cfg):
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
Expand Down Expand Up @@ -550,6 +556,21 @@ def validate_config(cfg):
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")

if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)

if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down

0 comments on commit 5a5d474

Please sign in to comment.