Skip to content

Commit

Permalink
not assert task.load_data_frame(required_columns={}) (#180)
Browse files Browse the repository at this point in the history
* not assert task.load_data_frame(required_columns={})

* change dict check method

* add ut

* change ut name
  • Loading branch information
vaaaaanquish authored Mar 16, 2021
1 parent f8d0d31 commit 7056c10
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 5 additions & 1 deletion gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ def _flatten_recursively(dfs):
else:
return dfs

data = _flatten_recursively(self.load(target=target))
dfs = self.load(target=target)
if isinstance(dfs, dict) and len(dfs) == 1:
dfs = list(dfs.values())[0]

data = _flatten_recursively(dfs)

required_columns = required_columns or set()
if data.empty and len(data.index) == 0 and len(required_columns - set(data.columns)) > 0:
Expand Down
10 changes: 9 additions & 1 deletion test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def test_add_cofigureation_evaluation_order(self, mock_cmdline: MagicMock):
class DummyTaskAddConfiguration(gokart.TaskOnKart):
aa = luigi.IntParameter()

luigi.configuration.get_config().set(f'DummyTaskAddConfiguration', 'aa', '3')
luigi.configuration.get_config().set('DummyTaskAddConfiguration', 'aa', '3')
mock_cmdline.return_value = luigi.cmdline_parser.CmdlineParser(['DummyTaskAddConfiguration'])
self.assertEqual(DummyTaskAddConfiguration().aa, 3)

Expand All @@ -329,6 +329,14 @@ def test_load_list_of_list_pandas(self):
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(3, df.shape[0])

def test_load_single_value_dict_of_dataframe(self):
task = _DummyTask()
task.load = MagicMock(return_value={'a': pd.DataFrame(dict(a=[1]))})

df = task.load_data_frame()
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(1, df.shape[0])

def test_load_data_frame_drop_columns(self):
task = _DummyTask()
task.load = MagicMock(return_value=pd.DataFrame(dict(a=[1], b=[2], c=[3])))
Expand Down

0 comments on commit 7056c10

Please sign in to comment.