Skip to content

Commit

Permalink
Merge pull request #166 from georgian-io/smaller-example-model
Browse files Browse the repository at this point in the history
Smaller Model for Default Config
  • Loading branch information
benjaminye authored May 7, 2024
2 parents 838534b + 293d474 commit 6a442a8
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 24 deletions.
4 changes: 3 additions & 1 deletion llmtune/cli/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from llmtune.finetune.lora import LoRAFinetune
from llmtune.inference.lora import LoRAInference
from llmtune.pydantic_models.config_model import Config
from llmtune.qa.generics import LLMTestSuite, QaTestRegistry
from llmtune.qa.generics import LLMTestSuite
from llmtune.qa.qa_tests import QaTestRegistry
from llmtune.ui.rich_ui import RichUI
from llmtune.utils.ablation_utils import generate_permutations
from llmtune.utils.save_utils import DirectoryHelper
Expand Down Expand Up @@ -92,6 +93,7 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
tests = QaTestRegistry.create_tests_from_list(llm_tests)
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
test_suite.save_test_results(qa_file_path)
test_suite.print_test_results()


@app.command("run")
Expand Down
2 changes: 1 addition & 1 deletion llmtune/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data:

# Model Definition -------------------
model:
hf_model_ckpt: "mistralai/Mistral-7B-Instruct-v0.2"
hf_model_ckpt: "facebook/opt-125m"
torch_dtype: "bfloat16"
#attn_implementation: "flash_attention_2"
quantize: true
Expand Down
1 change: 1 addition & 0 deletions llmtune/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,4 @@ class Config(BaseModel):
lora: LoraConfig
training: TrainingConfig
inference: InferenceConfig
qa: QaConfig
38 changes: 17 additions & 21 deletions llmtune/qa/generics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import statistics
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union

import pandas as pd
Expand All @@ -18,23 +19,6 @@ def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[f
pass


class QaTestRegistry:
registry = {}

@classmethod
def register(cls, *names):
def inner_wrapper(wrapped_class):
for name in names:
cls.registry[name] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
return [cls.create_test(test) for test in test_names]


class LLMTestSuite:
def __init__(
self,
Expand All @@ -51,11 +35,17 @@ def __init__(
self._results = {}

@staticmethod
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
def from_csv(
file_path: str,
tests: List[LLMQaTest],
prompt_col: str = "Prompt",
gold_col: str = "Ground Truth",
pred_col="Predicted",
) -> "LLMTestSuite":
results_df = pd.read_csv(file_path)
prompts = results_df["prompt"].tolist()
ground_truths = results_df["ground_truth"].tolist()
model_preds = results_df["model_prediction"].tolist()
prompts = results_df[prompt_col].tolist()
ground_truths = results_df[gold_col].tolist()
model_preds = results_df[pred_col].tolist()
return LLMTestSuite(tests, prompts, ground_truths, model_preds)

def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
Expand Down Expand Up @@ -84,5 +74,11 @@ def print_test_results(self):

def save_test_results(self, path: str):
# TODO: save these!
path = Path(path)
dir = path.parent

if not dir.exists():
dir.mkdir(parents=True, exist_ok=True)

resultant_dataframe = pd.DataFrame(self.test_results)
resultant_dataframe.to_csv(path, index=False)
19 changes: 18 additions & 1 deletion llmtune/qa/qa_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rouge_score import rouge_scorer
from transformers import DistilBertModel, DistilBertTokenizer

from llmtune.qa.generics import LLMQaTest, QaTestRegistry
from llmtune.qa.generics import LLMQaTest


model_name = "distilbert-base-uncased"
Expand All @@ -21,6 +21,23 @@
nltk.download("averaged_perceptron_tagger")


class QaTestRegistry:
registry = {}

@classmethod
def register(cls, *names):
def inner_wrapper(wrapped_class):
for name in names:
cls.registry[name] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
return [cls.registry[test]() for test in test_names]


@QaTestRegistry.register("summary_length")
class LengthTest(LLMQaTest):
@property
Expand Down

0 comments on commit 6a442a8

Please sign in to comment.