Skip to content

Commit

Permalink
Merge pull request #13 from Ackuq/feature/improved-expliding
Browse files Browse the repository at this point in the history
Add tests and improvements to exploding PIT
  • Loading branch information
Ackuq authored Feb 28, 2022
2 parents fa91673 + 8abeafe commit e34c66c
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 34 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ pit_join = df1.join(df2, pit_context.pit_udf(df1.ts, df2.ts) & (df1.id == df2.i

```py
pit_join = pit_context.union_as_of(
df1,
df2,
left=df1,
right=df2,
left_prefix="df1_",
right_prefix="df2_",
left_ts_column = "ts",
right_ts_column = "ts",
partition_cols=["id"],
)
```
Expand All @@ -91,9 +93,11 @@ pit_join = pit_context.union_as_of(

```py
pit_join = pit_context.exploding(
df1,
df2,
partition_cols = ["id"],
left=df1,
right=df2,
left_ts_column=df1["ts"],
right_ts_column=df2["ts"],
partition_cols = [df1["id"], df2["id"]],
)
```

Expand Down Expand Up @@ -134,6 +138,8 @@ import io.github.ackuq.pit.Exploding
val pitJoin = Exploding.join(
df1,
df2,
partitionCols = Seq("id")
leftTSColumn = df1("ts"),
rightTSColumn = df2("ts"),
partitionCols = Seq((df1("id"), df2("id")))
)
```
14 changes: 9 additions & 5 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ pit_join = df1.join(df2, pit_context.pit_udf(df1.ts, df2.ts) & (df1.id == df2.i

```py
pit_join = pit_context.union_as_of(
df1,
df2,
left=df1,
right=df2,
left_prefix="df1_",
right_prefix="df2_",
left_ts_column = "ts",
right_ts_column = "ts",
partition_cols=["id"],
)
```
Expand All @@ -48,8 +50,10 @@ pit_join = pit_context.union_as_of(

```py
pit_join = pit_context.exploding(
df1,
df2,
partition_cols = ["id"],
left=df1,
right=df2,
left_ts_column=df1["ts"],
right_ts_column=df2["ts"],
partition_cols = [df1["id"], df2["id"]],
)
```
23 changes: 15 additions & 8 deletions python/ackuq/pit/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# SOFTWARE.
#

from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple

from py4j.java_gateway import JavaPackage
from pyspark.sql import Column, DataFrame, SQLContext
Expand Down Expand Up @@ -82,8 +82,14 @@ def _to_scala_seq(self, list_like: List):
.toSeq()
)

def _to_scala_tuple(self, tuple: Tuple):
"""
Converts a Python tuple to Scala tuple
"""
return self._jvm.__getattr__("scala.Tuple" + str(len(tuple)))(*tuple)

def _scala_none(self):
return getattr(getattr(self._jvm.scala, "None$"), "MODULE$")
return self._jvm.scala.__getattr__("None$").__getattr__("MODULE$")

def _to_scala_option(self, x: Optional[Any]):
return self._jvm.scala.Some(x) if x is not None else self._scala_none()
Expand Down Expand Up @@ -137,9 +143,9 @@ def exploding(
self,
left: DataFrame,
right: DataFrame,
left_ts_column: str = "ts",
right_ts_column: str = "ts",
partition_cols: List[str] = [],
left_ts_column: Column,
right_ts_column: Column,
partition_cols: List[Tuple[Column, Column]] = [],
):
"""
Perform a backward asof join using the left table for event times.
Expand All @@ -150,13 +156,14 @@ def exploding(
:param right_ts_column The column used for timestamps in right DF
:param partition_cols The columns used for partitioning, if used
"""
_partition_cols = map(lambda p: self._to_scala_tuple((p[0]._jc, p[1]._jc)), partition_cols) # type: ignore
return DataFrame(
self._exploding.join(
left._jdf,
right._jdf,
left_ts_column,
right_ts_column,
self._to_scala_seq(partition_cols),
left_ts_column._jc, # type: ignore
right_ts_column._jc, # type: ignore
self._to_scala_seq(list(_partition_cols)),
),
self._sql_context,
)
60 changes: 60 additions & 0 deletions python/tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,63 @@ def __init__(self, spark: SparkSession) -> None:
self.PIT_1_2_3 = spark.createDataFrame(
spark.sparkContext.parallelize(self.PIT_1_2_3_RAW), self.PIT_3_schema
)


class SmallDataExploding(SmallData):
PIT_1_2_RAW = [
[1, 4, "1z", 1, 4, "1z"],
[1, 5, "1x", 1, 5, "1x"],
[1, 7, "1y", 1, 7, "1y"],
[2, 6, "2x", 2, 6, "2x"],
[2, 8, "2y", 2, 8, "2y"],
]
PIT_1_3_RAW = [
[1, 4, "1z", 1, 1, "f3-1-1"],
[1, 5, "1x", 1, 1, "f3-1-1"],
[1, 7, "1y", 1, 6, "f3-1-6"],
[2, 6, "2x", 2, 2, "f3-2-2"],
[2, 8, "2y", 2, 8, "f3-2-8"],
]
PIT_1_2_3_RAW = [
[1, 4, "1z", 1, 4, "1z", 1, 1, "f3-1-1"],
[1, 5, "1x", 1, 5, "1x", 1, 1, "f3-1-1"],
[1, 7, "1y", 1, 7, "1y", 1, 6, "f3-1-6"],
[2, 6, "2x", 2, 6, "2x", 2, 2, "f3-2-2"],
[2, 8, "2y", 2, 8, "2y", 2, 8, "f3-2-8"],
]

PIT_2_schema: StructType = StructType(
[
StructField("id", IntegerType(), nullable=False),
StructField("ts", IntegerType(), nullable=False),
StructField("value", StringType(), nullable=False),
StructField("id", IntegerType(), nullable=False),
StructField("ts", IntegerType(), nullable=False),
StructField("value", StringType(), nullable=False),
]
)
PIT_3_schema: StructType = StructType(
[
StructField("id", IntegerType(), nullable=False),
StructField("ts", IntegerType(), nullable=False),
StructField("value", StringType(), nullable=False),
StructField("id", IntegerType(), nullable=False),
StructField("ts", IntegerType(), nullable=False),
StructField("value", StringType(), nullable=False),
StructField("id", IntegerType(), nullable=False),
StructField("ts", IntegerType(), nullable=False),
StructField("value", StringType(), nullable=False),
]
)

def __init__(self, spark: SparkSession) -> None:
super().__init__(spark)
self.PIT_1_2 = spark.createDataFrame(
spark.sparkContext.parallelize(self.PIT_1_2_RAW), self.PIT_2_schema
)
self.PIT_1_3 = spark.createDataFrame(
spark.sparkContext.parallelize(self.PIT_1_3_RAW), self.PIT_2_schema
)
self.PIT_1_2_3 = spark.createDataFrame(
spark.sparkContext.parallelize(self.PIT_1_2_3_RAW), self.PIT_3_schema
)
47 changes: 47 additions & 0 deletions python/tests/test_exploding_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,50 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

from tests.data import SmallDataExploding
from tests.utils import SparkTests


class ExplodingTest(SparkTests):
def setUp(self) -> None:
super().setUp()
self.small_data = SmallDataExploding(self.spark)

def test_two_aligned(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg2

pit_join = self.pit_context.exploding(
fg1, fg2, fg1["ts"], fg2["ts"], [(fg1["id"], fg2["id"])]
)

self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_2.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_2.collect())

def test_two_misaligned(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg3

pit_join = self.pit_context.exploding(
fg1, fg2, fg1["ts"], fg2["ts"], [(fg1["id"], fg2["id"])]
)

self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_3.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_3.collect())

def test_three_misaligned(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg2
fg3 = self.small_data.fg3

left = self.pit_context.exploding(
fg1, fg2, fg1["ts"], fg2["ts"], [(fg1["id"], fg2["id"])]
)

pit_join = self.pit_context.exploding(
left, fg3, fg1["ts"], fg3["ts"], [(fg1["id"], fg3["id"])]
)

self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_2_3.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_2_3.collect())
31 changes: 16 additions & 15 deletions scala/src/main/scala/Exploding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@

package io.github.ackuq.pit

import utils.ColumnUtils.assertColumnsInDF

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, row_number}
import org.apache.spark.sql.{Column, DataFrame}

object Exploding {

Expand All @@ -43,27 +41,27 @@ object Exploding {
* @param rightTSColumn
* The column used for timestamps in right DF
* @param partitionCols
* The columns used for partitioning, if used
* The columns used for partitioning, is a tuple consisting of left
* partition column and right partition column
* @return
* The PIT-correct view of the joined dataframes
*/
def join(
left: DataFrame,
right: DataFrame,
leftTSColumn: String = "ts",
rightTSColumn: String = "ts",
partitionCols: Seq[String] = Seq()
leftTSColumn: Column,
rightTSColumn: Column,
partitionCols: Seq[(Column, Column)] = Seq()
): DataFrame = {
if (partitionCols.nonEmpty) {
assertColumnsInDF(partitionCols, left, right)
}

// Create the equality conditions of the partitioning column
val partitionConditions =
partitionCols.map(colName => left(colName) === right(colName))
partitionCols.map(col => col._1 === col._2)

// Combine the partitioning conditions with the PIT condition
val joinConditions =
partitionConditions :+ (left(leftTSColumn) >= right(rightTSColumn))
partitionConditions :+ (leftTSColumn >= rightTSColumn)

// Reduce the sequence of conditions to a single one
val joinCondition =
joinConditions.reduce((current, previous) => current.and(previous))

Expand All @@ -73,14 +71,17 @@ object Exploding {
joinCondition
)

val windowPartitionCols = partitionCols.map(left(_)) :+ left(leftTSColumn)
// Partition each window using the partitioning columns of the left DataFrame
val windowPartitionCols = partitionCols.map(_._1) :+ leftTSColumn

// Create the Window specification
val windowSpec =
Window
.partitionBy(windowPartitionCols: _*)
.orderBy(left(leftTSColumn).desc, right(rightTSColumn).desc)
.orderBy(rightTSColumn.desc)

combined
// Take only the row with the highest timestamps within each window frame
.withColumn("rn", row_number().over(windowSpec))
.where(col("rn") === 1)
.drop("rn")
Expand Down
Loading

0 comments on commit e34c66c

Please sign in to comment.