Skip to content
This repository has been archived by the owner on Jun 30, 2022. It is now read-only.

Commit

Permalink
Generalize base PTransform._extract_input_pvalues
Browse files Browse the repository at this point in the history
The basecase now understands tuples and dicts of pvalues, which
eases writing multi-input composite transforms.

----Release Notes----
[]
-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=121960544
  • Loading branch information
robertwb authored and aaltay committed May 11, 2016
1 parent f0467cb commit 2843cf9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
15 changes: 14 additions & 1 deletion google/cloud/dataflow/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def _extract_input_pvalues(self, pvalueish):
Returns pvalueish as well as the flat inputs list as the input may have to
be copied as inspection may be destructive.
By default, recursively extracts tuple components and dict values.
Generally only needs to be overriden for multi-input PTransforms.
"""
# pylint: disable=g-import-not-at-top
Expand All @@ -429,7 +431,18 @@ def _extract_input_pvalues(self, pvalueish):
if isinstance(pvalueish, pipeline.Pipeline):
pvalueish = pvalue.PBegin(pvalueish)

return pvalueish, (pvalueish,)
def _dict_tuple_leaves(pvalueish):
if isinstance(pvalueish, tuple):
for a in pvalueish:
for p in _dict_tuple_leaves(a):
yield p
elif isinstance(pvalueish, dict):
for a in pvalueish.values():
for p in _dict_tuple_leaves(a):
yield p
else:
yield pvalueish
return pvalueish, tuple(_dict_tuple_leaves(pvalueish))


class ChainedPTransform(PTransform):
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/dataflow/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,16 @@ def test_apply_to_list(self):
self.assertEqual([('k', (['a'], ['b', 'c']))],
join_input | df.CoGroupByKey('join'))

def test_multi_input_ptransform(self):
class DisjointUnion(PTransform):
def apply(self, pcollections):
return (pcollections
| df.Flatten()
| df.Map(lambda x: (x, None))
| df.GroupByKey()
| df.Map(lambda (x, _): x))
self.assertEqual([1, 2, 3], sorted(([1, 2], [2, 3]) | DisjointUnion()))

def test_apply_to_crazy_pvaluish(self):
class NestedFlatten(PTransform):
"""A PTransform taking and returning nested PValueish.
Expand Down

0 comments on commit 2843cf9

Please sign in to comment.