diff --git a/README.md b/README.md index 52d2f2b..d1fb477 100644 --- a/README.md +++ b/README.md @@ -79,10 +79,12 @@ pit_join = df1.join(df2, pit_context.pit_udf(df1.ts, df2.ts) & (df1.id == df2.i ```py pit_join = pit_context.union_as_of( - df1, - df2, + left=df1, + right=df2, left_prefix="df1_", right_prefix="df2_", + left_ts_column = "ts", + right_ts_column = "ts", partition_cols=["id"], ) ``` @@ -91,9 +93,11 @@ pit_join = pit_context.union_as_of( ```py pit_join = pit_context.exploding( - df1, - df2, - partition_cols = ["id"], + left=df1, + right=df2, + left_ts_column=df1["ts"], + right_ts_column=df2["ts"], + partition_cols = [df1["id"], df2["id"]], ) ``` @@ -134,6 +138,8 @@ import io.github.ackuq.pit.Exploding val pitJoin = Exploding.join( df1, df2, - partitionCols = Seq("id") + leftTSColumn = df1("ts"), + rightTSColumn = df2("ts"), + partitionCols = Seq((df1("id"), df2("id"))) ) ``` diff --git a/python/README.md b/python/README.md index 8e7b4f9..c40f3c3 100644 --- a/python/README.md +++ b/python/README.md @@ -36,10 +36,12 @@ pit_join = df1.join(df2, pit_context.pit_udf(df1.ts, df2.ts) & (df1.id == df2.i ```py pit_join = pit_context.union_as_of( - df1, - df2, + left=df1, + right=df2, left_prefix="df1_", right_prefix="df2_", + left_ts_column = "ts", + right_ts_column = "ts", partition_cols=["id"], ) ``` @@ -48,8 +50,10 @@ pit_join = pit_context.union_as_of( ```py pit_join = pit_context.exploding( - df1, - df2, - partition_cols = ["id"], + left=df1, + right=df2, + left_ts_column=df1["ts"], + right_ts_column=df2["ts"], + partition_cols = [df1["id"], df2["id"]], ) ``` diff --git a/python/ackuq/pit/context.py b/python/ackuq/pit/context.py index 26b8ae7..aaf05bc 100644 --- a/python/ackuq/pit/context.py +++ b/python/ackuq/pit/context.py @@ -22,7 +22,7 @@ # SOFTWARE. # -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple from py4j.java_gateway import JavaPackage from pyspark.sql import Column, DataFrame, SQLContext @@ -82,8 +82,14 @@ def _to_scala_seq(self, list_like: List): .toSeq() ) + def _to_scala_tuple(self, tuple: Tuple): + """ + Converts a Python tuple to Scala tuple + """ + return self._jvm.__getattr__("scala.Tuple" + str(len(tuple)))(*tuple) + def _scala_none(self): - return getattr(getattr(self._jvm.scala, "None$"), "MODULE$") + return self._jvm.scala.__getattr__("None$").__getattr__("MODULE$") def _to_scala_option(self, x: Optional[Any]): return self._jvm.scala.Some(x) if x is not None else self._scala_none() @@ -137,9 +143,9 @@ def exploding( self, left: DataFrame, right: DataFrame, - left_ts_column: str = "ts", - right_ts_column: str = "ts", - partition_cols: List[str] = [], + left_ts_column: Column, + right_ts_column: Column, + partition_cols: List[Tuple[Column, Column]] = [], ): """ Perform a backward asof join using the left table for event times. @@ -150,13 +156,14 @@ def exploding( :param right_ts_column The column used for timestamps in right DF :param partition_cols The columns used for partitioning, if used """ + _partition_cols = map(lambda p: self._to_scala_tuple((p[0]._jc, p[1]._jc)), partition_cols) # type: ignore return DataFrame( self._exploding.join( left._jdf, right._jdf, - left_ts_column, - right_ts_column, - self._to_scala_seq(partition_cols), + left_ts_column._jc, # type: ignore + right_ts_column._jc, # type: ignore + self._to_scala_seq(list(_partition_cols)), ), self._sql_context, ) diff --git a/python/tests/data.py b/python/tests/data.py index d6c98eb..42c0d58 100644 --- a/python/tests/data.py +++ b/python/tests/data.py @@ -182,3 +182,63 @@ def __init__(self, spark: SparkSession) -> None: self.PIT_1_2_3 = spark.createDataFrame( spark.sparkContext.parallelize(self.PIT_1_2_3_RAW), self.PIT_3_schema ) + + +class SmallDataExploding(SmallData): + PIT_1_2_RAW = [ + [1, 4, "1z", 1, 4, "1z"], + [1, 5, "1x", 1, 5, "1x"], + [1, 7, "1y", 1, 7, "1y"], + [2, 6, "2x", 2, 6, "2x"], + [2, 8, "2y", 2, 8, "2y"], + ] + PIT_1_3_RAW = [ + [1, 4, "1z", 1, 1, "f3-1-1"], + [1, 5, "1x", 1, 1, "f3-1-1"], + [1, 7, "1y", 1, 6, "f3-1-6"], + [2, 6, "2x", 2, 2, "f3-2-2"], + [2, 8, "2y", 2, 8, "f3-2-8"], + ] + PIT_1_2_3_RAW = [ + [1, 4, "1z", 1, 4, "1z", 1, 1, "f3-1-1"], + [1, 5, "1x", 1, 5, "1x", 1, 1, "f3-1-1"], + [1, 7, "1y", 1, 7, "1y", 1, 6, "f3-1-6"], + [2, 6, "2x", 2, 6, "2x", 2, 2, "f3-2-2"], + [2, 8, "2y", 2, 8, "2y", 2, 8, "f3-2-8"], + ] + + PIT_2_schema: StructType = StructType( + [ + StructField("id", IntegerType(), nullable=False), + StructField("ts", IntegerType(), nullable=False), + StructField("value", StringType(), nullable=False), + StructField("id", IntegerType(), nullable=False), + StructField("ts", IntegerType(), nullable=False), + StructField("value", StringType(), nullable=False), + ] + ) + PIT_3_schema: StructType = StructType( + [ + StructField("id", IntegerType(), nullable=False), + StructField("ts", IntegerType(), nullable=False), + StructField("value", StringType(), nullable=False), + StructField("id", IntegerType(), nullable=False), + StructField("ts", IntegerType(), nullable=False), + StructField("value", StringType(), nullable=False), + StructField("id", IntegerType(), nullable=False), + StructField("ts", IntegerType(), nullable=False), + StructField("value", StringType(), nullable=False), + ] + ) + + def __init__(self, spark: SparkSession) -> None: + super().__init__(spark) + self.PIT_1_2 = spark.createDataFrame( + spark.sparkContext.parallelize(self.PIT_1_2_RAW), self.PIT_2_schema + ) + self.PIT_1_3 = spark.createDataFrame( + spark.sparkContext.parallelize(self.PIT_1_3_RAW), self.PIT_2_schema + ) + self.PIT_1_2_3 = spark.createDataFrame( + spark.sparkContext.parallelize(self.PIT_1_2_3_RAW), self.PIT_3_schema + ) diff --git a/python/tests/test_exploding_pit.py b/python/tests/test_exploding_pit.py index 199461f..a181832 100644 --- a/python/tests/test_exploding_pit.py +++ b/python/tests/test_exploding_pit.py @@ -21,3 +21,50 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # + +from tests.data import SmallDataExploding +from tests.utils import SparkTests + + +class ExplodingTest(SparkTests): + def setUp(self) -> None: + super().setUp() + self.small_data = SmallDataExploding(self.spark) + + def test_two_aligned(self): + fg1 = self.small_data.fg1 + fg2 = self.small_data.fg2 + + pit_join = self.pit_context.exploding( + fg1, fg2, fg1["ts"], fg2["ts"], [(fg1["id"], fg2["id"])] + ) + + self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_2.schema) + self.assertEqual(pit_join.collect(), self.small_data.PIT_1_2.collect()) + + def test_two_misaligned(self): + fg1 = self.small_data.fg1 + fg2 = self.small_data.fg3 + + pit_join = self.pit_context.exploding( + fg1, fg2, fg1["ts"], fg2["ts"], [(fg1["id"], fg2["id"])] + ) + + self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_3.schema) + self.assertEqual(pit_join.collect(), self.small_data.PIT_1_3.collect()) + + def test_three_misaligned(self): + fg1 = self.small_data.fg1 + fg2 = self.small_data.fg2 + fg3 = self.small_data.fg3 + + left = self.pit_context.exploding( + fg1, fg2, fg1["ts"], fg2["ts"], [(fg1["id"], fg2["id"])] + ) + + pit_join = self.pit_context.exploding( + left, fg3, fg1["ts"], fg3["ts"], [(fg1["id"], fg3["id"])] + ) + + self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_2_3.schema) + self.assertEqual(pit_join.collect(), self.small_data.PIT_1_2_3.collect()) diff --git a/scala/src/main/scala/Exploding.scala b/scala/src/main/scala/Exploding.scala index ab7e44e..d2139a3 100644 --- a/scala/src/main/scala/Exploding.scala +++ b/scala/src/main/scala/Exploding.scala @@ -24,11 +24,9 @@ package io.github.ackuq.pit -import utils.ColumnUtils.assertColumnsInDF - -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{col, row_number} +import org.apache.spark.sql.{Column, DataFrame} object Exploding { @@ -43,27 +41,27 @@ object Exploding { * @param rightTSColumn * The column used for timestamps in right DF * @param partitionCols - * The columns used for partitioning, if used + * The columns used for partitioning, is a tuple consisting of left + * partition column and right partition column * @return * The PIT-correct view of the joined dataframes */ def join( left: DataFrame, right: DataFrame, - leftTSColumn: String = "ts", - rightTSColumn: String = "ts", - partitionCols: Seq[String] = Seq() + leftTSColumn: Column, + rightTSColumn: Column, + partitionCols: Seq[(Column, Column)] = Seq() ): DataFrame = { - if (partitionCols.nonEmpty) { - assertColumnsInDF(partitionCols, left, right) - } - + // Create the equality conditions of the partitioning column val partitionConditions = - partitionCols.map(colName => left(colName) === right(colName)) + partitionCols.map(col => col._1 === col._2) + // Combine the partitioning conditions with the PIT condition val joinConditions = - partitionConditions :+ (left(leftTSColumn) >= right(rightTSColumn)) + partitionConditions :+ (leftTSColumn >= rightTSColumn) + // Reduce the sequence of conditions to a single one val joinCondition = joinConditions.reduce((current, previous) => current.and(previous)) @@ -73,14 +71,17 @@ object Exploding { joinCondition ) - val windowPartitionCols = partitionCols.map(left(_)) :+ left(leftTSColumn) + // Partition each window using the partitioning columns of the left DataFrame + val windowPartitionCols = partitionCols.map(_._1) :+ leftTSColumn + // Create the Window specification val windowSpec = Window .partitionBy(windowPartitionCols: _*) - .orderBy(left(leftTSColumn).desc, right(rightTSColumn).desc) + .orderBy(rightTSColumn.desc) combined + // Take only the row with the highest timestamps within each window frame .withColumn("rn", row_number().over(windowSpec)) .where(col("rn") === 1) .drop("rn") diff --git a/scala/src/test/scala/ExplodingTests.scala b/scala/src/test/scala/ExplodingTests.scala new file mode 100644 index 0000000..fd67297 --- /dev/null +++ b/scala/src/test/scala/ExplodingTests.scala @@ -0,0 +1,103 @@ +/* + * MIT License + * + * Copyright (c) 2022 Axel Pettersson + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package io.github.ackuq.pit + +import EarlyStopSortMerge.pit +import data.{SmallDataExploding, SmallDataSortMerge} + +import org.scalatest.flatspec.AnyFlatSpec + +class ExplodingTests extends AnyFlatSpec with SparkSessionTestWrapper { + val smallData = new SmallDataExploding(spark) + + it should "Perform a PIT join with two dataframes, aligned timestamps" in { + val fg1 = smallData.fg1 + val fg2 = smallData.fg2 + + val pitJoin = + Exploding.join( + fg1, + fg2, + leftTSColumn = fg1("ts"), + rightTSColumn = fg2("ts"), + partitionCols = Seq((fg1("id"), fg2("id"))) + ) + + assert(!pitJoin.isEmpty) + // Assert same schema + assert(pitJoin.schema.equals(smallData.PIT_1_2.schema)) + // Assert same elements + assert(pitJoin.collect().sameElements(smallData.PIT_1_2.collect())) + } + + it should "Perform a PIT join with two dataframes, misaligned timestamps" in { + val fg1 = smallData.fg1 + val fg2 = smallData.fg3 + + val pitJoin = + Exploding.join( + fg1, + fg2, + leftTSColumn = fg1("ts"), + rightTSColumn = fg2("ts"), + partitionCols = Seq((fg1("id"), fg2("id"))) + ) + + assert(!pitJoin.isEmpty) + // Assert same schema + assert(pitJoin.schema.equals(smallData.PIT_1_3.schema)) + // Assert same elements + assert(pitJoin.collect().sameElements(smallData.PIT_1_3.collect())) + } + + it should "Perform a PIT join with three dataframes, misaligned timestamps" in { + val fg1 = smallData.fg1 + val fg2 = smallData.fg2 + val fg3 = smallData.fg3 + + val left = + Exploding.join( + fg1, + fg2, + leftTSColumn = fg1("ts"), + rightTSColumn = fg2("ts"), + partitionCols = Seq((fg1("id"), fg2("id"))) + ) + + val pitJoin = Exploding.join( + left, + fg3, + leftTSColumn = fg1("ts"), + rightTSColumn = fg3("ts"), + partitionCols = Seq((fg1("id"), fg3("id"))) + ) + + assert(!pitJoin.isEmpty) + // Assert same schema + assert(pitJoin.schema.equals(smallData.PIT_1_2_3.schema)) + // Assert same elements + assert(pitJoin.collect().sameElements(smallData.PIT_1_2_3.collect())) + } +} diff --git a/scala/src/test/scala/data/SmallDataExploding.scala b/scala/src/test/scala/data/SmallDataExploding.scala new file mode 100644 index 0000000..cc395df --- /dev/null +++ b/scala/src/test/scala/data/SmallDataExploding.scala @@ -0,0 +1,94 @@ +/* + * MIT License + * + * Copyright (c) 2022 Axel Pettersson + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package io.github.ackuq.pit +package data + +import org.apache.spark.sql.types.{ + IntegerType, + StringType, + StructField, + StructType +} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} + +class SmallDataExploding(spark: SparkSession) extends SmallData(spark) { + private val PIT_1_2_RAW = Seq( + Row(1, 4, "1z", 1, 4, "1z"), + Row(1, 5, "1x", 1, 5, "1x"), + Row(1, 7, "1y", 1, 7, "1y"), + Row(2, 6, "2x", 2, 6, "2x"), + Row(2, 8, "2y", 2, 8, "2y") + ) + private val PIT_1_3_RAW = Seq( + Row(1, 4, "1z", 1, 1, "f3-1-1"), + Row(1, 5, "1x", 1, 1, "f3-1-1"), + Row(1, 7, "1y", 1, 6, "f3-1-6"), + Row(2, 6, "2x", 2, 2, "f3-2-2"), + Row(2, 8, "2y", 2, 8, "f3-2-8") + ) + private val PIT_1_2_3_RAW = Seq( + Row(1, 4, "1z", 1, 4, "1z", 1, 1, "f3-1-1"), + Row(1, 5, "1x", 1, 5, "1x", 1, 1, "f3-1-1"), + Row(1, 7, "1y", 1, 7, "1y", 1, 6, "f3-1-6"), + Row(2, 6, "2x", 2, 6, "2x", 2, 2, "f3-2-2"), + Row(2, 8, "2y", 2, 8, "2y", 2, 8, "f3-2-8") + ) + private val PIT_2_schema: StructType = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("ts", IntegerType, nullable = false), + StructField("value", StringType, nullable = false), + StructField("id", IntegerType, nullable = false), + StructField("ts", IntegerType, nullable = false), + StructField("value", StringType, nullable = false) + ) + ) + private val PIT_3_schema: StructType = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("ts", IntegerType, nullable = false), + StructField("value", StringType, nullable = false), + StructField("id", IntegerType, nullable = false), + StructField("ts", IntegerType, nullable = false), + StructField("value", StringType, nullable = false), + StructField("id", IntegerType, nullable = false), + StructField("ts", IntegerType, nullable = false), + StructField("value", StringType, nullable = false) + ) + ) + + val PIT_1_2: DataFrame = spark.createDataFrame( + spark.sparkContext.parallelize(PIT_1_2_RAW), + PIT_2_schema + ) + val PIT_1_3: DataFrame = spark.createDataFrame( + spark.sparkContext.parallelize(PIT_1_3_RAW), + PIT_2_schema + ) + val PIT_1_2_3: DataFrame = spark.createDataFrame( + spark.sparkContext.parallelize(PIT_1_2_3_RAW), + PIT_3_schema + ) +}