Skip to content

Commit

Permalink
Merge pull request #115 from m3dev/add-column-drop-option-to-load-dat…
Browse files Browse the repository at this point in the history
…a-frame

add drop_columns option
  • Loading branch information
nishiba authored Jan 23, 2020
2 parents aacbc0e + 2032b9e commit d0b700c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
12 changes: 8 additions & 4 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TaskOnKart(luigi.Task):
modification_time_check = luigi.BoolParameter(
default=False,
description='If this is true, this task will not run only if all input and output files exist,'
' and all input files are modified before output file are modified.',
' and all input files are modified before output file are modified.',
significant=False)
delete_unnecessary_output_files = luigi.BoolParameter(default=False, description='If this is true, delete unnecessary output files.', significant=False)
significant = luigi.BoolParameter(
Expand Down Expand Up @@ -112,7 +112,7 @@ def make_target(self, relative_file_path: str, use_unique_id: bool = True, proce
unique_id = self.make_unique_id() if use_unique_id else None
return gokart.target.make_target(file_path=file_path, unique_id=unique_id, processor=processor)

def make_large_data_frame_target(self, relative_file_path: str, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
def make_large_data_frame_target(self, relative_file_path: str, use_unique_id: bool = True, max_byte=int(2 ** 26)) -> TargetOnKart:
file_path = os.path.join(self.workspace_directory, relative_file_path)
unique_id = self.make_unique_id() if use_unique_id else None
return gokart.target.make_model_target(
Expand Down Expand Up @@ -168,21 +168,25 @@ def _load(targets):

return _load(self._get_input_targets(target))

def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None) -> pd.DataFrame:
def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None,
drop_columns: bool = False) -> pd.DataFrame:
data = self.load(target=target)
if isinstance(data, list):
def _pd_concat(dfs):
if isinstance(dfs, list):
return pd.concat([_pd_concat(df) for df in dfs])
else:
return dfs

data = _pd_concat(data)

required_columns = required_columns or set()
if data.empty:
return pd.DataFrame(columns=required_columns)

assert required_columns.issubset(set(data.columns)), f'data must have columns {required_columns}, but actually have only {data.columns}.'
if drop_columns:
data = data[required_columns]
return data

def dump(self, obj, target: Union[None, str, TargetOnKart] = None) -> None:
Expand Down Expand Up @@ -292,7 +296,7 @@ def _get_module_versions(self) -> str:
for x in set([x.split('.')[0] for x in sys.modules.keys() if '_' not in x]):
module = import_module(x)
if '__version__' in dir(module):
if type(module.__version__)==str:
if type(module.__version__) == str:
version = module.__version__.split(" ")[0]
else:
version = '.'.join([str(v) for v in module.__version__])
Expand Down
9 changes: 9 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ def test_load_list_of_list_pandas(self):
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(3, df.shape[0])

def test_load_data_frame_drop_columns(self):
task = _DummyTask()
task.load = MagicMock(return_value=pd.DataFrame(dict(a=[1], b=[2], c=[3])))

df = task.load_data_frame(required_columns={'a', 'c'}, drop_columns=True)
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(1, df.shape[0])
self.assertSetEqual({'a', 'c'}, set(df.columns))

def test_use_rerun_with_inherits(self):
# All tasks are completed.
task_c = _DummyTaskC()
Expand Down

0 comments on commit d0b700c

Please sign in to comment.