Skip to content

Commit

Permalink
Support MERGE for MySQL connector
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjian2664 committed Dec 10, 2024
1 parent 90db2f5 commit c0b86ce
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1861,47 +1861,47 @@ public void testUpdateWithPredicates()
public void testConstantUpdateWithVarcharEqualityPredicates()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"))) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", MODIFYING_ROWS_MESSAGE);
return;
}
assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a'), (20, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a', 1), (20, 'A', 2)");
}
}

@Test
public void testConstantUpdateWithVarcharInequalityPredicates()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (2, 'A', 2)");
}
}

@Test
public void testConstantUpdateWithVarcharGreaterAndLowerPredicate()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", MODIFYING_ROWS_MESSAGE);
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (2, 'A', 2)");

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'a'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (20, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (20, 'A', 2)");
}
}

Expand All @@ -1927,14 +1927,14 @@ public void testDeleteWithVarcharEqualityPredicate()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_ROW_LEVEL_DELETE));
// TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_varchar", "(col varchar(1))", ImmutableList.of("'a'", "'A'", "null"))) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)) {
try (TestTable table = createTestTableForWrites( "test_delete_varchar", "(col varchar(1), pk INT)", ImmutableList.of("'a', 1", "'A', 2", "null, 3"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("DELETE FROM " + table.getName() + " WHERE col = 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("DELETE FROM " + table.getName() + " WHERE col = 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES 'a', null");
assertQuery("SELECT col FROM " + table.getName(), "VALUES 'a', null");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@
import java.time.OffsetDateTime;
import java.util.AbstractMap.SimpleEntry;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Stream;

Expand Down Expand Up @@ -975,6 +977,12 @@ public boolean isLimitGuaranteed(ConnectorSession session)
return true;
}

@Override
public boolean supportsMerge()
{
return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
Expand Down Expand Up @@ -1159,6 +1167,50 @@ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableH
}
}

@Override
public List<JdbcColumnHandle> getPrimaryKeys(ConnectorSession session, RemoteTableName remoteTableName)
{
SchemaTableName tableName = new SchemaTableName(remoteTableName.getCatalogName().orElse(null), remoteTableName.getTableName());
List<JdbcColumnHandle> columns = getColumns(session, tableName, remoteTableName);
try (Connection connection = connectionFactory.openConnection(session)) {
DatabaseMetaData metaData = connection.getMetaData();

ResultSet primaryKeys = metaData.getPrimaryKeys(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName());

Set<String> primaryKeyNames = new HashSet<>();
while (primaryKeys.next()) {
String columnName = primaryKeys.getString("COLUMN_NAME");
primaryKeyNames.add(columnName);
}
if (primaryKeyNames.isEmpty()) {
return ImmutableList.of();
}
ImmutableList.Builder<JdbcColumnHandle> primaryKeysBuilder = ImmutableList.builder();
for (JdbcColumnHandle columnHandle : columns) {
String name = columnHandle.getColumnName();
if (!primaryKeyNames.contains(name)) {
continue;
}
JdbcTypeHandle handle = columnHandle.getJdbcTypeHandle();
CaseSensitivity caseSensitivity = handle.caseSensitivity().orElse(CASE_INSENSITIVE);
if (caseSensitivity == CASE_INSENSITIVE) {
primaryKeysBuilder.add(new JdbcColumnHandle(
name,
// make sure the primary keys that are varchar/char relate types can be pushdown
new JdbcTypeHandle(handle.jdbcType(), handle.jdbcTypeName(), handle.columnSize(), handle.decimalDigits(), handle.arrayDimensions(), Optional.of(CASE_SENSITIVE)),
columnHandle.getColumnType()));
}
else {
primaryKeysBuilder.add(columnHandle);
}
}
return primaryKeysBuilder.build();
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

private static Optional<ColumnHistogram> getColumnHistogram(Map<String, String> columnHistograms, String columnName)
{
return Optional.ofNullable(columnHistograms.get(columnName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import java.sql.ResultSet;
import java.time.LocalDate;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.nullToEmpty;
import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE;
import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
Expand All @@ -56,7 +59,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
return switch (connectorBehavior) {
case SUPPORTS_AGGREGATION_PUSHDOWN,
SUPPORTS_JOIN_PUSHDOWN -> true;
SUPPORTS_JOIN_PUSHDOWN,
SUPPORTS_MERGE,
SUPPORTS_ROW_LEVEL_UPDATE -> true;
case SUPPORTS_ADD_COLUMN_WITH_COMMENT,
SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION,
SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT,
Expand All @@ -79,6 +84,15 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
};
}

@Override
protected Session getSession()
{
Session session = super.getSession();
return Session.builder(session)
.setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true")
.build();
}

@Override
protected TestTable createTableWithDefaultColumns()
{
Expand Down Expand Up @@ -181,14 +195,6 @@ public void testShowCreateTable()
")");
}

@Test
@Override
public void testDeleteWithLike()
{
assertThatThrownBy(super::testDeleteWithLike)
.hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE);
}

@Test
public void testViews()
{
Expand Down Expand Up @@ -628,4 +634,44 @@ public void verifyMySqlJdbcDriverNegativeDateHandling()
}
}
}

@Override
protected void createTableForWrites(String createTable, String tableName, Optional<String> primaryKey, OptionalInt updateCount)
{
super.createTableForWrites(createTable, tableName, primaryKey, updateCount);
primaryKey.ifPresent(key -> addPrimaryKey(createTable, tableName, key));
}

private void addPrimaryKey(String createTable, String tableName, String primaryKey)
{
Matcher matcher = Pattern.compile("CREATE TABLE .* \\(.*\\b" + primaryKey + "\\b\\s+([a-zA-Z0-9()]+).*", Pattern.CASE_INSENSITIVE).matcher(createTable);
if (matcher.matches()) {
String type = matcher.group(1).toLowerCase(Locale.ENGLISH);
if (type.contains("varchar") || type.contains("char")) {
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s(255))", tableName, primaryKey));
return;
}
}

// ctas or the type is not varchar/char
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
}

@Override
protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, String primaryKey)
{
TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, primaryKey);
String tableName = testTable.getName();
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
return testTable;
}

@Override
protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, List<String> rowsToInsert, String primaryKey)
{
TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, rowsToInsert, primaryKey);
String tableName = testTable.getName();
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
return testTable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableMap;
import com.google.inject.Module;
import io.trino.Session;
import io.trino.operator.RetryPolicy;
import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin;
import io.trino.plugin.jdbc.BaseJdbcFailureRecoveryTest;
Expand All @@ -32,6 +33,8 @@
public abstract class BaseMySqlFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
private TestingMySqlServer mySqlServer;

public BaseMySqlFailureRecoveryTest(RetryPolicy retryPolicy)
{
super(retryPolicy);
Expand All @@ -45,7 +48,8 @@ protected QueryRunner createQueryRunner(
Module failureInjectionModule)
throws Exception
{
return MySqlQueryRunner.builder(closeAfterClass(new TestingMySqlServer()))
this.mySqlServer = new TestingMySqlServer();
return MySqlQueryRunner.builder(closeAfterClass(mySqlServer))
.setExtraProperties(configProperties)
.setCoordinatorProperties(coordinatorProperties)
.setAdditionalSetup(runner -> {
Expand All @@ -58,6 +62,13 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testDeleteWithSubquery()
{
assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testUpdateWithSubquery()
Expand All @@ -66,6 +77,13 @@ protected void testUpdateWithSubquery()
abort("skipped");
}

@Test
@Override
protected void testMerge()
{
assertThatThrownBy(super::testMerge).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testUpdate()
Expand All @@ -81,4 +99,10 @@ protected void testUpdate()
.withCleanupQuery(cleanupQuery)
.isCoordinatorOnly();
}

@Override
protected void addPrimaryKeyForMergeTarget(Session session, String tableName, String primaryKey)
{
mySqlServer.execute("ALTER TABLE %s ADD PRIMARY KEY (%s)".formatted(tableName, primaryKey));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
*/
package io.trino.plugin.mysql;

import io.trino.Session;
import io.trino.plugin.jdbc.BaseJdbcConnectorSmokeTest;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
import io.trino.testing.sql.TestTable;

import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE;
import static java.lang.String.format;

public class TestMySqlGlobalTransactionMyConnectorSmokeTest
extends BaseJdbcConnectorSmokeTest
Expand All @@ -32,15 +37,37 @@ protected QueryRunner createQueryRunner()
.build();
}

@Override
protected Session getSession()
{
Session session = super.getSession();
return Session.builder(session)
.setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true")
.build();
}

@Override
protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
switch (connectorBehavior) {
case SUPPORTS_RENAME_SCHEMA:
return false;

case SUPPORTS_ROW_LEVEL_UPDATE:
case SUPPORTS_MERGE:
return true;

default:
return super.hasBehavior(connectorBehavior);
}
}

@Override
protected TestTable createTestTableForWrites(String tablePrefix)
{
TestTable table = super.createTestTableForWrites(tablePrefix);
String tableName = table.getName();
mySqlServer.execute(format("ALTER TABLE %s ADD PRIMARY KEY (a)", tableName));
return table;
}
}
Loading

0 comments on commit c0b86ce

Please sign in to comment.