From 84cc69c59a49bee8a46e2622f70fa42613f31d7a Mon Sep 17 00:00:00 2001 From: lucas-aixplain Date: Mon, 4 Nov 2024 19:40:59 -0300 Subject: [PATCH 1/3] Add metric tests --- .../metric/data/metric_test_end2end.json | 3 + .../metric/metric_functional_test.py | 57 +++++++ tests/unit/metric_test.py | 143 ++++++++++++++++++ 3 files changed, 203 insertions(+) create mode 100644 tests/functional/metric/data/metric_test_end2end.json create mode 100644 tests/functional/metric/metric_functional_test.py create mode 100644 tests/unit/metric_test.py diff --git a/tests/functional/metric/data/metric_test_end2end.json b/tests/functional/metric/data/metric_test_end2end.json new file mode 100644 index 00000000..e359c610 --- /dev/null +++ b/tests/functional/metric/data/metric_test_end2end.json @@ -0,0 +1,3 @@ +[ + {"name": "rouge", "id": "65e1d5ac95487dea3023a0b8", "hypothesis": "test", "reference": "test", "score": 1.0} +] \ No newline at end of file diff --git a/tests/functional/metric/metric_functional_test.py b/tests/functional/metric/metric_functional_test.py new file mode 100644 index 00000000..0beb792c --- /dev/null +++ b/tests/functional/metric/metric_functional_test.py @@ -0,0 +1,57 @@ +__author__ = "lucaspavanelli" +""" +Copyright 2022 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import json +from dotenv import load_dotenv +import pytest + +load_dotenv() +from aixplain.factories import MetricFactory + +RUN_FILE = "tests/functional/metric/data/metric_test_end2end.json" + + +def read_data(data_path): + return json.load(open(data_path, "r")) + + +@pytest.fixture(scope="module", params=read_data(RUN_FILE)) +def run_input_map(request): + return request.param + + +def test_end2end(run_input_map): + metric_id = run_input_map["id"] + hypothesis = run_input_map["hypothesis"] + reference = run_input_map["reference"] + metric = MetricFactory.get(metric_id) + result = metric.run(hypothesis=hypothesis, reference=reference) + assert result is not None + assert result["status"] == "SUCCESS" + assert result["completed"] is True + assert "details" in result + assert "data" in result + assert len(result["data"]) == 1 + assert result["data"][0]["score"] == run_input_map["score"] + assert result["score"] == run_input_map["score"] + + +def test_list_metric(): + metric_list = MetricFactory.list()["results"] + assert len(metric_list) > 0 + + +# TODO test the following list: rouge, overlap f1, wder, der, wer, precision, recall, f1 and accuracy diff --git a/tests/unit/metric_test.py b/tests/unit/metric_test.py new file mode 100644 index 00000000..bd743bdd --- /dev/null +++ b/tests/unit/metric_test.py @@ -0,0 +1,143 @@ +__author__ = "lucaspavanelli" + +""" +Copyright 2022 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dotenv import load_dotenv +import requests_mock + +load_dotenv() +from aixplain.modules import Metric +from aixplain.factories import MetricFactory +from aixplain.enums import Supplier +from urllib.parse import urljoin + +import pytest + + +def test_metric_factory_get(): + metric_id = "1" + url = urljoin(MetricFactory.backend_url, f"sdk/metrics/{metric_id}") + json_response = { + "id": "1", + "name": "rouge", + "supplier": "aixplain", + "referenceRequired": True, + "sourceRequired": True, + "normalizedPrice": 0.0, + "function": "text_generation", + } + with requests_mock.Mocker() as m: + m.get(url, json=json_response) + metric = MetricFactory.get(metric_id) + assert metric.id == metric_id + assert metric.name == "rouge" + assert metric.supplier == Supplier.AIXPLAIN + assert metric.cost == 0.0 + assert metric.function == "text_generation" + assert metric.is_reference_required is True + assert metric.is_source_required is True + + +def test_metric_factory_get_exception(): + metric_id = "1" + url = urljoin(MetricFactory.backend_url, f"sdk/metrics/{metric_id}") + with requests_mock.Mocker() as m: + m.get(url, status_code=404) + with pytest.raises(Exception) as e: + MetricFactory.get(metric_id) + expected_message = "Status code: 404. Metric Creation: Unspecified error." + assert str(e) == expected_message + + +def test_metric_factory_list(): + url = urljoin(MetricFactory.backend_url, "sdk/metrics") + json_response = { + "results": [ + { + "id": "1", + "name": "rouge", + "supplier": "aixplain", + "referenceRequired": True, + "sourceRequired": True, + "normalizedPrice": 0.0, + "function": "text_generation", + } + ] + } + with requests_mock.Mocker() as m: + m.get(url, json=json_response) + metrics = MetricFactory.list() + returned_keys = ["results", "page_total", "total", "page_number"] + assert all(key in metrics for key in returned_keys) + results = metrics["results"] + assert len(results) == 1 + metric = results[0] + assert metric.id == "1" + assert metric.name == "rouge" + assert metric.supplier == Supplier.AIXPLAIN + assert metric.cost == 0.0 + assert metric.function == "text_generation" + assert metric.is_reference_required is True + assert metric.is_source_required is True + + +def test_metric_factory_list_exception(): + url = urljoin(MetricFactory.backend_url, "sdk/metrics") + with requests_mock.Mocker() as m: + m.get(url, status_code=404) + results = MetricFactory.list() + assert len(results) == 0 + + +def test_metric_constructor(): + metric = Metric("1", "rouge", "aixplain", True, True, 0.0, "text_generation") + assert metric.id == "1" + assert metric.name == "rouge" + assert metric.supplier == Supplier.AIXPLAIN + assert metric.cost == 0.0 + assert metric.function == "text_generation" + assert metric.is_reference_required is True + assert metric.is_source_required is True + assert metric.version == "1.0" + + +def test_metric_add_normalization_options(): + metric = Metric("1", "rouge", "aixplain", True, True, 0.0, "text_generation") + assert metric.normalization_options == [] + metric.add_normalization_options(["option1", "option2"]) + assert metric.normalization_options == [["option1", "option2"]] + + +def test_metric_run(mocker): + metric = Metric("1", "rouge", "aixplain", True, True, 0.0, "text_generation") + model = mocker.MagicMock() + model.run.return_value = "result" + model_mocker = mocker.patch("aixplain.factories.model_factory.ModelFactory.get", return_value=model) + response = metric.run("hypothesis", "source", "reference") + assert response == "result" + model_mocker.assert_called_with("1") + model.run.assert_called_once() + model.run.assert_called_with( + { + "function": "text_generation", + "supplier": Supplier.AIXPLAIN, + "version": "rouge", + "hypotheses": ["hypothesis"], + "sources": ["source"], + "references": [["reference"]], + } + ) From ce70208afa48a7edbbaf5aacac94d80c9f1a6206 Mon Sep 17 00:00:00 2001 From: lucas-aixplain Date: Fri, 8 Nov 2024 20:05:15 -0300 Subject: [PATCH 2/3] Update metric functional test data --- tests/functional/metric/data/metric_test_end2end.json | 5 +++-- .../metric/data/metric_test_end2end_full.json | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 tests/functional/metric/data/metric_test_end2end_full.json diff --git a/tests/functional/metric/data/metric_test_end2end.json b/tests/functional/metric/data/metric_test_end2end.json index e359c610..8c1f5be8 100644 --- a/tests/functional/metric/data/metric_test_end2end.json +++ b/tests/functional/metric/data/metric_test_end2end.json @@ -1,3 +1,4 @@ [ - {"name": "rouge", "id": "65e1d5ac95487dea3023a0b8", "hypothesis": "test", "reference": "test", "score": 1.0} -] \ No newline at end of file + {"name": "rouge", "id": "65e1d5ac95487dea3023a0b8", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "overlap f1", "id": "66df3e2d6eb56336b6628171", "hypothesis": "test", "reference": "test", "score": 1.0} +] diff --git a/tests/functional/metric/data/metric_test_end2end_full.json b/tests/functional/metric/data/metric_test_end2end_full.json new file mode 100644 index 00000000..a4517071 --- /dev/null +++ b/tests/functional/metric/data/metric_test_end2end_full.json @@ -0,0 +1,11 @@ +[ + {"name": "rouge", "id": "65e1d5ac95487dea3023a0b8", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "overlap f1", "id": "66df3e2d6eb56336b6628171", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "wder", "id": "66df3d0a6eb56302656a1c41", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "der", "id": "66155bc16eb56341b1291e61", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "wer", "id": "646d371caec2a04700e61945", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "precision", "id": "666c914d6eb5630a427d5ed1", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "recall", "id": "666c91706eb5630fbd6456a1", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "f1", "id": "666c91826eb5631382436291", "hypothesis": "test", "reference": "test", "score": 1.0}, + {"name": "accuracy", "id": "666c8ffe6eb5634c5779d191", "hypothesis": "test", "reference": "test", "score": 1.0} +] \ No newline at end of file From 7343da03bc4bac2f258ec4e48db1c6fda5a1247b Mon Sep 17 00:00:00 2001 From: lucas-aixplain Date: Mon, 11 Nov 2024 12:09:31 -0300 Subject: [PATCH 3/3] Update metric functional test --- tests/functional/metric/metric_functional_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/functional/metric/metric_functional_test.py b/tests/functional/metric/metric_functional_test.py index 0beb792c..b58ffcec 100644 --- a/tests/functional/metric/metric_functional_test.py +++ b/tests/functional/metric/metric_functional_test.py @@ -46,7 +46,6 @@ def test_end2end(run_input_map): assert "data" in result assert len(result["data"]) == 1 assert result["data"][0]["score"] == run_input_map["score"] - assert result["score"] == run_input_map["score"] def test_list_metric():