Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving metrics to metrics and adding pass/fail LLM tests #186

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions llmtune/cli/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from llmtune.finetune.lora import LoRAFinetune
from llmtune.inference.lora import LoRAInference
from llmtune.pydantic_models.config_model import Config
from llmtune.qa.generics import LLMTestSuite
from llmtune.qa.qa_tests import QaTestRegistry
from llmtune.qa.generics import LLMMetricSuite
from llmtune.qa.qa_metrics import QaMetricRegistry
from llmtune.ui.rich_ui import RichUI
from llmtune.utils.ablation_utils import generate_permutations
from llmtune.utils.save_utils import DirectoryHelper
Expand Down Expand Up @@ -91,10 +91,10 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
qa_file_path = dir_helper.save_paths.qa_file
if not qa_file_path.exists():
llm_metrics = config.qa.llm_metrics
tests = QaTestRegistry.create_tests_from_list(llm_metrics)
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
test_suite.save_test_results(qa_file_path)
test_suite.print_test_results()
tests = QaMetricRegistry.create_tests_from_list(llm_metrics)
test_suite = LLMMetricSuite.from_csv(results_file_path, tests)
test_suite.save_metric_results(qa_file_path)
test_suite.print_metric_results()


@app.command("run")
Expand Down
58 changes: 26 additions & 32 deletions llmtune/qa/generics.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,78 @@
import statistics
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union

import pandas as pd

from llmtune.qa.qa_metrics import LLMQaMetric
from llmtune.ui.rich_ui import RichUI


class LLMQaTest(ABC):
@property
@abstractmethod
def test_name(self) -> str:
pass

@abstractmethod
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int, bool]:
pass

class LLMMetricSuite:
"""
Represents and runs a suite of metrics on a set of prompts,
golden responses, and model predictions.
"""

class LLMTestSuite:
def __init__(
self,
tests: List[LLMQaTest],
metrics: List[LLMQaMetric],
prompts: List[str],
ground_truths: List[str],
model_preds: List[str],
) -> None:
self.tests = tests
self.metrics = metrics
self.prompts = prompts
self.ground_truths = ground_truths
self.model_preds = model_preds

self._results = {}
self._results: Dict[str, List[Union[float, int]]] = {}

@staticmethod
def from_csv(
file_path: str,
tests: List[LLMQaTest],
metrics: List[LLMQaMetric],
prompt_col: str = "Prompt",
gold_col: str = "Ground Truth",
pred_col="Predicted",
) -> "LLMTestSuite":
) -> "LLMMetricSuite":
results_df = pd.read_csv(file_path)
prompts = results_df[prompt_col].tolist()
ground_truths = results_df[gold_col].tolist()
model_preds = results_df[pred_col].tolist()
return LLMTestSuite(tests, prompts, ground_truths, model_preds)
return LLMMetricSuite(metrics, prompts, ground_truths, model_preds)

def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
test_results = {}
for test in self.tests:
metrics = []
def compute_metrics(self) -> Dict[str, List[Union[float, int]]]:
results = {}
for metric in self.metrics:
metric_results = []
for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds):
metrics.append(test.get_metric(prompt, ground_truth, model_pred))
test_results[test.test_name] = metrics
metric_results.append(metric.get_metric(prompt, ground_truth, model_pred))
results[metric.metric_name] = metric_results

self._results = test_results
return test_results
self._results = results
return results

@property
def test_results(self):
return self._results if self._results else self.run_tests()
def metric_results(self) -> Dict[str, List[Union[float, int]]]:
return self._results if self._results else self.compute_metrics()

def print_test_results(self):
result_dictionary = self.test_results
def print_metric_results(self):
result_dictionary = self.metric_results
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
median_values = {key: statistics.median(column_data[key]) for key in column_data}
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
# Use the RichUI class to display the table
RichUI.qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values)

def save_test_results(self, path: str):
def save_metric_results(self, path: str):
# TODO: save these!
path = Path(path)
dir = path.parent

if not dir.exists():
dir.mkdir(parents=True, exist_ok=True)

resultant_dataframe = pd.DataFrame(self.test_results)
resultant_dataframe = pd.DataFrame(self.metric_results)
resultant_dataframe.to_csv(path, index=False)
210 changes: 210 additions & 0 deletions llmtune/qa/qa_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from abc import ABC, abstractmethod
from typing import List, Union

import nltk
import numpy as np
import torch
from langchain.evaluation import JsonValidityEvaluator
from nltk import pos_tag
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from rouge_score import rouge_scorer
from transformers import DistilBertModel, DistilBertTokenizer


json_validity_evaluator = JsonValidityEvaluator()

nltk.download("stopwords")
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")


class LLMQaMetric(ABC):
"""
Abstract base class for a metric. A metric can be computed over a single
data instance, and outputs a scalar value (integer or float).
"""

@property
@abstractmethod
def metric_name(self) -> str:
pass

@abstractmethod
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int]:
pass


class QaMetricRegistry:
registry = {}

@classmethod
def register(cls, *names):
def inner_wrapper(wrapped_class):
for name in names:
cls.registry[name] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_tests_from_list(cls, metric_names: List[str]) -> List[LLMQaMetric]:
return [cls.registry[test]() for test in metric_names]


@QaMetricRegistry.register("summary_length")
class LengthMetric(LLMQaMetric):
@property
def metric_name(self) -> str:
return "summary_length"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
return abs(len(ground_truth) - len(model_prediction))


@QaMetricRegistry.register("jaccard_similarity")
class JaccardSimilarityMetric(LLMQaMetric):
@property
def metric_name(self) -> str:
return "jaccard_similarity"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
set_ground_truth = set(ground_truth.lower())
set_model_prediction = set(model_prediction.lower())

intersection_size = len(set_ground_truth.intersection(set_model_prediction))
union_size = len(set_ground_truth.union(set_model_prediction))

similarity = intersection_size / union_size if union_size != 0 else 0
return float(similarity)


@QaMetricRegistry.register("dot_product")
class DotProductSimilarityMetric(LLMQaMetric):
"""Encodes both the ground truth and model prediction using DistilBERT, and
computes the dot product similarity between the two embeddings."""

def __init__(self):
model_name = "distilbert-base-uncased"
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
self.model = DistilBertModel.from_pretrained(model_name)

@property
def metric_name(self) -> str:
return "dot_product"

def _encode_sentence(self, sentence):
tokens = self.tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**tokens)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
embedding_ground_truth = self._encode_sentence(ground_truth)
embedding_model_prediction = self._encode_sentence(model_prediction)
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
return float(dot_product_similarity)


@QaMetricRegistry.register("rouge_score")
class RougeScoreMetric(LLMQaMetric):
@property
def metric_name(self) -> str:
return "rouge_score"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
scores = scorer.score(model_prediction, ground_truth)
return float(scores["rouge1"].precision)


@QaMetricRegistry.register("word_overlap")
class WordOverlapMetric(LLMQaMetric):
@property
def metric_name(self) -> str:
return "word_overlap"

def _remove_stopwords(self, text: str) -> str:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
filtered_text = [word for word in word_tokens if word.lower() not in stop_words]
return " ".join(filtered_text)

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
cleaned_model_prediction = self._remove_stopwords(model_prediction)
cleaned_ground_truth = self._remove_stopwords(ground_truth)

words_model_prediction = set(cleaned_model_prediction.split())
words_ground_truth = set(cleaned_ground_truth.split())

common_words = words_model_prediction.intersection(words_ground_truth)
overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100
return float(overlap_percentage)


@QaMetricRegistry.register("json_valid")
class JSONValidityMetric(LLMQaMetric):
"""
Checks to see if valid json can be parsed from the model output, according
to langchain_core.utils.json.parse_json_markdown
The JSON can be wrapped in markdown and this test will still pass
"""

@property
def metric_name(self) -> str:
return "json_valid"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
result = json_validity_evaluator.evaluate_strings(prediction=model_prediction)
binary_res = result["score"]
return float(binary_res)


class PosCompositionMetric(LLMQaMetric):
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
words = word_tokenize(text)
tags = pos_tag(words)
pos_words = [word for word, tag in tags if tag in pos_tags]
total_words = len(text.split(" "))
return round(len(pos_words) / total_words, 2)


@QaMetricRegistry.register("verb_percent")
class VerbPercentMetric(PosCompositionMetric):
@property
def metric_name(self) -> str:
return "verb_percent"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"])


@QaMetricRegistry.register("adjective_percent")
class AdjectivePercentMetric(PosCompositionMetric):
@property
def metric_name(self) -> str:
return "adjective_percent"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])


@QaMetricRegistry.register("noun_percent")
class NounPercentMetric(PosCompositionMetric):
@property
def metric_name(self) -> str:
return "noun_percent"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"])


# Instantiate tests
# length_test = LengthMetric()
# jaccard_similarity_test = JaccardSimilarityMetric()
# dot_product_similarity_test = DotProductSimilarityMetric()
# rouge_score_test = RougeScoreMetric()
# word_overlap_test = WordOverlapMetric()
# verb_percent_test = VerbPercentMetric()
# adjective_percent_test = AdjectivePercentMetric()
# noun_percent_test = NounPercentMetric()
Loading