Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MERGE for MySQL connector #24428

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/main/sphinx/connector/mysql.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ following features:
- [](/sql/insert), see also [](mysql-insert)
- [](/sql/update), see also [](mysql-update)
- [](/sql/delete), see also [](mysql-delete)
- [](/sql/merge), see also [](mysql-merge)
- [](/sql/truncate)
- [](/sql/create-table)
- [](/sql/create-table-as)
Expand All @@ -359,6 +360,10 @@ following features:
```{include} sql-delete-limitation.fragment
```

(mysql-merge)=
```{include} non-transactional-merge.fragment
```

(mysql-procedures)=
### Procedures

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1855,47 +1855,47 @@ public void testUpdateWithPredicates()
public void testConstantUpdateWithVarcharEqualityPredicates()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = newTrinoTable("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"))) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
chenjian2664 marked this conversation as resolved.
Show resolved Hide resolved
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", MODIFYING_ROWS_MESSAGE);
return;
}
assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a'), (20, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a', 1), (20, 'A', 2)");
}
}

@Test
public void testConstantUpdateWithVarcharInequalityPredicates()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (2, 'A', 2)");
}
}

@Test
public void testConstantUpdateWithVarcharGreaterAndLowerPredicate()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE));
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) {
try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1), pk INT)", ImmutableList.of("1, 'a', 1", "2, 'A', 2"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", MODIFYING_ROWS_MESSAGE);
assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (2, 'A', 2)");

assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'a'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (20, 'A')");
assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a', 1), (20, 'A', 2)");
}
}

Expand All @@ -1921,14 +1921,14 @@ public void testDeleteWithVarcharEqualityPredicate()
{
skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_ROW_LEVEL_DELETE));
// TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated
try (TestTable table = newTrinoTable("test_delete_varchar", "(col varchar(1))", ImmutableList.of("'a'", "'A'", "null"))) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY)) {
try (TestTable table = createTestTableForWrites( "test_delete_varchar", "(col varchar(1), pk INT)", ImmutableList.of("'a', 1", "'A', 2", "null, 3"), "pk")) {
if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) {
assertQueryFails("DELETE FROM " + table.getName() + " WHERE col = 'A'", MODIFYING_ROWS_MESSAGE);
return;
}

assertUpdate("DELETE FROM " + table.getName() + " WHERE col = 'A'", 1);
assertQuery("SELECT * FROM " + table.getName(), "VALUES 'a', null");
assertQuery("SELECT col FROM " + table.getName(), "VALUES 'a', null");
}
}

Expand Down Expand Up @@ -2019,6 +2019,25 @@ public void testDeleteWithSemiJoin()
.hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE);
}

@Test
public void testMergeTargetWithoutPrimaryKeys()
{
skipTestUnless(hasBehavior(SUPPORTS_MERGE));

String tableName = "test_merge_target_no_pks_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (a int, b int)");
assertUpdate("INSERT INTO " + tableName + " VALUES(1, 1), (2, 2)", 2);

assertQueryFails(format("DELETE FROM %s WHERE a IS NOT NULL AND abs(a + b) > 10", tableName), "The connector can not perform merge on the target table without primary keys");
assertQueryFails(format("UPDATE %s SET a = a+b WHERE a IS NOT NULL AND (a + b) > 10", tableName), "The connector can not perform merge on the target table without primary keys");
assertQueryFails(format("MERGE INTO %s t USING (VALUES (3, 3)) AS s(x, y) " +
" ON t.a = s.x " +
" WHEN MATCHED THEN UPDATE SET b = y " +
" WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.x, s.y) ", tableName), "The connector can not perform merge on the target table without primary keys");

assertUpdate("DROP TABLE " + tableName);
}

@Test
@Override
public void testDeleteWithVarcharPredicate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,4 +553,11 @@ public void testExecuteProcedure()
assertUpdate("DROP TABLE IF EXISTS " + schemaTableName);
}
}

@Test
@Override
public void testMergeTargetWithoutPrimaryKeys()
{
abort("Ignite table always has primary key");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@
import java.time.OffsetDateTime;
import java.util.AbstractMap.SimpleEntry;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Stream;

Expand Down Expand Up @@ -975,6 +977,12 @@ public boolean isLimitGuaranteed(ConnectorSession session)
return true;
}

@Override
public boolean supportsMerge()
{
return true;
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<JdbcSortItem> sortOrder)
{
Expand Down Expand Up @@ -1159,6 +1167,43 @@ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableH
}
}

@Override
public List<JdbcColumnHandle> getPrimaryKeys(ConnectorSession session, RemoteTableName remoteTableName)
{
SchemaTableName tableName = new SchemaTableName(remoteTableName.getCatalogName().orElse(null), remoteTableName.getTableName());
List<JdbcColumnHandle> columns = getColumns(session, tableName, remoteTableName);
try (Connection connection = connectionFactory.openConnection(session)) {
DatabaseMetaData metaData = connection.getMetaData();

ResultSet primaryKeys = metaData.getPrimaryKeys(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName());

Set<String> primaryKeyNames = new HashSet<>();
while (primaryKeys.next()) {
primaryKeyNames.add(primaryKeys.getString("COLUMN_NAME"));
}
if (primaryKeyNames.isEmpty()) {
return ImmutableList.of();
}
ImmutableList.Builder<JdbcColumnHandle> primaryKeysBuilder = ImmutableList.builder();
for (JdbcColumnHandle columnHandle : columns) {
String name = columnHandle.getColumnName();
if (!primaryKeyNames.contains(name)) {
continue;
}
JdbcTypeHandle handle = columnHandle.getJdbcTypeHandle();
primaryKeysBuilder.add(new JdbcColumnHandle(
name,
// make sure the primary keys that are varchar/char relate types can be pushdown
new JdbcTypeHandle(handle.jdbcType(), handle.jdbcTypeName(), handle.columnSize(), handle.decimalDigits(), handle.arrayDimensions(), Optional.of(CASE_SENSITIVE)),
columnHandle.getColumnType()));
}
return primaryKeysBuilder.build();
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

private static Optional<ColumnHistogram> getColumnHistogram(Map<String, String> columnHistograms, String columnName)
{
return Optional.ofNullable(columnHistograms.get(columnName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import java.sql.ResultSet;
import java.time.LocalDate;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.nullToEmpty;
import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE;
import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
Expand All @@ -56,7 +59,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
return switch (connectorBehavior) {
case SUPPORTS_AGGREGATION_PUSHDOWN,
SUPPORTS_JOIN_PUSHDOWN -> true;
SUPPORTS_JOIN_PUSHDOWN,
SUPPORTS_MERGE,
SUPPORTS_ROW_LEVEL_UPDATE -> true;
case SUPPORTS_ADD_COLUMN_WITH_COMMENT,
SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION,
SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT,
Expand All @@ -79,6 +84,15 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
};
}

@Override
protected Session getSession()
{
Session session = super.getSession();
return Session.builder(session)
.setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true")
.build();
}

@Override
protected TestTable createTableWithDefaultColumns()
{
Expand Down Expand Up @@ -181,14 +195,6 @@ public void testShowCreateTable()
")");
}

@Test
@Override
public void testDeleteWithLike()
{
assertThatThrownBy(super::testDeleteWithLike)
.hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE);
}

@Test
public void testViews()
{
Expand Down Expand Up @@ -628,4 +634,43 @@ public void verifyMySqlJdbcDriverNegativeDateHandling()
}
}
}

@Override
protected void createTableForWrites(String createTable, String tableName, Optional<String> primaryKey, OptionalInt updateCount)
{
super.createTableForWrites(createTable, tableName, primaryKey, updateCount);
primaryKey.ifPresent(key -> addPrimaryKey(createTable, tableName, key));
}

private void addPrimaryKey(String createTable, String tableName, String primaryKey)
{
Matcher matcher = Pattern.compile("CREATE TABLE .* \\(.*\\b" + primaryKey + "\\b\\s+([a-zA-Z0-9()]+).*", Pattern.CASE_INSENSITIVE).matcher(createTable);
if (matcher.matches()) {
String type = matcher.group(1).toLowerCase(Locale.ENGLISH);
if (type.contains("varchar") || type.contains("char")) {
// Mysql requires the primary keys must hava a fixed length, here use the 255 length that is just long enough for the test
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s(255))", tableName, primaryKey));
return;
}
}

// ctas or the type is not varchar/char
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", tableName, primaryKey));
}

@Override
protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, String primaryKey)
{
TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, primaryKey);
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", testTable.getName(), primaryKey));
return testTable;
}

@Override
protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, List<String> rowsToInsert, String primaryKey)
{
TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, rowsToInsert, primaryKey);
onRemoteDatabase().execute(format("ALTER TABLE %s ADD PRIMARY KEY (%s)", testTable.getName(), primaryKey));
return testTable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableMap;
import com.google.inject.Module;
import io.trino.Session;
import io.trino.operator.RetryPolicy;
import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin;
import io.trino.plugin.jdbc.BaseJdbcFailureRecoveryTest;
Expand All @@ -26,9 +27,13 @@
import java.util.Map;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class BaseMySqlFailureRecoveryTest
extends BaseJdbcFailureRecoveryTest
{
private TestingMySqlServer mySqlServer;

public BaseMySqlFailureRecoveryTest(RetryPolicy retryPolicy)
{
super(retryPolicy);
Expand All @@ -42,7 +47,8 @@ protected QueryRunner createQueryRunner(
Module failureInjectionModule)
throws Exception
{
return MySqlQueryRunner.builder(closeAfterClass(new TestingMySqlServer()))
this.mySqlServer = new TestingMySqlServer();
return MySqlQueryRunner.builder(closeAfterClass(mySqlServer))
.setExtraProperties(configProperties)
.setCoordinatorProperties(coordinatorProperties)
.setAdditionalSetup(runner -> {
Expand All @@ -55,6 +61,27 @@ protected QueryRunner createQueryRunner(
.build();
}

@Test
@Override
protected void testDeleteWithSubquery()
{
assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testUpdateWithSubquery()
{
assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testMerge()
{
assertThatThrownBy(super::testMerge).hasMessageContaining("Non-transactional MERGE is disabled");
}

@Test
@Override
protected void testUpdate()
Expand All @@ -70,4 +97,10 @@ protected void testUpdate()
.withCleanupQuery(cleanupQuery)
.isCoordinatorOnly();
}

@Override
protected void addPrimaryKeyForMergeTarget(Session session, String tableName, String primaryKey)
{
mySqlServer.execute("ALTER TABLE %s ADD PRIMARY KEY (%s)".formatted(tableName, primaryKey));
}
}
Loading
Loading