Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ Runtime validation of TaskInstanceParameter() and ListTaskInstanceParameter() by subclass bound #305

Merged
merged 13 commits into from
Mar 2, 2023
29 changes: 29 additions & 0 deletions gokart/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@
import luigi
from luigi import task_register

import gokart

logger = getLogger(__name__)


class TaskInstanceParameter(luigi.Parameter):

def __init__(self, expected_type=None, *args, **kwargs):
if expected_type is None:
self.expected_type = gokart.TaskOnKart
elif isinstance(expected_type, type):
self.expected_type = expected_type
else:
raise TypeError(f'expected_type must be a type, not {type(expected_type)}')
super().__init__(*args, **kwargs)

@staticmethod
def _recursive(param_dict):
params = param_dict['params']
Expand All @@ -36,6 +47,10 @@ def serialize(self, x):
values = dict(type=x.get_task_family(), params=params)
return luigi.DictParameter().serialize(values)

def _warn_on_wrong_param_type(self, param_name, param_value):
if not isinstance(param_value, self.expected_type):
raise TypeError(f'{param_value} is not an instance of {self.expected_type}')


class _TaskInstanceEncoder(json.JSONEncoder):

Expand All @@ -48,12 +63,26 @@ def default(self, obj):

class ListTaskInstanceParameter(luigi.Parameter):

def __init__(self, expected_elements_type=None, *args, **kwargs):
if expected_elements_type is None:
self.expected_elements_type = gokart.TaskOnKart
elif isinstance(expected_elements_type, type):
self.expected_elements_type = expected_elements_type
else:
raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}')
super().__init__(*args, **kwargs)

def parse(self, s):
return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))]

def serialize(self, x):
return json.dumps(x, cls=_TaskInstanceEncoder)

def _warn_on_wrong_param_type(self, param_name, param_value):
for v in param_value:
if not isinstance(v, self.expected_elements_type):
raise TypeError(f'{v} is not an instance of {self.expected_elements_type}')


class ExplicitBoolParameter(luigi.BoolParameter):

Expand Down
58 changes: 58 additions & 0 deletions test/test_task_instance_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ class _DummySubTask(TaskOnKart):
pass


class _DummyCorrectSubClassTask(_DummySubTask):
task_namespace = __name__
pass


class _DummyInvalidSubClassTask(TaskOnKart):
task_namespace = __name__
pass


class _DummyTask(TaskOnKart):
task_namespace = __name__
param = luigi.IntParameter()
Expand Down Expand Up @@ -40,6 +50,54 @@ def test_serialize_and_parse_list_params(self):
parsed = gokart.TaskInstanceParameter().parse(s)
self.assertEqual(parsed.task_id, original.task_id)

def test_invalid_class(self):
self.assertRaises(TypeError, lambda: gokart.TaskInstanceParameter(expected_type=1)) # not type instance

def test_params_with_correct_param_type(self):

class _DummyPipelineA(TaskOnKart):
task_namespace = __name__
subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask)

task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask())
self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask())

def test_params_with_invalid_param_type(self):

class _DummyPipelineB(TaskOnKart):
task_namespace = __name__
subtask = gokart.TaskInstanceParameter(expected_type=_DummySubTask)

with self.assertRaises(TypeError):
_DummyPipelineB(subtask=_DummyInvalidSubClassTask())


class ListTaskInstanceParameterTest(unittest.TestCase):

def setUp(self):
_DummyTask.clear_instance_cache()

def test_invalid_class(self):
self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_elements_type=1)) # not type instance

def test_list_params_with_correct_param_types(self):

class _DummyPipelineC(TaskOnKart):
task_namespace = __name__
subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)

task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()])
self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(), ))

def test_list_params_with_invalid_param_types(self):

class _DummyPipelineD(TaskOnKart):
task_namespace = __name__
subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)

with self.assertRaises(TypeError):
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask])


if __name__ == '__main__':
unittest.main()