Skip to content

Commit

Permalink
Support MERGE for MySQL connector
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjian2664 committed Jan 15, 2025
1 parent 1b1359c commit 5b93e35
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1855,47 +1855,47 @@ public void testUpdateWithPredicates()
public void testConstantUpdateWithVarcharEqualityPredicates()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = newTrinoTable("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"))) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", MODIFYING_ROWS_MESSAGE);
return;
}
assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a'), (20, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a', 1), (20, 'A', 2)");
}
}

@Test
public void testConstantUpdateWithVarcharInequalityPredicates()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (2, 'A', 2)");
}
}

@Test
public void testConstantUpdateWithVarcharGreaterAndLowerPredicate()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", MODIFYING_ROWS_MESSAGE);
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (2, 'A', 2)");

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'a'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (20, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (20, 'A', 2)");
}
}

Expand All @@ -1921,14 +1921,14 @@ public void testDeleteWithVarcharEqualityPredicate()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_ROW_LEVEL_DELETE));
// TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated
try (TestTable table = newTrinoTable("test_delete_varchar", "(col varchar(1))", ImmutableList.of("'a'", "'A'", "null"))) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)) {
try (TestTable table = createTestTableForWrites( "test_delete_varchar", "(col varchar(1), pk INT)", ImmutableList.of("'a', 1", "'A', 2", "null, 3"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("DELETE FROM " + table.getName() + " WHERE col = 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("DELETE FROM " + table.getName() + " WHERE col = 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES 'a', null");
assertQuery("SELECT col FROM " + table.getName(), "VALUES 'a', null");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@
import java.time.OffsetDateTime;
import java.util.AbstractMap.SimpleEntry;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Stream;

Expand Down Expand Up @@ -975,6 +977,12 @@ public boolean isLimitGuaranteed(ConnectorSession session)
return true;
}

@Override
public boolean supportsMerge()
{
return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
Expand Down Expand Up @@ -1159,6 +1167,50 @@ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableH
}
}

@Override
public List<JdbcColumnHandle> getPrimaryKeys(ConnectorSession session, RemoteTableName remoteTableName)
{
SchemaTableName tableName = new SchemaTableName(remoteTableName.getCatalogName().orElse(null), remoteTableName.getTableName());
List<JdbcColumnHandle> columns = getColumns(session, tableName, remoteTableName);
try (Connection connection = connectionFactory.openConnection(session)) {
DatabaseMetaData metaData = connection.getMetaData();

ResultSet primaryKeys = metaData.getPrimaryKeys(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName());

Set<String> primaryKeyNames = new HashSet<>();
while (primaryKeys.next()) {
String columnName = primaryKeys.getString("COLUMN_NAME");
primaryKeyNames.add(columnName);
}
if (primaryKeyNames.isEmpty()) {
return ImmutableList.of();
}
ImmutableList.Builder<JdbcColumnHandle> primaryKeysBuilder = ImmutableList.builder();
for (JdbcColumnHandle columnHandle : columns) {
String name = columnHandle.getColumnName();
if (!primaryKeyNames.contains(name)) {
continue;
}
JdbcTypeHandle handle = columnHandle.getJdbcTypeHandle();
CaseSensitivity caseSensitivity = handle.caseSensitivity().orElse(CASE_INSENSITIVE);
if (caseSensitivity == CASE_INSENSITIVE) {
primaryKeysBuilder.add(new JdbcColumnHandle(
name,
// make sure the primary keys that are varchar/char relate types can be pushdown
new JdbcTypeHandle(handle.jdbcType(), handle.jdbcTypeName(), handle.columnSize(), handle.decimalDigits(), handle.arrayDimensions(), Optional.of(CASE_SENSITIVE)),
columnHandle.getColumnType()));
}
else {
primaryKeysBuilder.add(columnHandle);
}
}
return primaryKeysBuilder.build();
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

private static Optional<ColumnHistogram> getColumnHistogram(Map<String, String> columnHistograms, String columnName)
{
return Optional.ofNullable(columnHistograms.get(columnName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@
import java.sql.ResultSet;
import java.time.LocalDate;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.nullToEmpty;
import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE;
import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static io.trino.testing.TestingSession.testSessionBuilder;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;
Expand All @@ -56,7 +60,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
return switch (connectorBehavior) {
case SUPPORTS_AGGREGATION_PUSHDOWN,
SUPPORTS_JOIN_PUSHDOWN -> true;
SUPPORTS_JOIN_PUSHDOWN,
SUPPORTS_MERGE,
SUPPORTS_ROW_LEVEL_UPDATE -> true;
case SUPPORTS_ADD_COLUMN_WITH_COMMENT,
SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION,
SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT,
Expand All @@ -79,6 +85,32 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
};
}

@Override
protected Session getSession()
{
Session session = super.getSession();
return Session.builder(session)
.setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true")
.build();
}

@Test
public void testMergeTargetWithNoPrimaryKeys()
{
String tableName = "test_merge_target_no_pks_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (a int, b int)");
assertUpdate("INSERT INTO " + tableName + " VALUES(1, 1), (2, 2)", 2);

assertQueryFails(format("DELETE FROM %s WHERE a IS NOT NULL AND abs(a + b) > 10", tableName), "The connector can not perform merge on the target table without primary keys");
assertQueryFails(format("UPDATE %s SET a = a+b WHERE a IS NOT NULL AND (a + b) > 10", tableName), "The connector can not perform merge on the target table without primary keys");
assertQueryFails(format("MERGE INTO %s t USING (VALUES (3, 3)) AS s(x, y) " +
" ON t.a = s.x " +
" WHEN MATCHED THEN UPDATE SET b = y " +
" WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.x, s.y) ", tableName), "The connector can not perform merge on the target table without primary keys");

assertUpdate("DROP TABLE " + tableName);
}

@Override
protected TestTable createTableWithDefaultColumns()
{
Expand Down Expand Up @@ -181,14 +213,6 @@ public void testShowCreateTable()
")");
}

@Test
@Override
public void testDeleteWithLike()
{
assertThatThrownBy(super::testDeleteWithLike)
.hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE);
}

@Test
public void testViews()
{
Expand Down Expand Up @@ -628,4 +652,45 @@ public void verifyMySqlJdbcDriverNegativeDateHandling()
}
}
}

@Override
protected void createTableForWrites(String createTable, String tableName, Optional<String> primaryKey, OptionalInt updateCount)
{
super.createTableForWrites(createTable, tableName, primaryKey, updateCount);
primaryKey.ifPresent(key -> addPrimaryKey(createTable, tableName, key));
}

private void addPrimaryKey(String createTable, String tableName, String primaryKey)
{
Matcher matcher = Pattern.compile("CREATE TABLE .* \\(.*\\b" + primaryKey + "\\b\\s+([a-zA-Z0-9()]+).*", Pattern.CASE_INSENSITIVE).matcher(createTable);
if (matcher.matches()) {
String type = matcher.group(1).toLowerCase(Locale.ENGLISH);
if (type.contains("varchar") || type.contains("char")) {
// Mysql requires the primary keys must hava a fixed length, here use the 255 length that is just long enough for the test
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s(255))", tableName, primaryKey));
return;
}
}

// ctas or the type is not varchar/char
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
}

@Override
protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, String primaryKey)
{
TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, primaryKey);
String tableName = testTable.getName();
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
return testTable;
}

@Override
protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, List<String> rowsToInsert, String primaryKey)
{
TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, rowsToInsert, primaryKey);
String tableName = testTable.getName();
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
return testTable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableMap;
import com.google.inject.Module;
import io.trino.Session;
import io.trino.operator.RetryPolicy;
import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin;
import io.trino.plugin.jdbc.BaseJdbcFailureRecoveryTest;
Expand All @@ -26,9 +27,14 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assumptions.abort;

public abstract class BaseMySqlFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
private TestingMySqlServer mySqlServer;

public BaseMySqlFailureRecoveryTest(RetryPolicy retryPolicy)
{
super(retryPolicy);
Expand All @@ -42,7 +48,8 @@ protected QueryRunner createQueryRunner(
Module failureInjectionModule)
throws Exception
{
return MySqlQueryRunner.builder(closeAfterClass(new TestingMySqlServer()))
this.mySqlServer = new TestingMySqlServer();
return MySqlQueryRunner.builder(closeAfterClass(mySqlServer))
.setExtraProperties(configProperties)
.setCoordinatorProperties(coordinatorProperties)
.setAdditionalSetup(runner -> {
Expand All @@ -55,6 +62,28 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testDeleteWithSubquery()
{
assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled");
abort("skipped");
}

@Test
@Override
protected void testMerge()
{
assertThatThrownBy(super::testMerge).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testUpdate()
Expand All @@ -70,4 +99,10 @@ protected void testUpdate()
.withCleanupQuery(cleanupQuery)
.isCoordinatorOnly();
}

@Override
protected void addPrimaryKeyForMergeTarget(Session session, String tableName, String primaryKey)
{
mySqlServer.execute("ALTER TABLE %s ADD PRIMARY KEY (%s)".formatted(tableName, primaryKey));
}
}
Loading

0 comments on commit 5b93e35

Please sign in to comment.