From d62dbe2cffb32f7873d6c5bc51fe4c2f25b7be74 Mon Sep 17 00:00:00 2001 From: Axel Pettersson Date: Wed, 4 May 2022 18:25:54 +0200 Subject: [PATCH] Fix zipping of keys --- .../main/scala/execution/PITJoinExec.scala | 69 ++++++++++--------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/scala/src/main/scala/execution/PITJoinExec.scala b/scala/src/main/scala/execution/PITJoinExec.scala index 35a130a..bb225ec 100644 --- a/scala/src/main/scala/execution/PITJoinExec.scala +++ b/scala/src/main/scala/execution/PITJoinExec.scala @@ -270,10 +270,11 @@ protected[pit] case class PITJoinExec( private def copyKeys( ctx: CodegenContext, - vars: Seq[ExprCode] + vars: Seq[ExprCode], + keys: Seq[Expression] ): Seq[ExprCode] = { vars.zipWithIndex.map { case (ev, i) => - ctx.addBufferedState(leftKeys(i).dataType, "value", ev.value) + ctx.addBufferedState(keys(i).dataType, "value", ev.value) } } @@ -299,7 +300,7 @@ protected[pit] case class PITJoinExec( val toleranceCheck = leftPIT.zip(rightPIT).zipWithIndex.map { case ((l, r), i) => s""" - | (${l.value} - ${r.value} > ${tolerance}) + | (${l.value} - ${r.value} > $tolerance) |""".stripMargin } toleranceCheck.mkString(" && ") @@ -400,7 +401,7 @@ protected[pit] case class PITJoinExec( ): Seq[ExprCode] = { ctx.INPUT_ROW = rightRow right.output.zipWithIndex.map { case (a, i) => - val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx); + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) if (returnNulls) { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -456,14 +457,14 @@ protected[pit] case class PITJoinExec( rightEquiKeyTmpVars.map(_.isNull).mkString(" || ") ) // Copy the right key as class members so they could be used in next function call. - val rightPITKeyVars = copyKeys(ctx, rightPITKeyTmpVars) - val rightEquiKeyVars = copyKeys(ctx, rightEquiKeyTmpVars) + val rightPITKeyVars = copyKeys(ctx, rightPITKeyTmpVars, leftPitKeys) + val rightEquiKeyVars = copyKeys(ctx, rightEquiKeyTmpVars, leftEquiKeys) val matched = ctx.addMutableState("InternalRow", "matched", forceInline = true) - val matchedPITKeyVars = copyKeys(ctx, leftPITKeyVars) - val matchedEquiKeyVars = copyKeys(ctx, leftEquiKeyVars) + val matchedPITKeyVars = copyKeys(ctx, leftPITKeyVars, leftPitKeys) + val matchedEquiKeyVars = copyKeys(ctx, leftEquiKeyVars, leftEquiKeys) ctx.addNewFunction( "findNextInnerJoinRows", @@ -626,32 +627,32 @@ protected[pit] case class PITJoinExec( val thisPlan = ctx.addReferenceObj("plan", this) val eagerCleanup = s"$thisPlan.cleanupResources();" - returnNulls match { - case false => s""" - |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | ${leftVarDecl.mkString("\n")} - | ${beforeLoop.trim} - | InternalRow $rightRow = (InternalRow) $matched; - | ${condCheck.trim} - | $numOutput.add(1); - | ${consume(ctx, leftVars ++ rightVars)} - | if (shouldStop()) return; - |} - |$eagerCleanup - |""".stripMargin - case true => - s""" - |while($leftInput.hasNext()) { - | findNextInnerJoinRows($leftInput, $rightInput); - | ${leftVarDecl.mkString("\n")} - | ${beforeLoop.trim} - | InternalRow $rightRow = (InternalRow) $matched; - | ${condCheck.trim} - | $numOutput.add(1); - | ${consume(ctx, leftVars ++ rightVars)}; - | if (shouldStop()) return; - |} - |""".stripMargin + if (returnNulls) { + s""" + |while($leftInput.hasNext()) { + | findNextInnerJoinRows($leftInput, $rightInput); + | ${leftVarDecl.mkString("\n")} + | ${beforeLoop.trim} + | InternalRow $rightRow = (InternalRow) $matched; + | ${condCheck.trim} + | $numOutput.add(1); + | ${consume(ctx, leftVars ++ rightVars)}; + | if (shouldStop()) return; + |} + |""".stripMargin + } else { + s""" + |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | ${leftVarDecl.mkString("\n")} + | ${beforeLoop.trim} + | InternalRow $rightRow = (InternalRow) $matched; + | ${condCheck.trim} + | $numOutput.add(1); + | ${consume(ctx, leftVars ++ rightVars)} + | if (shouldStop()) return; + |} + |$eagerCleanup + |""".stripMargin } } }