From 85f3ebcd80cacd0ab0bd3bb136705001543b7a79 Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Wed, 4 Sep 2024 14:30:28 +0900 Subject: [PATCH] feat: add TaskOnKart.dump type (#368) --- gokart/build.py | 38 ++++++++++++++++++++++++++++++++------ gokart/task.py | 19 ++++++++++++++----- poetry.lock | 4 +++- pyproject.toml | 1 + test/test_build.py | 21 +++++++++++++++------ 5 files changed, 65 insertions(+), 18 deletions(-) diff --git a/gokart/build.py b/gokart/build.py index 57f69898..46001e32 100644 --- a/gokart/build.py +++ b/gokart/build.py @@ -1,7 +1,7 @@ import logging from functools import partial from logging import getLogger -from typing import Any, Optional +from typing import Literal, Optional, TypeVar, cast, overload import backoff import luigi @@ -11,6 +11,8 @@ from gokart.target import TargetOnKart from gokart.task import TaskOnKart +T = TypeVar('T') + class LoggerConfig: def __init__(self, level: int): @@ -41,13 +43,13 @@ def __init__(self): self.flag: bool = False -def _get_output(task: TaskOnKart) -> Any: +def _get_output(task: TaskOnKart[T]) -> T: output = task.output() # FIXME: currently, nested output is not supported if isinstance(output, list) or isinstance(output, tuple): - return [t.load() for t in output if isinstance(t, TargetOnKart)] + return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)]) if isinstance(output, dict): - return {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)} + return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)}) if isinstance(output, TargetOnKart): return output.load() raise ValueError(f'output type is not supported: {type(output)}') @@ -65,15 +67,39 @@ def _reset_register(keep={'gokart', 'luigi'}): ] +@overload +def build( + task: TaskOnKart[T], + return_value: Literal[True] = True, + reset_register: bool = True, + log_level: int = logging.ERROR, + task_lock_exception_max_tries: int = 10, + task_lock_exception_max_wait_seconds: int = 600, + **env_params, +) -> T: ... + + +@overload +def build( + task: TaskOnKart[T], + return_value: Literal[False], + reset_register: bool = True, + log_level: int = logging.ERROR, + task_lock_exception_max_tries: int = 10, + task_lock_exception_max_wait_seconds: int = 600, + **env_params, +) -> None: ... + + def build( - task: TaskOnKart, + task: TaskOnKart[T], return_value: bool = True, reset_register: bool = True, log_level: int = logging.ERROR, task_lock_exception_max_tries: int = 10, task_lock_exception_max_wait_seconds: int = 600, **env_params, -) -> Optional[Any]: +) -> Optional[T]: """ Run gokart task for local interpreter. Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build) diff --git a/gokart/task.py b/gokart/task.py index bf47471a..93d334a7 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -5,7 +5,7 @@ import types from importlib import import_module from logging import getLogger -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, Generator, Generic, List, Optional, Set, TypeVar, Union, overload import luigi import pandas as pd @@ -25,7 +25,10 @@ logger = getLogger(__name__) -class TaskOnKart(luigi.Task): +T = TypeVar('T') + + +class TaskOnKart(luigi.Task, Generic[T]): """ This is a wrapper class of luigi.Task. @@ -282,7 +285,7 @@ def _load(targets): return list(data.values())[0] return data - def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Any: + def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Generator[Any, None, None]: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): for t in targets: @@ -318,6 +321,12 @@ def _flatten_recursively(dfs): data = data[list(required_columns)] return data + @overload + def dump(self, obj: T, target: None = None) -> None: ... + + @overload + def dump(self, obj: Any, target: Union[str, TargetOnKart]) -> None: ... + def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None) -> None: PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace) if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame): @@ -336,13 +345,13 @@ def get_own_code(self): own_codes = self.get_code(self) return ''.join(sorted(list(own_codes - gokart_codes))) - def make_unique_id(self): + def make_unique_id(self) -> str: unique_id = self.task_unique_id or self._make_hash_id() if self.cache_unique_id: self.task_unique_id = unique_id return unique_id - def _make_hash_id(self): + def _make_hash_id(self) -> str: def _to_str_params(task): if isinstance(task, TaskOnKart): return str(task.make_unique_id()) if task.significant else None diff --git a/poetry.lock b/poetry.lock index c85fc928..8ab872eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1334,6 +1334,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -1354,6 +1355,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -2259,4 +2261,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "a5797aca91ef6d3e7af307d51ae506b07b26cd79fb903d3d61fab825377f1f3f" +content-hash = "e80c45330d11c23e7639cd54fbb141bee103b0e2215cb5cdd9b1908add1af139" diff --git a/pyproject.toml b/pyproject.toml index 1bbfecb1..d26b5d83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ fakeredis = "*" mypy = "*" types-redis = "*" matplotlib = "*" +typing-extensions = "^4.11.0" [tool.ruff] line-length = 160 diff --git a/test/test_build.py b/test/test_build.py index bb22616f..94df9b2f 100644 --- a/test/test_build.py +++ b/test/test_build.py @@ -1,7 +1,14 @@ import logging import os +import sys import unittest from copy import copy +from typing import Dict + +if sys.version_info >= (3, 11): + from typing import assert_type +else: + from typing_extensions import assert_type import luigi import luigi.mock @@ -11,9 +18,9 @@ from gokart.conflict_prevention_lock.task_lock import TaskLockException -class _DummyTask(gokart.TaskOnKart): +class _DummyTask(gokart.TaskOnKart[str]): task_namespace = __name__ - param = luigi.Parameter() + param: str = luigi.Parameter() def output(self): return self.make_target('./test/dummy.pkl') @@ -22,10 +29,10 @@ def run(self): self.dump(self.param) -class _DummyTaskTwoOutputs(gokart.TaskOnKart): +class _DummyTaskTwoOutputs(gokart.TaskOnKart[Dict[str, str]]): task_namespace = __name__ - param1 = luigi.Parameter() - param2 = luigi.Parameter() + param1: str = luigi.Parameter() + param2: str = luigi.Parameter() def output(self): return {'out1': self.make_target('./test/dummy1.pkl'), 'out2': self.make_target('./test/dummy2.pkl')} @@ -42,7 +49,7 @@ def run(self): raise RuntimeError -class _ParallelRunner(gokart.TaskOnKart): +class _ParallelRunner(gokart.TaskOnKart[str]): def requires(self): return [_DummyTask(param=str(i)) for i in range(10)] @@ -79,6 +86,7 @@ def test_read_config(self): config_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config', 'test_config.ini') gokart.utils.add_config(config_file_path) output = gokart.build(_DummyTask(), reset_register=False) + assert_type(output, str) self.assertEqual(output, 'test') def test_build_dict_outputs(self): @@ -87,6 +95,7 @@ def test_build_dict_outputs(self): 'out2': 'test2', } output = gokart.build(_DummyTaskTwoOutputs(param1=param_dict['out1'], param2=param_dict['out2']), reset_register=False) + assert_type(output, Dict[str, str]) self.assertEqual(output, param_dict) def test_failed_task(self):