Skip to content

Commit

Permalink
renaming metrics from tests to metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
SinclairHudson committed Jun 16, 2024
1 parent 4a20cb2 commit 8a192c3
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 11 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pip install -e .

1. Use `ruff check --fix` to check and fix lint errors
2. Use `ruff format` to apply formatting
3. Run `pytest` at the top level directory to run unit tests

NOTE: Ruff linting and formatting checks are done when PR is raised via Git Action. Before raising a PR, it is a good practice to check and fix lint errors, as well as apply formatting.

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ lora:

```yaml
qa:
llm_tests:
llm_metrics:
- length_test
- word_overlap_test
```
Expand Down
5 changes: 3 additions & 2 deletions llmtune/cli/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
else:
RichUI.results_found(results_path)

# Quality Assurance -------------------------
RichUI.before_qa()
qa_file_path = dir_helper.save_paths.qa_file
if not qa_file_path.exists():
llm_tests = config.qa.llm_tests
tests = QaTestRegistry.create_tests_from_list(llm_tests)
llm_metrics = config.qa.llm_metrics
tests = QaTestRegistry.create_tests_from_list(llm_metrics)
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
test_suite.save_test_results(qa_file_path)
test_suite.print_test_results()
Expand Down
2 changes: 1 addition & 1 deletion llmtune/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inference:
temperature: 0.8

qa:
llm_tests:
llm_metrics:
- jaccard_similarity
- dot_product
- rouge_score
Expand Down
2 changes: 1 addition & 1 deletion llmtune/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class QaConfig(BaseModel):
llm_tests: Optional[List[str]] = Field([], description="list of tests that needs to be connected")
llm_metrics: Optional[List[str]] = Field([], description="list of metrics that needs to be connected")


class DataConfig(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion llmtune/qa/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def print_test_results(self):
median_values = {key: statistics.median(column_data[key]) for key in column_data}
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
# Use the RichUI class to display the table
RichUI.qa_display_table(result_dictionary, mean_values, median_values, stdev_values)
RichUI.qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values)

def save_test_results(self, path: str):
# TODO: save these!
Expand Down
2 changes: 1 addition & 1 deletion llmtune/ui/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,5 @@ def qa_found(cls):
pass

@abstractstaticmethod
def qa_display_table(cls):
def qa_display_metric_table(cls):
pass
4 changes: 2 additions & 2 deletions llmtune/ui/rich_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def qa_found():
pass

@staticmethod
def qa_display_table(result_dictionary, mean_values, median_values, stdev_values):
def qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values):
# Create a table
table = Table(show_header=True, header_style="bold", title="Test Results")
table = Table(show_header=True, header_style="bold", title="Test Set Metric Results")

# Add columns to the table
table.add_column("Metric", style="cyan")
Expand Down
6 changes: 5 additions & 1 deletion test_utils/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Defines a configuration that can be used for unit testing.
"""

from llmtune.pydantic_models.config_model import (
AblationConfig,
BitsAndBytesConfig,
Expand Down Expand Up @@ -72,7 +76,7 @@ def get_sample_config():
train_test_split_seed=42,
),
qa=QaConfig(
llm_tests=[
llm_metrics=[
"jaccard_similarity",
"dot_product",
"rouge_score",
Expand Down
2 changes: 1 addition & 1 deletion tests/qa/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_save_test_results(mock_csv, mock_tests, mocker):
# def test_print_test_results(mock_csv, mock_tests, mock_rich_ui):
# test_suite = LLMTestSuite.from_csv("dummy_path.csv", mock_tests)
# test_suite.print_test_results()
# assert mock_rich_ui.qa_display_table.called
# assert mock_rich_ui.qa_display_metric_table.called


def test_print_test_results(capfd, example_data):
Expand Down

0 comments on commit 8a192c3

Please sign in to comment.