Skip to content

Commit

Permalink
Variable #key won't be added to context after window nodes - tests fixed
Browse files Browse the repository at this point in the history
If two or more window nodes are not separated by a union, then #key variable is created with null
  • Loading branch information
Szymon Bogusz committed Jan 28, 2025
1 parent 9f5ee5a commit eab62f6
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
}

private val processValidator: ProcessValidator = ProcessValidator.default(modelData())
private val idForTest: String = "fragmentResult"

test("aggregates are properly validated") {
validateOk("#AGG.approxCardinality", "#input.str", Typed[Long])
Expand Down Expand Up @@ -140,7 +141,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = sliding("#AGG.sum", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(1, 3, 7)
}

Expand All @@ -151,7 +152,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 0, "a"), TestRecordHours(id, 1, 1, "b"), TestRecordHours(id, 2, 0, "b")))
val testProcess = sliding("#AGG.sum", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(0, 1, 1)
}

Expand All @@ -163,7 +164,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
val testProcess =
sliding("#AGG.countWhen", """#input.str == "a" || #input.str == "b" """, emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(1, 2, 1)
}

Expand All @@ -174,7 +175,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = sliding("#AGG.average", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(1.0d, 1.5, 3.5)
}

Expand All @@ -185,7 +186,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = sliding("#AGG.median", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(1.0d, 1.5, 3.5)
}

Expand All @@ -205,7 +206,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b")))
val testProcess = sliding(aggregationName, "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
val mapped = aggregateVariables
.map(e => e.asInstanceOf[Double])
mapped.size shouldBe 2
Expand All @@ -222,7 +223,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
val testProcess =
sliding("#AGG.sum", "#input.eId", emitWhenEventLeft = false, afterAggregateExpression = "#input.eId")

val nodeResults = runCollectOutputVariables(id, model, testProcess)
val nodeResults = runCollectOutputVariables(idForTest, model, testProcess)
nodeResults.map(_.variableTyped[Number]("fooVar").get) shouldBe List(1, 2, 5)
}

Expand All @@ -239,7 +240,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
val testProcess = sliding("#AGG.sum", "#input.eId", emitWhenEventLeft = false)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(1, 3, 7, 4)
}

Expand All @@ -255,7 +256,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
val testProcess = sliding("#AGG.sum", "#input.eId", emitWhenEventLeft = true)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(1, 3, 7, 5, 0)
}

Expand All @@ -282,7 +283,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = tumbling("#AGG.sum", "#input.eId", emitWhen = TumblingWindowTrigger.OnEnd)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(3, 5)
}

Expand All @@ -298,7 +299,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
val testProcess = tumbling("#AGG.set", "#input.eId", emitWhen = TumblingWindowTrigger.OnEnd)

val aggregateVariables = runCollectOutputAggregate[Set[Number]](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Set[Number]](idForTest, model, testProcess)
aggregateVariables shouldBe List(Set(1, 2), Set(5), Set(6)).map(_.asJava)
}

Expand Down Expand Up @@ -328,7 +329,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
Map("windowLength" -> "T(java.time.Duration).parse('P1D')")
)

val aggregateVariables = runCollectOutputAggregate[java.util.Set[Number]](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[java.util.Set[Number]](idForTest, model, testProcess)
var expected = List(Set(1), Set(2, 5), Set(7))
if (trigger == TumblingWindowTrigger.OnEndWithExtraWindow) {
expected = expected :+ Set()
Expand Down Expand Up @@ -363,7 +364,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
val model =
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))

val aggregateVariables = runCollectOutputAggregate[java.util.Map[String, Any]](id, model, resolvedScenario)
val aggregateVariables = runCollectOutputAggregate[java.util.Map[String, Any]](idForTest, model, resolvedScenario)
aggregateVariables.map(_.asScala("aggresult")) shouldBe List(3, 5)
}

Expand Down Expand Up @@ -407,7 +408,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
afterAggregateExpression = "#input.eId"
)

val nodeResults = runCollectOutputVariables(id, model, testProcess)
val nodeResults = runCollectOutputVariables(idForTest, model, testProcess)

nodeResults.map(_.variableTyped[Number]("fooVar").get) shouldBe List(1, 2, 5)

Expand All @@ -429,7 +430,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
val testProcess = tumbling("#AGG.sum", "#input.eId", emitWhen = TumblingWindowTrigger.OnEnd)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(4, 5)
}

Expand All @@ -446,7 +447,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
) // lost because watermark advanced to 2
val testProcess = tumbling("#AGG.sum", "#input.eId", emitWhen = TumblingWindowTrigger.OnEnd)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(3, 5)
}

Expand All @@ -457,7 +458,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a"), TestRecordHours(id, 1, 2, "b"), TestRecordHours(id, 2, 5, "b")))
val testProcess = tumbling("#AGG.sum", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(3, 5, 0)
}

Expand All @@ -468,7 +469,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess = tumbling("#AGG.average", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables.length shouldEqual (2)
aggregateVariables(0) shouldEqual 1.0
aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true
Expand All @@ -481,7 +482,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess = tumbling("#AGG.median", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables.length shouldEqual (2)
aggregateVariables(0) shouldEqual 1.0
aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true
Expand All @@ -505,7 +506,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
modelData(List(TestRecordHours(id, 0, 1, "a")))
val testProcess = tumbling(aggregatorName, "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables.length shouldEqual (2)
aggregateVariables(0) shouldEqual 0.0
aggregateVariables(1).asInstanceOf[Double].isNaN shouldBe true
Expand All @@ -520,7 +521,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
val testProcess =
tumbling("#AGG.average", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null)
}

Expand All @@ -532,7 +533,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
val testProcess =
tumbling("#AGG.median", """T(java.math.BigDecimal).ONE""", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldEqual List(new java.math.BigDecimal("1"), null)
}

Expand All @@ -559,7 +560,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow
)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldEqual List(new java.math.BigDecimal("0"), null)
}
}
Expand All @@ -577,7 +578,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
val testProcess = tumbling("#AGG.sum", "#input.eId", emitWhen = TumblingWindowTrigger.OnEndWithExtraWindow)

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(4, 5, 0)
}

Expand All @@ -601,10 +602,10 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
val testProcess = session("#AGG.list", "#input.eId", SessionWindowTrigger.OnEnd, "#input.str == 'stop'")

val aggregateVariables = runCollectOutputAggregate[Number](id, model, testProcess)
val aggregateVariables = runCollectOutputAggregate[Number](idForTest, model, testProcess)
aggregateVariables shouldBe List(asList(4, 3, 2, 1), asList(7, 6, 5), asList(8))

val nodeResults = runCollectOutputVariables(id, model, testProcess)
val nodeResults = runCollectOutputVariables(idForTest, model, testProcess)
nodeResults.flatMap(_.variableTyped[TestRecordHours]("input")) shouldBe Nil
}

Expand All @@ -624,7 +625,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
val model = modelData(testRecords)
val testProcess = session("#AGG.list", "#input.eId", SessionWindowTrigger.OnEvent, "#input.str == 'stop'")

val outputVariables = runCollectOutputVariables(id, model, testProcess)
val outputVariables = runCollectOutputVariables(idForTest, model, testProcess)
outputVariables.map(_.variableTyped[java.util.List[Number]]("fragmentResult").get) shouldBe List(
asList(1),
asList(2, 1),
Expand All @@ -651,21 +652,23 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)
)
val testProcess = sliding(
"#AGG.map({sum: #AGG.sum, first: #AGG.first, last: #AGG.last, set: #AGG.set, hll: #AGG.approxCardinality})",
"{sum: #input.eId, first: #input.eId, last: #input.eId, set: #input.str, hll: #input.str}",
"#AGG.map({id: #AGG.first, sum: #AGG.sum, first: #AGG.first, last: #AGG.last, set: #AGG.set, hll: #AGG.approxCardinality})",
"{id: #input.id, sum: #input.eId, first: #input.eId, last: #input.eId, set: #input.str, hll: #input.str}",
emitWhenEventLeft = false
)

val aggregateVariables = runCollectOutputAggregate[util.Map[String, Any]](id, model, testProcess).map(_.asScala)
val aggregateVariables = runCollectOutputAggregate[util.Map[String, Any]](idForTest, model, testProcess)
.map(_.asScala)
.filter(_("id") == id)

aggregateVariables shouldBe List(
Map("first" -> 1, "last" -> 1, "hll" -> 1, "sum" -> 1, "set" -> Set("a").asJava),
Map("first" -> 1, "last" -> 2, "hll" -> 2, "sum" -> 3, "set" -> Set("a", "b").asJava),
Map("first" -> 2, "last" -> 3, "hll" -> 2, "sum" -> 5, "set" -> Set("b", "c").asJava),
Map("first" -> 2, "last" -> 4, "hll" -> 3, "sum" -> 9, "set" -> Set("b", "c", "d").asJava),
Map("first" -> 3, "last" -> 6, "hll" -> 3, "sum" -> 13, "set" -> Set("c", "d", "e").asJava),
Map("first" -> 6, "last" -> 7, "hll" -> 2, "sum" -> 13, "set" -> Set("e", "a").asJava),
Map("first" -> 6, "last" -> 8, "hll" -> 3, "sum" -> 21, "set" -> Set("e", "a", "b").asJava)
Map("id" -> id, "first" -> 1, "last" -> 1, "hll" -> 1, "sum" -> 1, "set" -> Set("a").asJava),
Map("id" -> id, "first" -> 1, "last" -> 2, "hll" -> 2, "sum" -> 3, "set" -> Set("a", "b").asJava),
Map("id" -> id, "first" -> 2, "last" -> 3, "hll" -> 2, "sum" -> 5, "set" -> Set("b", "c").asJava),
Map("id" -> id, "first" -> 2, "last" -> 4, "hll" -> 3, "sum" -> 9, "set" -> Set("b", "c", "d").asJava),
Map("id" -> id, "first" -> 3, "last" -> 6, "hll" -> 3, "sum" -> 13, "set" -> Set("c", "d", "e").asJava),
Map("id" -> id, "first" -> 6, "last" -> 7, "hll" -> 2, "sum" -> 13, "set" -> Set("e", "a").asJava),
Map("id" -> id, "first" -> 6, "last" -> 8, "hll" -> 3, "sum" -> 21, "set" -> Set("e", "a", "b").asJava)
)
}

Expand Down Expand Up @@ -698,8 +701,8 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
.map(name =>
AggregateData(
"aggregate-sliding",
s"#AGG.$name",
"#input.eId",
s"#AGG.map({id: #AGG.first, name: #AGG.$name})",
"{id: #input.id, name: #input.eId}",
"windowLength",
Map("emitWhenEventLeft" -> "false"),
name
Expand All @@ -708,9 +711,9 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
)

runProcess(model, testProcess)
val lastResult = variablesForKey(collectingListener, id).last
val lastResult = variablesForKey(collectingListener, "id").last
aggregates.foreach { case (name, expected) =>
lastResult.variableTyped[AnyRef](s"fragmentResult$name").get shouldBe expected
lastResult.variableTyped[java.util.Map[String, Any]](s"fragmentResult$name").get.get("name") shouldBe expected
}
}

Expand Down Expand Up @@ -747,7 +750,7 @@ class TransformersTest extends AnyFunSuite with FlinkSpec with Matchers with Ins
): List[TestProcess.ResultContext[Any]] = {
collectingListener.results
.nodeResults("end")
.filter(_.variableTyped[String](VariableConstants.KeyVariableName).contains(key))
.filter(_.variableTyped[String](key).isDefined)
}

private def validateError(aggregator: String, aggregateBy: String, error: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class EmitExtraWindowWhenNoDataTumblingAggregatorFunction[MapT[K, V]](
out.collect(
ValueWithContext(
finalVal,
KeyEnricher.enrichWithKey(NkContext(contextIdGenerator.nextContextId()), ctx.getCurrentKey)
NkContext(contextIdGenerator.nextContextId())
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class EmitWhenEventLeftAggregatorFunction[MapT[K, V]](
out.collect(
ValueWithContext(
finalVal,
KeyEnricher.enrichWithKey(NkContext(contextIdGenerator.nextContextId()), ctx.getCurrentKey)
NkContext(contextIdGenerator.nextContextId())
)
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package pl.touk.nussknacker.engine.flink.util.transformer.aggregate

import org.apache.flink.api.common.functions.{OpenContext, RuntimeContext}
import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.flink.util.Collector
import pl.touk.nussknacker.engine.api.runtimecontext.{ContextIdGenerator, EngineRuntimeContext}
import pl.touk.nussknacker.engine.api.{Context => NkContext, ValueWithContext}
import pl.touk.nussknacker.engine.flink.api.process.FlinkCustomNodeContext
import pl.touk.nussknacker.engine.flink.util.keyed.KeyEnricher

import java.lang

class NoOpFunction(convertToEngineRuntimeContext: RuntimeContext => EngineRuntimeContext, nodeId: String)
extends ProcessWindowFunction[AnyRef, ValueWithContext[AnyRef], String, TimeWindow] {

@transient
private var contextIdGenerator: ContextIdGenerator = _

override def open(openContext: OpenContext): Unit = {
contextIdGenerator = convertToEngineRuntimeContext(getRuntimeContext).contextIdGenerator(nodeId)
}

override def process(
key: String,
context: ProcessWindowFunction[AnyRef, ValueWithContext[AnyRef], String, TimeWindow]#Context,
values: lang.Iterable[AnyRef],
out: Collector[ValueWithContext[AnyRef]]
): Unit = {
values.forEach({ value =>
out.collect(
ValueWithContext(value, NkContext(contextIdGenerator.nextContextId()))
)
})
}

}

object NoOpFunction {
def apply(fctx: FlinkCustomNodeContext): NoOpFunction =
new NoOpFunction(fctx.convertToEngineRuntimeContext, fctx.nodeId)
}
Loading

0 comments on commit eab62f6

Please sign in to comment.