Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
goodwanghan committed Aug 20, 2024
1 parent bceef07 commit ff0df2e
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/tune/noniterative/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import pandas as pd
from fugue import FugueWorkflow
from pytest import raises

from tune import optimize_noniterative, suggest_for_noniterative_objective
from tune.concepts.dataset import TuneDatasetBuilder
from tune.concepts.flow import Monitor
from tune.concepts.logger import MetricLogger, get_current_metric_logger
from tune.concepts.space import Grid, Space
from tune.constants import TUNE_REPORT, TUNE_REPORT_METRIC
from tune.exceptions import TuneInterrupted
Expand All @@ -14,23 +16,29 @@


def objective(a: float, b: pd.DataFrame) -> float:
return a ** 2 + b.shape[0]
return a**2 + b.shape[0]


def objective_with_logger(a: float, b: pd.DataFrame) -> float:
m = get_current_metric_logger()
assert m.mock
return a**2 + b.shape[0]


def objective2(a: float, b: pd.DataFrame) -> float:
return -(a ** 2 + b.shape[0])
return -(a**2 + b.shape[0])


def objective3(a: float, b: pd.DataFrame) -> float:
if a == -2:
raise TuneInterrupted()
return a ** 2 + b.shape[0]
return a**2 + b.shape[0]


def objective4(a: float, b: pd.DataFrame) -> float:
if a == -2:
raise ValueError("expected")
return a ** 2 + b.shape[0]
return a**2 + b.shape[0]


def assert_metric(df: pd.DataFrame, metrics: List[float]) -> None:
Expand All @@ -56,12 +64,15 @@ def test_study(tmpdir):
# no data partition
builder = TuneDatasetBuilder(space, str(tmpdir)).add_df("b", dag.df(input_df))
dataset = builder.build(dag, 1)
logger = MetricLogger()
logger.mock = True
for distributed in [True, False, None]:
# min_better = True
result = optimize_noniterative(
objective=objective,
objective=objective_with_logger,
dataset=dataset,
distributed=distributed,
logger=logger,
)
result.result()[[TUNE_REPORT, TUNE_REPORT_METRIC]].output(
assert_metric, params=dict(metrics=[3.0, 4.0, 7.0])
Expand Down

0 comments on commit ff0df2e

Please sign in to comment.