Skip to content

Commit

Permalink
Fix incorrect semantics for BETWEEN on MV columns in the multi-stage …
Browse files Browse the repository at this point in the history
…query engine (#14135)
  • Loading branch information
yashmayya authored Oct 5, 2024
1 parent e47169c commit 8334add
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ public void testLiteralOnlyFunc()
@Test
public void testMultiValueColumnGroupBy()
throws Exception {
String pinotQuery = "SELECT count(*), arrayToMV(RandomAirports) FROM mytable "
+ "GROUP BY arrayToMV(RandomAirports)";
String pinotQuery = "SELECT count(*), ARRAY_TO_MV(RandomAirports) FROM mytable "
+ "GROUP BY ARRAY_TO_MV(RandomAirports)";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
}
Expand Down Expand Up @@ -800,8 +800,8 @@ public void skipArrayToMvOptimization()
public void testMultiValueColumnGroupByOrderBy()
throws Exception {
String pinotQuery =
"SELECT count(*), arrayToMV(RandomAirports) FROM mytable " + "GROUP BY arrayToMV(RandomAirports) "
+ "ORDER BY arrayToMV(RandomAirports) DESC";
"SELECT count(*), ARRAY_TO_MV(RandomAirports) FROM mytable GROUP BY ARRAY_TO_MV(RandomAirports) "
+ "ORDER BY ARRAY_TO_MV(RandomAirports) DESC";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
}
Expand Down Expand Up @@ -896,9 +896,76 @@ public void testSearch()
assertNoError(jsonNode);
}

@Test
public void testBetween()
throws Exception {
String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ArrDelay BETWEEN 10 AND 50";
JsonNode jsonNode = postQuery(sqlQuery);
assertNoError(jsonNode);
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 18572);

String explainQuery = "EXPLAIN PLAN FOR " + sqlQuery;
jsonNode = postQuery(explainQuery);
assertNoError(jsonNode);
String plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
// Ensure that the BETWEEN filter predicate was converted to >= and <=
Assert.assertFalse(plan.contains("BETWEEN"));
Assert.assertTrue(plan.contains(">="));
Assert.assertTrue(plan.contains("<="));

// No rows should be returned since lower bound is greater than upper bound
sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(RandomAirports) BETWEEN 'SUN' AND 'GTR'";
jsonNode = postQuery(sqlQuery);
assertNoError(jsonNode);
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 0);

explainQuery = "EXPLAIN PLAN FOR " + sqlQuery;
jsonNode = postQuery(explainQuery);
assertNoError(jsonNode);
plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
// Ensure that the BETWEEN filter predicate was not converted to >= and <=
Assert.assertTrue(plan.contains("BETWEEN"));
Assert.assertFalse(plan.contains(">="));
Assert.assertFalse(plan.contains("<="));

// Expect a non-zero result this time since we're using BETWEEN SYMMETRIC
sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(RandomAirports) BETWEEN SYMMETRIC 'SUN' AND 'GTR'";
jsonNode = postQuery(sqlQuery);
assertNoError(jsonNode);
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 57007);

explainQuery = "EXPLAIN PLAN FOR " + sqlQuery;
jsonNode = postQuery(explainQuery);
assertNoError(jsonNode);
plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
// Ensure that the BETWEEN filter predicate was not converted to >= and <=
Assert.assertTrue(plan.contains("BETWEEN"));
Assert.assertFalse(plan.contains(">="));
Assert.assertFalse(plan.contains("<="));

// Test NOT BETWEEN
sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(RandomAirports) NOT BETWEEN 'GTR' AND 'SUN'";
jsonNode = postQuery(sqlQuery);
assertNoError(jsonNode);
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 58538);

explainQuery =
"SET " + CommonConstants.Broker.Request.QueryOptionKey.EXPLAIN_ASKING_SERVERS + "=true; EXPLAIN PLAN FOR "
+ sqlQuery;
jsonNode = postQuery(explainQuery);
assertNoError(jsonNode);
plan = jsonNode.get("resultTable").get("rows").get(0).get(1).asText();
// Ensure that the BETWEEN filter predicate was not converted to >= and <=. Also ensure that the NOT filter is
// added.
Assert.assertTrue(plan.contains("BETWEEN"));
Assert.assertTrue(plan.contains("FilterNot"));
Assert.assertFalse(plan.contains(">="));
Assert.assertFalse(plan.contains("<="));
}

@Test
public void testMVNumericCastInFilter() throws Exception {
String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE arrayToMV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";
String sqlQuery = "SELECT COUNT(*) FROM mytable WHERE ARRAY_TO_MV(CAST(DivAirportIDs AS BIGINT ARRAY)) > 0";
JsonNode jsonNode = postQuery(sqlQuery);
assertNoError(jsonNode);
assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asInt(), 15482);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.fun.SqlBetweenOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql2rel.SqlRexContext;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
import org.apache.calcite.sql2rel.SqlRexConvertletTable;
import org.apache.calcite.sql2rel.StandardConvertletTable;
import org.apache.calcite.util.Litmus;


/**
Expand All @@ -37,6 +39,13 @@
public class PinotConvertletTable implements SqlRexConvertletTable {

public static final PinotConvertletTable INSTANCE = new PinotConvertletTable();
private static final SqlBetweenOperator PINOT_BETWEEN =
new SqlBetweenOperator(SqlBetweenOperator.Flag.ASYMMETRIC, false) {
@Override
public boolean validRexOperands(int count, Litmus litmus) {
return litmus.succeed();
}
};

private PinotConvertletTable() {
}
Expand All @@ -49,6 +58,8 @@ public SqlRexConvertlet get(SqlCall call) {
return TimestampAddConvertlet.INSTANCE;
case TIMESTAMP_DIFF:
return TimestampDiffConvertlet.INSTANCE;
case BETWEEN:
return BetweenConvertlet.INSTANCE;
default:
return StandardConvertletTable.INSTANCE.get(call);
}
Expand Down Expand Up @@ -85,4 +96,45 @@ public RexNode convertCall(SqlRexContext cx, SqlCall call) {
cx.convertExpression(call.operand(2))));
}
}

/**
* Override the standard convertlet for BETWEEN to avoid the rewrite to >= AND <= for MV columns since that breaks
* the filter predicate's semantics.
*/
private static class BetweenConvertlet implements SqlRexConvertlet {
private static final BetweenConvertlet INSTANCE = new BetweenConvertlet();

@Override
public RexNode convertCall(SqlRexContext cx, SqlCall call) {
if (call.operand(0) instanceof SqlCall && ((SqlCall) call.operand(0)).getOperator().getName()
.equals("ARRAY_TO_MV")) {
RexBuilder rexBuilder = cx.getRexBuilder();

SqlBetweenOperator betweenOperator = (SqlBetweenOperator) call.getOperator();

RexNode rexNode = rexBuilder.makeCall(cx.getValidator().getValidatedNodeType(call), PINOT_BETWEEN,
List.of(cx.convertExpression(call.operand(0)), cx.convertExpression(call.operand(1)),
cx.convertExpression(call.operand(2))));

// Since Pinot only has support for ASYMMETRIC BETWEEN, we need to rewrite SYMMETRIC BETWEEN, ASYMMETRIC NOT
// BETWEEN, and SYMMETRIC NOT BETWEEN to the equivalent BETWEEN expressions.

// (val BETWEEN SYMMETRIC x AND y) is equivalent to (val BETWEEN x AND y OR val BETWEEN y AND x)
if (betweenOperator.flag == SqlBetweenOperator.Flag.SYMMETRIC) {
RexNode flipped = rexBuilder.makeCall(cx.getValidator().getValidatedNodeType(call), PINOT_BETWEEN,
List.of(cx.convertExpression(call.operand(0)), cx.convertExpression(call.operand(2)),
cx.convertExpression(call.operand(1))));
rexNode = rexBuilder.makeCall(SqlStdOperatorTable.OR, rexNode, flipped);
}

if (betweenOperator.isNegated()) {
rexNode = rexBuilder.makeCall(SqlStdOperatorTable.NOT, rexNode);
}

return rexNode;
} else {
return StandardConvertletTable.INSTANCE.convertBetween(cx, (SqlBetweenOperator) call.getOperator(), call);
}
}
}
}

0 comments on commit 8334add

Please sign in to comment.