diff --git a/llmtune/cli/toolkit.py b/llmtune/cli/toolkit.py index 7451722..36fbe2e 100644 --- a/llmtune/cli/toolkit.py +++ b/llmtune/cli/toolkit.py @@ -15,7 +15,8 @@ from llmtune.finetune.lora import LoRAFinetune from llmtune.inference.lora import LoRAInference from llmtune.pydantic_models.config_model import Config -from llmtune.qa.generics import LLMTestSuite, QaTestRegistry +from llmtune.qa.generics import LLMTestSuite +from llmtune.qa.qa_tests import QaTestRegistry from llmtune.ui.rich_ui import RichUI from llmtune.utils.ablation_utils import generate_permutations from llmtune.utils.save_utils import DirectoryHelper @@ -92,6 +93,7 @@ def run_one_experiment(config: Config, config_path: Path) -> None: tests = QaTestRegistry.create_tests_from_list(llm_tests) test_suite = LLMTestSuite.from_csv(results_file_path, tests) test_suite.save_test_results(qa_file_path) + test_suite.print_test_results() @app.command("run") diff --git a/llmtune/config.yml b/llmtune/config.yml index f77eab6..ed25e8d 100644 --- a/llmtune/config.yml +++ b/llmtune/config.yml @@ -23,7 +23,7 @@ data: # Model Definition ------------------- model: - hf_model_ckpt: "mistralai/Mistral-7B-Instruct-v0.2" + hf_model_ckpt: "facebook/opt-125m" torch_dtype: "bfloat16" #attn_implementation: "flash_attention_2" quantize: true diff --git a/llmtune/pydantic_models/config_model.py b/llmtune/pydantic_models/config_model.py index ab39a3a..fe1ab82 100644 --- a/llmtune/pydantic_models/config_model.py +++ b/llmtune/pydantic_models/config_model.py @@ -244,3 +244,4 @@ class Config(BaseModel): lora: LoraConfig training: TrainingConfig inference: InferenceConfig + qa: QaConfig diff --git a/llmtune/qa/generics.py b/llmtune/qa/generics.py index 8efabc2..0f9d6d3 100644 --- a/llmtune/qa/generics.py +++ b/llmtune/qa/generics.py @@ -1,5 +1,6 @@ import statistics from abc import ABC, abstractmethod +from pathlib import Path from typing import Dict, List, Union import pandas as pd @@ -18,23 +19,6 @@ def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[f pass -class QaTestRegistry: - registry = {} - - @classmethod - def register(cls, *names): - def inner_wrapper(wrapped_class): - for name in names: - cls.registry[name] = wrapped_class - return wrapped_class - - return inner_wrapper - - @classmethod - def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]: - return [cls.create_test(test) for test in test_names] - - class LLMTestSuite: def __init__( self, @@ -51,11 +35,17 @@ def __init__( self._results = {} @staticmethod - def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite": + def from_csv( + file_path: str, + tests: List[LLMQaTest], + prompt_col: str = "Prompt", + gold_col: str = "Ground Truth", + pred_col="Predicted", + ) -> "LLMTestSuite": results_df = pd.read_csv(file_path) - prompts = results_df["prompt"].tolist() - ground_truths = results_df["ground_truth"].tolist() - model_preds = results_df["model_prediction"].tolist() + prompts = results_df[prompt_col].tolist() + ground_truths = results_df[gold_col].tolist() + model_preds = results_df[pred_col].tolist() return LLMTestSuite(tests, prompts, ground_truths, model_preds) def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]: @@ -84,5 +74,11 @@ def print_test_results(self): def save_test_results(self, path: str): # TODO: save these! + path = Path(path) + dir = path.parent + + if not dir.exists(): + dir.mkdir(parents=True, exist_ok=True) + resultant_dataframe = pd.DataFrame(self.test_results) resultant_dataframe.to_csv(path, index=False) diff --git a/llmtune/qa/qa_tests.py b/llmtune/qa/qa_tests.py index 46236d4..9a01861 100644 --- a/llmtune/qa/qa_tests.py +++ b/llmtune/qa/qa_tests.py @@ -9,7 +9,7 @@ from rouge_score import rouge_scorer from transformers import DistilBertModel, DistilBertTokenizer -from llmtune.qa.generics import LLMQaTest, QaTestRegistry +from llmtune.qa.generics import LLMQaTest model_name = "distilbert-base-uncased" @@ -21,6 +21,23 @@ nltk.download("averaged_perceptron_tagger") +class QaTestRegistry: + registry = {} + + @classmethod + def register(cls, *names): + def inner_wrapper(wrapped_class): + for name in names: + cls.registry[name] = wrapped_class + return wrapped_class + + return inner_wrapper + + @classmethod + def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]: + return [cls.registry[test]() for test in test_names] + + @QaTestRegistry.register("summary_length") class LengthTest(LLMQaTest): @property