Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create unit and functional tests for metrics #306

Draft
wants to merge 5 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/functional/metric/data/metric_test_end2end.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
{"name": "rouge", "id": "65e1d5ac95487dea3023a0b8", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "overlap f1", "id": "66df3e2d6eb56336b6628171", "hypothesis": "test", "reference": "test", "score": 1.0}
]
11 changes: 11 additions & 0 deletions tests/functional/metric/data/metric_test_end2end_full.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
{"name": "rouge", "id": "65e1d5ac95487dea3023a0b8", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "overlap f1", "id": "66df3e2d6eb56336b6628171", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "wder", "id": "66df3d0a6eb56302656a1c41", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "der", "id": "66155bc16eb56341b1291e61", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "wer", "id": "646d371caec2a04700e61945", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "precision", "id": "666c914d6eb5630a427d5ed1", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "recall", "id": "666c91706eb5630fbd6456a1", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "f1", "id": "666c91826eb5631382436291", "hypothesis": "test", "reference": "test", "score": 1.0},
{"name": "accuracy", "id": "666c8ffe6eb5634c5779d191", "hypothesis": "test", "reference": "test", "score": 1.0}
]
56 changes: 56 additions & 0 deletions tests/functional/metric/metric_functional_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
__author__ = "lucaspavanelli"
"""
Copyright 2022 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from dotenv import load_dotenv
import pytest

load_dotenv()
from aixplain.factories import MetricFactory

RUN_FILE = "tests/functional/metric/data/metric_test_end2end.json"


def read_data(data_path):
return json.load(open(data_path, "r"))


@pytest.fixture(scope="module", params=read_data(RUN_FILE))
def run_input_map(request):
return request.param


def test_end2end(run_input_map):
metric_id = run_input_map["id"]
hypothesis = run_input_map["hypothesis"]
reference = run_input_map["reference"]
metric = MetricFactory.get(metric_id)
result = metric.run(hypothesis=hypothesis, reference=reference)
assert result is not None
assert result["status"] == "SUCCESS"
assert result["completed"] is True
assert "details" in result
assert "data" in result
assert len(result["data"]) == 1
assert result["data"][0]["score"] == run_input_map["score"]


def test_list_metric():
metric_list = MetricFactory.list()["results"]
assert len(metric_list) > 0


# TODO test the following list: rouge, overlap f1, wder, der, wer, precision, recall, f1 and accuracy
143 changes: 143 additions & 0 deletions tests/unit/metric_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
__author__ = "lucaspavanelli"

"""
Copyright 2022 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from dotenv import load_dotenv
import requests_mock

load_dotenv()
from aixplain.modules import Metric
from aixplain.factories import MetricFactory
from aixplain.enums import Supplier
from urllib.parse import urljoin

import pytest


def test_metric_factory_get():
metric_id = "1"
url = urljoin(MetricFactory.backend_url, f"sdk/metrics/{metric_id}")
json_response = {
"id": "1",
"name": "rouge",
"supplier": "aixplain",
"referenceRequired": True,
"sourceRequired": True,
"normalizedPrice": 0.0,
"function": "text_generation",
}
with requests_mock.Mocker() as m:
m.get(url, json=json_response)
metric = MetricFactory.get(metric_id)
assert metric.id == metric_id
assert metric.name == "rouge"
assert metric.supplier == Supplier.AIXPLAIN
assert metric.cost == 0.0
assert metric.function == "text_generation"
assert metric.is_reference_required is True
assert metric.is_source_required is True


def test_metric_factory_get_exception():
metric_id = "1"
url = urljoin(MetricFactory.backend_url, f"sdk/metrics/{metric_id}")
with requests_mock.Mocker() as m:
m.get(url, status_code=404)
with pytest.raises(Exception) as e:
MetricFactory.get(metric_id)
expected_message = "Status code: 404. Metric Creation: Unspecified error."
assert str(e) == expected_message


def test_metric_factory_list():
url = urljoin(MetricFactory.backend_url, "sdk/metrics")
json_response = {
"results": [
{
"id": "1",
"name": "rouge",
"supplier": "aixplain",
"referenceRequired": True,
"sourceRequired": True,
"normalizedPrice": 0.0,
"function": "text_generation",
}
]
}
with requests_mock.Mocker() as m:
m.get(url, json=json_response)
metrics = MetricFactory.list()
returned_keys = ["results", "page_total", "total", "page_number"]
assert all(key in metrics for key in returned_keys)
results = metrics["results"]
assert len(results) == 1
metric = results[0]
assert metric.id == "1"
assert metric.name == "rouge"
assert metric.supplier == Supplier.AIXPLAIN
assert metric.cost == 0.0
assert metric.function == "text_generation"
assert metric.is_reference_required is True
assert metric.is_source_required is True


def test_metric_factory_list_exception():
url = urljoin(MetricFactory.backend_url, "sdk/metrics")
with requests_mock.Mocker() as m:
m.get(url, status_code=404)
results = MetricFactory.list()
assert len(results) == 0


def test_metric_constructor():
metric = Metric("1", "rouge", "aixplain", True, True, 0.0, "text_generation")
assert metric.id == "1"
assert metric.name == "rouge"
assert metric.supplier == Supplier.AIXPLAIN
assert metric.cost == 0.0
assert metric.function == "text_generation"
assert metric.is_reference_required is True
assert metric.is_source_required is True
assert metric.version == "1.0"


def test_metric_add_normalization_options():
metric = Metric("1", "rouge", "aixplain", True, True, 0.0, "text_generation")
assert metric.normalization_options == []
metric.add_normalization_options(["option1", "option2"])
assert metric.normalization_options == [["option1", "option2"]]


def test_metric_run(mocker):
metric = Metric("1", "rouge", "aixplain", True, True, 0.0, "text_generation")
model = mocker.MagicMock()
model.run.return_value = "result"
model_mocker = mocker.patch("aixplain.factories.model_factory.ModelFactory.get", return_value=model)
response = metric.run("hypothesis", "source", "reference")
assert response == "result"
model_mocker.assert_called_with("1")
model.run.assert_called_once()
model.run.assert_called_with(
{
"function": "text_generation",
"supplier": Supplier.AIXPLAIN,
"version": "rouge",
"hypotheses": ["hypothesis"],
"sources": ["source"],
"references": [["reference"]],
}
)