diff --git a/data-prepper-plugins/rds-source/build.gradle b/data-prepper-plugins/rds-source/build.gradle index 77f1022f63..1b325457bf 100644 --- a/data-prepper-plugins/rds-source/build.gradle +++ b/data-prepper-plugins/rds-source/build.gradle @@ -23,6 +23,7 @@ dependencies { implementation 'com.zendesk:mysql-binlog-connector-java:0.29.2' implementation 'com.mysql:mysql-connector-j:8.4.0' + implementation 'org.postgresql:postgresql:42.7.4' compileOnly 'org.projectlombok:lombok:1.18.20' annotationProcessor 'org.projectlombok:lombok:1.18.20' diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/RdsService.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/RdsService.java index 5dad3cb3c6..106827cb69 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/RdsService.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/RdsService.java @@ -14,6 +14,7 @@ import org.opensearch.dataprepper.model.plugin.PluginConfigObservable; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.export.DataFileScheduler; import org.opensearch.dataprepper.plugins.source.rds.export.ExportScheduler; import org.opensearch.dataprepper.plugins.source.rds.export.ExportTaskManager; @@ -26,9 +27,13 @@ import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata; import org.opensearch.dataprepper.plugins.source.rds.resync.ResyncScheduler; import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager; +import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManagerFactory; +import org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager; +import org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager; import org.opensearch.dataprepper.plugins.source.rds.schema.QueryManager; import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager; -import org.opensearch.dataprepper.plugins.source.rds.stream.BinlogClientFactory; +import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManagerFactory; +import org.opensearch.dataprepper.plugins.source.rds.stream.ReplicationLogClientFactory; import org.opensearch.dataprepper.plugins.source.rds.stream.StreamScheduler; import org.opensearch.dataprepper.plugins.source.rds.utils.IdentifierShortener; import org.slf4j.Logger; @@ -37,6 +42,7 @@ import software.amazon.awssdk.services.s3.S3Client; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -101,9 +107,16 @@ public void start(Buffer> buffer) { new ClusterApiStrategy(rdsClient) : new InstanceApiStrategy(rdsClient); final DbMetadata dbMetadata = rdsApiStrategy.describeDb(sourceConfig.getDbIdentifier()); final String s3PathPrefix = getS3PathPrefix(); + final SchemaManager schemaManager = getSchemaManager(sourceConfig, dbMetadata); - final Map> tableColumnDataTypeMap = getColumnDataTypeMap(schemaManager); - final DbTableMetadata dbTableMetadata = new DbTableMetadata(dbMetadata, tableColumnDataTypeMap); + DbTableMetadata dbTableMetadata; + if (sourceConfig.getEngine() == EngineType.MYSQL) { + final Map> tableColumnDataTypeMap = getColumnDataTypeMap( + (MySqlSchemaManager) schemaManager); + dbTableMetadata = new DbTableMetadata(dbMetadata, tableColumnDataTypeMap); + } else { + dbTableMetadata = new DbTableMetadata(dbMetadata, Collections.emptyMap()); + } leaderScheduler = new LeaderScheduler( sourceCoordinator, sourceConfig, s3PathPrefix, schemaManager, dbTableMetadata); @@ -121,21 +134,23 @@ public void start(Buffer> buffer) { } if (sourceConfig.isStreamEnabled()) { - BinlogClientFactory binaryLogClientFactory = new BinlogClientFactory(sourceConfig, rdsClient, dbMetadata); + ReplicationLogClientFactory replicationLogClientFactory = new ReplicationLogClientFactory(sourceConfig, rdsClient, dbMetadata); if (sourceConfig.isTlsEnabled()) { - binaryLogClientFactory.setSSLMode(SSLMode.REQUIRED); + replicationLogClientFactory.setSSLMode(SSLMode.REQUIRED); } else { - binaryLogClientFactory.setSSLMode(SSLMode.DISABLED); + replicationLogClientFactory.setSSLMode(SSLMode.DISABLED); } streamScheduler = new StreamScheduler( - sourceCoordinator, sourceConfig, s3PathPrefix, binaryLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable); + sourceCoordinator, sourceConfig, s3PathPrefix, replicationLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable); runnableList.add(streamScheduler); - resyncScheduler = new ResyncScheduler( - sourceCoordinator, sourceConfig, getQueryManager(sourceConfig, dbMetadata), s3PathPrefix, buffer, pluginMetrics, acknowledgementSetManager); - runnableList.add(resyncScheduler); + if (sourceConfig.getEngine() == EngineType.MYSQL) { + resyncScheduler = new ResyncScheduler( + sourceCoordinator, sourceConfig, getQueryManager(sourceConfig, dbMetadata), s3PathPrefix, buffer, pluginMetrics, acknowledgementSetManager); + runnableList.add(resyncScheduler); + } } executor = Executors.newFixedThreadPool(runnableList.size()); @@ -164,19 +179,14 @@ public void shutdown() { } private SchemaManager getSchemaManager(final RdsSourceConfig sourceConfig, final DbMetadata dbMetadata) { - final ConnectionManager connectionManager = new ConnectionManager( - dbMetadata.getEndpoint(), - dbMetadata.getPort(), - sourceConfig.getAuthenticationConfig().getUsername(), - sourceConfig.getAuthenticationConfig().getPassword(), - sourceConfig.isTlsEnabled()); - return new SchemaManager(connectionManager); + final ConnectionManager connectionManager = new ConnectionManagerFactory(sourceConfig, dbMetadata).getConnectionManager(); + return new SchemaManagerFactory(connectionManager).getSchemaManager(); } private QueryManager getQueryManager(final RdsSourceConfig sourceConfig, final DbMetadata dbMetadata) { final String readerEndpoint = dbMetadata.getReaderEndpoint() != null ? dbMetadata.getReaderEndpoint() : dbMetadata.getEndpoint(); final int readerPort = dbMetadata.getReaderPort() == 0 ? dbMetadata.getPort() : dbMetadata.getReaderPort(); - final ConnectionManager readerConnectionManager = new ConnectionManager( + final MySqlConnectionManager readerConnectionManager = new MySqlConnectionManager( readerEndpoint, readerPort, sourceConfig.getAuthenticationConfig().getUsername(), @@ -203,13 +213,11 @@ private String getS3PathPrefix() { return s3PathPrefix; } - private Map> getColumnDataTypeMap(final SchemaManager schemaManager) { + private Map> getColumnDataTypeMap(final MySqlSchemaManager schemaManager) { return sourceConfig.getTableNames().stream() .collect(Collectors.toMap( fullTableName -> fullTableName, fullTableName -> schemaManager.getColumnDataTypes(fullTableName.split("\\.")[0], fullTableName.split("\\.")[1]) )); } - - } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/configuration/EngineType.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/configuration/EngineType.java index f75ec32bfe..20f7f3b534 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/configuration/EngineType.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/configuration/EngineType.java @@ -13,7 +13,8 @@ public enum EngineType { - MYSQL("mysql"); + MYSQL("mysql"), + POSTGRES("postgres"); private static final Map ENGINE_TYPE_MAP = Arrays.stream(EngineType.values()) .collect(Collectors.toMap( diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java index 1f751e2087..80615bdebf 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java @@ -10,26 +10,74 @@ import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; import java.util.List; +import java.util.Map; public class StreamProgressState { - @JsonProperty("currentPosition") - private BinlogCoordinate currentPosition; + // TODO: separate MySQL and Postgres properties into different progress state classes + // Common + @JsonProperty("engineType") + private String engineType; @JsonProperty("waitForExport") private boolean waitForExport = false; + /** + * Map of table name to primary keys + */ + @JsonProperty("primaryKeyMap") + private Map> primaryKeyMap; + + // For MySQL + @JsonProperty("currentPosition") + private BinlogCoordinate currentPosition; + @JsonProperty("foreignKeyRelations") private List foreignKeyRelations; + // For Postgres + @JsonProperty("currentLsn") + private String currentLsn; + + @JsonProperty("replicationSlotName") + private String replicationSlotName; + + public String getEngineType() { + return engineType; + } + + public void setEngineType(String engineType) { + this.engineType = engineType; + } + public BinlogCoordinate getCurrentPosition() { return currentPosition; } + public String getCurrentLsn() { + return currentLsn; + } + + public Map> getPrimaryKeyMap() { + return primaryKeyMap; + } + + public void setPrimaryKeyMap(Map> primaryKeyMap) { + this.primaryKeyMap = primaryKeyMap; + } + + public String getReplicationSlotName() { + return replicationSlotName; + } + public void setCurrentPosition(BinlogCoordinate currentPosition) { this.currentPosition = currentPosition; } + public void setReplicationSlotName(String replicationSlotName) { + this.replicationSlotName = replicationSlotName; + } + public boolean shouldWaitForExport() { return waitForExport; } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/datatype/postgres/ColumnType.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/datatype/postgres/ColumnType.java new file mode 100644 index 0000000000..c03e4e67c4 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/datatype/postgres/ColumnType.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.datatype.postgres; + +import java.util.HashMap; +import java.util.Map; + +public enum ColumnType { + BOOLEAN(16, "boolean"), + SMALLINT(21, "smallint"), + INTEGER(23, "integer"), + BIGINT(20, "bigint"), + REAL(700, "real"), + DOUBLE_PRECISION(701, "double precision"), + NUMERIC(1700, "numeric"), + TEXT(25, "text"), + VARCHAR(1043, "varchar"), + DATE(1082, "date"), + TIME(1083, "time"), + TIMESTAMP(1114, "timestamp"), + TIMESTAMPTZ(1184, "timestamptz"), + UUID(2950, "uuid"), + JSON(114, "json"), + JSONB(3802, "jsonb"); + + private final int typeId; + private final String typeName; + + private static final Map TYPE_ID_MAP = new HashMap<>(); + + static { + for (ColumnType type : values()) { + TYPE_ID_MAP.put(type.typeId, type); + } + } + + ColumnType(int typeId, String typeName) { + this.typeId = typeId; + this.typeName = typeName; + } + + public int getTypeId() { + return typeId; + } + + public String getTypeName() { + return typeName; + } + + public static ColumnType getByTypeId(int typeId) { + return TYPE_ID_MAP.get(typeId); + } + + public static String getTypeNameByEnum(ColumnType columnType) { + return columnType.getTypeName(); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java index 3f7e6d5cb2..f3beb3e12c 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java @@ -8,6 +8,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ExportPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.LeaderPartition; @@ -17,6 +18,8 @@ import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata; +import org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager; +import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresSchemaManager; import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.stream.Collectors; import static org.opensearch.dataprepper.plugins.source.rds.RdsService.S3_PATH_DELIMITER; @@ -152,22 +156,41 @@ private Map> getPrimaryKeyMap() { return sourceConfig.getTableNames().stream() .collect(Collectors.toMap( fullTableName -> fullTableName, - fullTableName -> schemaManager.getPrimaryKeys(fullTableName.split("\\.")[0], fullTableName.split("\\.")[1]) + fullTableName -> schemaManager.getPrimaryKeys(fullTableName) )); } private void createStreamPartition(RdsSourceConfig sourceConfig) { final StreamProgressState progressState = new StreamProgressState(); + progressState.setEngineType(sourceConfig.getEngine().toString()); progressState.setWaitForExport(sourceConfig.isExportEnabled()); - getCurrentBinlogPosition().ifPresent(progressState::setCurrentPosition); - progressState.setForeignKeyRelations(schemaManager.getForeignKeyRelations(sourceConfig.getTableNames())); + progressState.setPrimaryKeyMap(getPrimaryKeyMap()); + if (sourceConfig.getEngine() == EngineType.MYSQL) { + getCurrentBinlogPosition().ifPresent(progressState::setCurrentPosition); + progressState.setForeignKeyRelations(((MySqlSchemaManager)schemaManager).getForeignKeyRelations(sourceConfig.getTableNames())); + } else { + // Postgres + // Create replication slot, which will mark the starting point for stream + final String publicationName = generatePublicationName(); + final String slotName = generateReplicationSlotName(); + ((PostgresSchemaManager)schemaManager).createLogicalReplicationSlot(sourceConfig.getTableNames(), publicationName, slotName); + progressState.setReplicationSlotName(slotName); + } StreamPartition streamPartition = new StreamPartition(sourceConfig.getDbIdentifier(), progressState); sourceCoordinator.createPartition(streamPartition); } private Optional getCurrentBinlogPosition() { - Optional binlogCoordinate = schemaManager.getCurrentBinaryLogPosition(); + Optional binlogCoordinate = ((MySqlSchemaManager)schemaManager).getCurrentBinaryLogPosition(); LOG.debug("Current binlog position: {}", binlogCoordinate.orElse(null)); return binlogCoordinate; } + + private String generatePublicationName() { + return "data_prepper_publication_" + UUID.randomUUID().toString().substring(0, 8); + } + + private String generateReplicationSlotName() { + return "data_prepper_slot_" + UUID.randomUUID().toString().substring(0, 8); + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/MessageType.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/MessageType.java new file mode 100644 index 0000000000..a537835099 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/MessageType.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.model; + +public enum MessageType { + BEGIN('B'), + RELATION('R'), + INSERT('I'), + UPDATE('U'), + DELETE('D'), + COMMIT('C'); + + private final char value; + + MessageType(char value) { + this.value = value; + } + + public char getValue() { + return value; + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManager.java index 542724d49d..dc475d0173 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManager.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManager.java @@ -1,55 +1,22 @@ /* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * */ package org.opensearch.dataprepper.plugins.source.rds.schema; import java.sql.Connection; -import java.sql.DriverManager; import java.sql.SQLException; -import java.util.Properties; -public class ConnectionManager { - static final String JDBC_URL_FORMAT = "jdbc:mysql://%s:%d"; - static final String USERNAME_KEY = "user"; - static final String PASSWORD_KEY = "password"; - static final String USE_SSL_KEY = "useSSL"; - static final String REQUIRE_SSL_KEY = "requireSSL"; - static final String TINY_INT_ONE_IS_BIT_KEY = "tinyInt1isBit"; - static final String TRUE_VALUE = "true"; - static final String FALSE_VALUE = "false"; - private final String hostName; - private final int port; - private final String username; - private final String password; - private final boolean requireSSL; - - public ConnectionManager(String hostName, int port, String username, String password, boolean requireSSL) { - this.hostName = hostName; - this.port = port; - this.username = username; - this.password = password; - this.requireSSL = requireSSL; - } - - public Connection getConnection() throws SQLException { - final Properties props = new Properties(); - props.setProperty(USERNAME_KEY, username); - props.setProperty(PASSWORD_KEY, password); - if (requireSSL) { - props.setProperty(USE_SSL_KEY, TRUE_VALUE); - props.setProperty(REQUIRE_SSL_KEY, TRUE_VALUE); - } else { - props.setProperty(USE_SSL_KEY, FALSE_VALUE); - } - props.setProperty(TINY_INT_ONE_IS_BIT_KEY, FALSE_VALUE); - final String jdbcUrl = String.format(JDBC_URL_FORMAT, hostName, port); - return doGetConnection(jdbcUrl, props); - } +/** + * Interface for managing connections to a database. + */ +public interface ConnectionManager { - // VisibleForTesting - Connection doGetConnection(String jdbcUrl, Properties props) throws SQLException { - return DriverManager.getConnection(jdbcUrl, props); - } + Connection getConnection() throws SQLException; } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerFactory.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerFactory.java new file mode 100644 index 0000000000..4a4dff966e --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerFactory.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; +import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; + +import java.util.List; + +public class ConnectionManagerFactory { + private final RdsSourceConfig sourceConfig; + private final DbMetadata dbMetadata; + + public ConnectionManagerFactory(final RdsSourceConfig sourceConfig, final DbMetadata dbMetadata) { + this.sourceConfig = sourceConfig; + this.dbMetadata = dbMetadata; + } + + public ConnectionManager getConnectionManager() { + if (sourceConfig.getEngine() == EngineType.MYSQL) { + return new MySqlConnectionManager( + dbMetadata.getEndpoint(), + dbMetadata.getPort(), + sourceConfig.getAuthenticationConfig().getUsername(), + sourceConfig.getAuthenticationConfig().getPassword(), + sourceConfig.isTlsEnabled()); + } + + return new PostgresConnectionManager( + dbMetadata.getEndpoint(), + dbMetadata.getPort(), + sourceConfig.getAuthenticationConfig().getUsername(), + sourceConfig.getAuthenticationConfig().getPassword(), + sourceConfig.isTlsEnabled(), + getDatabaseName(sourceConfig.getTableNames())); + } + + private String getDatabaseName(List tableNames) { + return tableNames.get(0).split("\\.")[0]; + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlConnectionManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlConnectionManager.java new file mode 100644 index 0000000000..6b0ff01ba8 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlConnectionManager.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +public class MySqlConnectionManager implements ConnectionManager { + static final String JDBC_URL_FORMAT = "jdbc:mysql://%s:%d"; + static final String USERNAME_KEY = "user"; + static final String PASSWORD_KEY = "password"; + static final String USE_SSL_KEY = "useSSL"; + static final String REQUIRE_SSL_KEY = "requireSSL"; + static final String TINY_INT_ONE_IS_BIT_KEY = "tinyInt1isBit"; + static final String TRUE_VALUE = "true"; + static final String FALSE_VALUE = "false"; + private final String hostName; + private final int port; + private final String username; + private final String password; + private final boolean requireSSL; + + public MySqlConnectionManager(String hostName, int port, String username, String password, boolean requireSSL) { + this.hostName = hostName; + this.port = port; + this.username = username; + this.password = password; + this.requireSSL = requireSSL; + } + + @Override + public Connection getConnection() throws SQLException { + final Properties props = new Properties(); + props.setProperty(USERNAME_KEY, username); + props.setProperty(PASSWORD_KEY, password); + if (requireSSL) { + props.setProperty(USE_SSL_KEY, TRUE_VALUE); + props.setProperty(REQUIRE_SSL_KEY, TRUE_VALUE); + } else { + props.setProperty(USE_SSL_KEY, FALSE_VALUE); + } + props.setProperty(TINY_INT_ONE_IS_BIT_KEY, FALSE_VALUE); + final String jdbcUrl = String.format(JDBC_URL_FORMAT, hostName, port); + return doGetConnection(jdbcUrl, props); + } + + // VisibleForTesting + Connection doGetConnection(String jdbcUrl, Properties props) throws SQLException { + return DriverManager.getConnection(jdbcUrl, props); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlSchemaManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlSchemaManager.java new file mode 100644 index 0000000000..1ca9182b40 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlSchemaManager.java @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyAction; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class MySqlSchemaManager implements SchemaManager { + private static final Logger LOG = LoggerFactory.getLogger(MySqlSchemaManager.class); + + static final String[] TABLE_TYPES = new String[]{"TABLE"}; + static final String COLUMN_NAME = "COLUMN_NAME"; + static final String BINLOG_STATUS_QUERY = "SHOW MASTER STATUS"; + static final String BINLOG_FILE = "File"; + static final String BINLOG_POSITION = "Position"; + static final int NUM_OF_RETRIES = 3; + static final int BACKOFF_IN_MILLIS = 500; + static final String TYPE_NAME = "TYPE_NAME"; + static final String FKTABLE_NAME = "FKTABLE_NAME"; + static final String FKCOLUMN_NAME = "FKCOLUMN_NAME"; + static final String PKTABLE_NAME = "PKTABLE_NAME"; + static final String PKCOLUMN_NAME = "PKCOLUMN_NAME"; + static final String UPDATE_RULE = "UPDATE_RULE"; + static final String DELETE_RULE = "DELETE_RULE"; + static final String COLUMN_DEF = "COLUMN_DEF"; + private final ConnectionManager connectionManager; + + public MySqlSchemaManager(ConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } + + @Override + public List getPrimaryKeys(final String fullTableName) { + final String database = fullTableName.split("\\.")[0]; + final String table = fullTableName.split("\\.")[1]; + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + final List primaryKeys = new ArrayList<>(); + try (final Connection connection = connectionManager.getConnection()) { + try (final ResultSet rs = connection.getMetaData().getPrimaryKeys(database, null, table)) { + while (rs.next()) { + primaryKeys.add(rs.getString(COLUMN_NAME)); + } + return primaryKeys; + } + } catch (Exception e) { + LOG.error("Failed to get primary keys for table {}, retrying", table, e); + } + applyBackoff(); + retry++; + } + LOG.warn("Failed to get primary keys for table {}", table); + return List.of(); + } + + public Map getColumnDataTypes(final String database, final String tableName) { + final Map columnsToDataType = new HashMap<>(); + for (int retry = 0; retry <= NUM_OF_RETRIES; retry++) { + try (Connection connection = connectionManager.getConnection()) { + final DatabaseMetaData metaData = connection.getMetaData(); + + // Retrieve column metadata + try (ResultSet columns = metaData.getColumns(database, null, tableName, null)) { + while (columns.next()) { + columnsToDataType.put( + columns.getString(COLUMN_NAME), + columns.getString(TYPE_NAME) + ); + } + } + } catch (final Exception e) { + LOG.error("Failed to get dataTypes for database {} table {}, retrying", database, tableName, e); + if (retry == NUM_OF_RETRIES) { + throw new RuntimeException(String.format("Failed to get dataTypes for database %s table %s after " + + "%d retries", database, tableName, retry), e); + } + } + applyBackoff(); + } + return columnsToDataType; + } + + public Optional getCurrentBinaryLogPosition() { + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + try (final Connection connection = connectionManager.getConnection()) { + final Statement statement = connection.createStatement(); + final ResultSet rs = statement.executeQuery(BINLOG_STATUS_QUERY); + if (rs.next()) { + return Optional.of(new BinlogCoordinate(rs.getString(BINLOG_FILE), rs.getLong(BINLOG_POSITION))); + } + } catch (Exception e) { + LOG.error("Failed to get current binary log position, retrying", e); + } + applyBackoff(); + retry++; + } + LOG.warn("Failed to get current binary log position"); + return Optional.empty(); + } + + /** + * Get the foreign key relations associated with the given tables. + * + * @param tableNames the table names + * @return the foreign key relations + */ + public List getForeignKeyRelations(List tableNames) { + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + try (final Connection connection = connectionManager.getConnection()) { + final List foreignKeyRelations = new ArrayList<>(); + DatabaseMetaData metaData = connection.getMetaData(); + for (final String tableName : tableNames) { + String database = tableName.split("\\.")[0]; + String table = tableName.split("\\.")[1]; + ResultSet tableResult = metaData.getTables(database, null, table, TABLE_TYPES); + while (tableResult.next()) { + ResultSet foreignKeys = metaData.getImportedKeys(database, null, table); + + while (foreignKeys.next()) { + String fkTableName = foreignKeys.getString(FKTABLE_NAME); + String fkColumnName = foreignKeys.getString(FKCOLUMN_NAME); + String pkTableName = foreignKeys.getString(PKTABLE_NAME); + String pkColumnName = foreignKeys.getString(PKCOLUMN_NAME); + ForeignKeyAction updateAction = ForeignKeyAction.getActionFromMetadata(foreignKeys.getShort(UPDATE_RULE)); + ForeignKeyAction deleteAction = ForeignKeyAction.getActionFromMetadata(foreignKeys.getShort(DELETE_RULE)); + + Object defaultValue = null; + if (updateAction == ForeignKeyAction.SET_DEFAULT || deleteAction == ForeignKeyAction.SET_DEFAULT) { + // Get column default + ResultSet columnResult = metaData.getColumns(database, null, table, fkColumnName); + + if (columnResult.next()) { + defaultValue = columnResult.getObject(COLUMN_DEF); + } + } + + ForeignKeyRelation foreignKeyRelation = ForeignKeyRelation.builder() + .databaseName(database) + .parentTableName(pkTableName) + .referencedKeyName(pkColumnName) + .childTableName(fkTableName) + .foreignKeyName(fkColumnName) + .foreignKeyDefaultValue(defaultValue) + .updateAction(updateAction) + .deleteAction(deleteAction) + .build(); + + foreignKeyRelations.add(foreignKeyRelation); + } + } + } + + return foreignKeyRelations; + } catch (Exception e) { + LOG.error("Failed to scan foreign key references, retrying", e); + } + applyBackoff(); + retry++; + } + LOG.warn("Failed to scan foreign key references"); + return List.of(); + } + + private void applyBackoff() { + try { + Thread.sleep(BACKOFF_IN_MILLIS); + } catch (final InterruptedException e){ + Thread.currentThread().interrupt(); + } + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresConnectionManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresConnectionManager.java new file mode 100644 index 0000000000..c7b02a0c10 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresConnectionManager.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.postgresql.PGProperty; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +public class PostgresConnectionManager implements ConnectionManager { + private static final Logger LOG = LoggerFactory.getLogger(PostgresConnectionManager.class); + + public static final String JDBC_URL_FORMAT = "jdbc:postgresql://%s:%d/%s"; + public static final String SERVER_VERSION_9_4 = "9.4"; + public static final String DATABASE_REPLICATION = "database"; + public static final String SIMPLE_QUERY = "simple"; + public static final String TRUE_VALUE = "true"; + public static final String FALSE_VALUE = "false"; + public static final String REQUIRE_SSL = "require"; + + private final String endpoint; + private final int port; + private final String username; + private final String password; + private final boolean requireSSL; + private final String database; + + public PostgresConnectionManager(String endpoint, int port, String username, String password, boolean requireSSL, String database) { + this.endpoint = endpoint; + this.port = port; + this.username = username; + this.password = password; + this.requireSSL = requireSSL; + this.database = database; + } + + @Override + public Connection getConnection() throws SQLException { + final Properties props = new Properties(); + PGProperty.USER.set(props, username); + if (!password.isEmpty()) { + PGProperty.PASSWORD.set(props, password); + } + PGProperty.ASSUME_MIN_SERVER_VERSION.set(props, SERVER_VERSION_9_4); // This is required + PGProperty.REPLICATION.set(props, DATABASE_REPLICATION); // This is also required + PGProperty.PREFER_QUERY_MODE.set(props, SIMPLE_QUERY); + + if (requireSSL) { + PGProperty.SSL.set(props, TRUE_VALUE); + PGProperty.SSL_MODE.set(props, REQUIRE_SSL); + } else { + PGProperty.SSL.set(props, FALSE_VALUE); + } + + final String jdbcUrl = String.format(JDBC_URL_FORMAT, this.endpoint, this.port, this.database); + LOG.debug("Connecting to JDBC URL: {}", jdbcUrl); + return doGetConnection(jdbcUrl, props); + } + + // VisibleForTesting + Connection doGetConnection(String jdbcUrl, Properties props) throws SQLException { + return DriverManager.getConnection(jdbcUrl, props); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresSchemaManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresSchemaManager.java new file mode 100644 index 0000000000..dcd604f8a8 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresSchemaManager.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.postgresql.PGConnection; +import org.postgresql.replication.PGReplicationConnection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.util.ArrayList; +import java.util.List; + +public class PostgresSchemaManager implements SchemaManager { + private static final Logger LOG = LoggerFactory.getLogger(PostgresSchemaManager.class); + private final ConnectionManager connectionManager; + + static final int NUM_OF_RETRIES = 3; + static final int BACKOFF_IN_MILLIS = 500; + static final String COLUMN_NAME = "COLUMN_NAME"; + + public PostgresSchemaManager(ConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } + + public void createLogicalReplicationSlot(final List tableNames, final String publicationName, final String slotName) { + StringBuilder createPublicationStatementBuilder = new StringBuilder("CREATE PUBLICATION ") + .append(publicationName) + .append(" FOR TABLE "); + for (int i = 0; i < tableNames.size(); i++) { + createPublicationStatementBuilder.append(tableNames.get(i)); + if (i < tableNames.size() - 1) { + createPublicationStatementBuilder.append(", "); + } + } + createPublicationStatementBuilder.append(";"); + final String createPublicationStatement = createPublicationStatementBuilder.toString(); + + try (Connection conn = connectionManager.getConnection()) { + try { + PreparedStatement statement = conn.prepareStatement(createPublicationStatement); + statement.executeUpdate(); + } catch (Exception e) { + LOG.info("Failed to create publication: {}", e.getMessage()); + } + + PGConnection pgConnection = conn.unwrap(PGConnection.class); + + // Create replication slot + PGReplicationConnection replicationConnection = pgConnection.getReplicationAPI(); + try { + replicationConnection.createReplicationSlot() + .logical() + .withSlotName(slotName) + .withOutputPlugin("pgoutput") + .make(); + LOG.info("Replication slot {} created successfully. ", slotName); + } catch (Exception e) { + LOG.info("Failed to create replication slot {}: {}", slotName, e.getMessage()); + } + } catch (Exception e) { + LOG.error("Exception when creating replication slot. ", e); + } + } + + @Override + public List getPrimaryKeys(final String fullTableName) { + final String[] splits = fullTableName.split("\\."); + final String database = splits[0]; + final String schema = splits[1]; + final String table = splits[2]; + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + final List primaryKeys = new ArrayList<>(); + try (final Connection connection = connectionManager.getConnection()) { + try (final ResultSet rs = connection.getMetaData().getPrimaryKeys(database, schema, table)) { + while (rs.next()) { + primaryKeys.add(rs.getString(COLUMN_NAME)); + } + return primaryKeys; + } + } catch (Exception e) { + LOG.error("Failed to get primary keys for table {}, retrying", table, e); + } + applyBackoff(); + retry++; + } + LOG.warn("Failed to get primary keys for table {}", table); + return List.of(); + } + + private void applyBackoff() { + try { + Thread.sleep(BACKOFF_IN_MILLIS); + } catch (final InterruptedException e){ + Thread.currentThread().interrupt(); + } + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManager.java index d89345fb71..95e61bc729 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManager.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManager.java @@ -26,9 +26,9 @@ public class QueryManager { static final int NUM_OF_RETRIES = 3; static final int BACKOFF_IN_MILLIS = 500; - private final ConnectionManager connectionManager; + private final MySqlConnectionManager connectionManager; - public QueryManager(ConnectionManager connectionManager) { + public QueryManager(MySqlConnectionManager connectionManager) { this.connectionManager = connectionManager; } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java index bbe01ba160..000a1eea05 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java @@ -1,186 +1,15 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.dataprepper.plugins.source.rds.schema; -import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; -import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyAction; -import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.sql.Connection; -import java.sql.DatabaseMetaData; -import java.sql.ResultSet; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Optional; - -public class SchemaManager { - private static final Logger LOG = LoggerFactory.getLogger(SchemaManager.class); - - static final String[] TABLE_TYPES = new String[]{"TABLE"}; - static final String COLUMN_NAME = "COLUMN_NAME"; - static final String BINLOG_STATUS_QUERY = "SHOW MASTER STATUS"; - static final String BINLOG_FILE = "File"; - static final String BINLOG_POSITION = "Position"; - static final int NUM_OF_RETRIES = 3; - static final int BACKOFF_IN_MILLIS = 500; - static final String TYPE_NAME = "TYPE_NAME"; - static final String FKTABLE_NAME = "FKTABLE_NAME"; - static final String FKCOLUMN_NAME = "FKCOLUMN_NAME"; - static final String PKTABLE_NAME = "PKTABLE_NAME"; - static final String PKCOLUMN_NAME = "PKCOLUMN_NAME"; - static final String UPDATE_RULE = "UPDATE_RULE"; - static final String DELETE_RULE = "DELETE_RULE"; - static final String COLUMN_DEF = "COLUMN_DEF"; - private final ConnectionManager connectionManager; - - public SchemaManager(ConnectionManager connectionManager) { - this.connectionManager = connectionManager; - } - - public List getPrimaryKeys(final String database, final String table) { - int retry = 0; - while (retry <= NUM_OF_RETRIES) { - final List primaryKeys = new ArrayList<>(); - try (final Connection connection = connectionManager.getConnection()) { - try (final ResultSet rs = connection.getMetaData().getPrimaryKeys(database, null, table)) { - while (rs.next()) { - primaryKeys.add(rs.getString(COLUMN_NAME)); - } - return primaryKeys; - } - } catch (Exception e) { - LOG.error("Failed to get primary keys for table {}, retrying", table, e); - } - applyBackoff(); - retry++; - } - LOG.warn("Failed to get primary keys for table {}", table); - return List.of(); - } - - public Map getColumnDataTypes(final String database, final String tableName) { - final Map columnsToDataType = new HashMap<>(); - for (int retry = 0; retry <= NUM_OF_RETRIES; retry++) { - try (Connection connection = connectionManager.getConnection()) { - final DatabaseMetaData metaData = connection.getMetaData(); - - // Retrieve column metadata - try (ResultSet columns = metaData.getColumns(database, null, tableName, null)) { - while (columns.next()) { - columnsToDataType.put( - columns.getString(COLUMN_NAME), - columns.getString(TYPE_NAME) - ); - } - } - } catch (final Exception e) { - LOG.error("Failed to get dataTypes for database {} table {}, retrying", database, tableName, e); - if (retry == NUM_OF_RETRIES) { - throw new RuntimeException(String.format("Failed to get dataTypes for database %s table %s after " + - "%d retries", database, tableName, retry), e); - } - } - applyBackoff(); - } - return columnsToDataType; - } - - public Optional getCurrentBinaryLogPosition() { - int retry = 0; - while (retry <= NUM_OF_RETRIES) { - try (final Connection connection = connectionManager.getConnection()) { - final Statement statement = connection.createStatement(); - final ResultSet rs = statement.executeQuery(BINLOG_STATUS_QUERY); - if (rs.next()) { - return Optional.of(new BinlogCoordinate(rs.getString(BINLOG_FILE), rs.getLong(BINLOG_POSITION))); - } - } catch (Exception e) { - LOG.error("Failed to get current binary log position, retrying", e); - } - applyBackoff(); - retry++; - } - LOG.warn("Failed to get current binary log position"); - return Optional.empty(); - } +/** + * Interface for manager classes that are used to get metadata of a database, such as table schemas + */ +public interface SchemaManager { /** - * Get the foreign key relations associated with the given tables. - * - * @param tableNames the table names - * @return the foreign key relations + * Get the primary keys for a table + * @param fullTableName The full table name + * @return List of primary keys */ - public List getForeignKeyRelations(List tableNames) { - int retry = 0; - while (retry <= NUM_OF_RETRIES) { - try (final Connection connection = connectionManager.getConnection()) { - final List foreignKeyRelations = new ArrayList<>(); - DatabaseMetaData metaData = connection.getMetaData(); - for (final String tableName : tableNames) { - String database = tableName.split("\\.")[0]; - String table = tableName.split("\\.")[1]; - ResultSet tableResult = metaData.getTables(database, null, table, TABLE_TYPES); - while (tableResult.next()) { - ResultSet foreignKeys = metaData.getImportedKeys(database, null, table); - - while (foreignKeys.next()) { - String fkTableName = foreignKeys.getString(FKTABLE_NAME); - String fkColumnName = foreignKeys.getString(FKCOLUMN_NAME); - String pkTableName = foreignKeys.getString(PKTABLE_NAME); - String pkColumnName = foreignKeys.getString(PKCOLUMN_NAME); - ForeignKeyAction updateAction = ForeignKeyAction.getActionFromMetadata(foreignKeys.getShort(UPDATE_RULE)); - ForeignKeyAction deleteAction = ForeignKeyAction.getActionFromMetadata(foreignKeys.getShort(DELETE_RULE)); - - Object defaultValue = null; - if (updateAction == ForeignKeyAction.SET_DEFAULT || deleteAction == ForeignKeyAction.SET_DEFAULT) { - // Get column default - ResultSet columnResult = metaData.getColumns(database, null, table, fkColumnName); - - if (columnResult.next()) { - defaultValue = columnResult.getObject(COLUMN_DEF); - } - } - - ForeignKeyRelation foreignKeyRelation = ForeignKeyRelation.builder() - .databaseName(database) - .parentTableName(pkTableName) - .referencedKeyName(pkColumnName) - .childTableName(fkTableName) - .foreignKeyName(fkColumnName) - .foreignKeyDefaultValue(defaultValue) - .updateAction(updateAction) - .deleteAction(deleteAction) - .build(); - - foreignKeyRelations.add(foreignKeyRelation); - } - } - } - - return foreignKeyRelations; - } catch (Exception e) { - LOG.error("Failed to scan foreign key references, retrying", e); - } - applyBackoff(); - retry++; - } - LOG.warn("Failed to scan foreign key references"); - return List.of(); - } - - private void applyBackoff() { - try { - Thread.sleep(BACKOFF_IN_MILLIS); - } catch (final InterruptedException e){ - Thread.currentThread().interrupt(); - } - } + List getPrimaryKeys(final String fullTableName); } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerFactory.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerFactory.java new file mode 100644 index 0000000000..c11d427e12 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerFactory.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +public class SchemaManagerFactory { + private final ConnectionManager connectionManager; + + public SchemaManagerFactory(final ConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } + + public SchemaManager getSchemaManager() { + if (connectionManager instanceof MySqlConnectionManager) { + return new MySqlSchemaManager(connectionManager); + } + + return new PostgresSchemaManager(connectionManager); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientFactory.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientFactory.java deleted file mode 100644 index b63e588f01..0000000000 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientFactory.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.source.rds.stream; - -import com.github.shyiko.mysql.binlog.BinaryLogClient; -import com.github.shyiko.mysql.binlog.event.deserialization.EventDeserializer; -import com.github.shyiko.mysql.binlog.network.SSLMode; -import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; -import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; -import software.amazon.awssdk.services.rds.RdsClient; - -public class BinlogClientFactory { - - private final RdsClient rdsClient; - private final DbMetadata dbMetadata; - private String username; - private String password; - private SSLMode sslMode = SSLMode.REQUIRED; - - public BinlogClientFactory(final RdsSourceConfig sourceConfig, - final RdsClient rdsClient, - final DbMetadata dbMetadata) { - this.rdsClient = rdsClient; - this.dbMetadata = dbMetadata; - username = sourceConfig.getAuthenticationConfig().getUsername(); - password = sourceConfig.getAuthenticationConfig().getPassword(); - } - - public BinaryLogClient create() { - BinaryLogClient binaryLogClient = new BinaryLogClient( - dbMetadata.getEndpoint(), - dbMetadata.getPort(), - username, - password); - binaryLogClient.setSSLMode(sslMode); - final EventDeserializer eventDeserializer = new EventDeserializer(); - eventDeserializer.setCompatibilityMode( - EventDeserializer.CompatibilityMode.DATE_AND_TIME_AS_LONG - ); - binaryLogClient.setEventDeserializer(eventDeserializer); - return binaryLogClient; - } - - public void setSSLMode(SSLMode sslMode) { - this.sslMode = sslMode; - } - - public void setCredentials(String username, String password) { - this.username = username; - this.password = password; - } -} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientWrapper.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientWrapper.java new file mode 100644 index 0000000000..36d8195106 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientWrapper.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import com.github.shyiko.mysql.binlog.BinaryLogClient; + +import java.io.IOException; + +public class BinlogClientWrapper implements ReplicationLogClient { + + private final BinaryLogClient binlogClient; + + public BinlogClientWrapper(final BinaryLogClient binlogClient) { + this.binlogClient = binlogClient; + } + + @Override + public void connect() throws IOException { + binlogClient.connect(); + } + + @Override + public void disconnect() throws IOException { + binlogClient.disconnect(); + } + + public BinaryLogClient getBinlogClient() { + return binlogClient; + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java index 4491a7c643..2bc21ca786 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java @@ -30,12 +30,12 @@ import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; import org.opensearch.dataprepper.plugins.source.rds.converter.StreamRecordConverter; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; -import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; -import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata; -import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; import org.opensearch.dataprepper.plugins.source.rds.datatype.DataTypeHelper; import org.opensearch.dataprepper.plugins.source.rds.datatype.MySQLDataType; +import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata; import org.opensearch.dataprepper.plugins.source.rds.model.ParentTable; +import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java new file mode 100644 index 0000000000..cc83fdd232 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager; +import org.postgresql.PGConnection; +import org.postgresql.replication.LogSequenceNumber; +import org.postgresql.replication.PGReplicationStream; +import org.postgresql.replication.fluent.logical.ChainedLogicalStreamBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.sql.Connection; + +public class LogicalReplicationClient implements ReplicationLogClient { + + private static final Logger LOG = LoggerFactory.getLogger(LogicalReplicationClient.class); + + private final ConnectionManager connectionManager; + private final String replicationSlotName; + private LogSequenceNumber startLsn; + private LogicalReplicationEventProcessor eventProcessor; + + private volatile boolean disconnectRequested = false; + + public LogicalReplicationClient(final ConnectionManager connectionManager, + final String replicationSlotName) { + this.connectionManager = connectionManager; + this.replicationSlotName = replicationSlotName; + } + + @Override + public void connect() { + PGReplicationStream stream; + try (Connection conn = connectionManager.getConnection()) { + PGConnection pgConnection = conn.unwrap(PGConnection.class); + + // Create a replication stream + ChainedLogicalStreamBuilder logicalStreamBuilder = pgConnection.getReplicationAPI() + .replicationStream() + .logical() + .withSlotName(replicationSlotName) + .withSlotOption("proto_version", "1") + .withSlotOption("publication_names", "my_publication"); + if (startLsn != null) { + logicalStreamBuilder.withStartPosition(startLsn); + } + stream = logicalStreamBuilder.start(); + + if (eventProcessor != null) { + while (!disconnectRequested) { + try { + // Read changes + ByteBuffer msg = stream.readPending(); + + if (msg == null) { + Thread.sleep(10); + continue; + } + + // decode and convert events to Data Prepper events + eventProcessor.process(msg); + + // Acknowledge receiving the message + LogSequenceNumber lsn = stream.getLastReceiveLSN(); + stream.setFlushedLSN(lsn); + stream.setAppliedLSN(lsn); + } catch (Exception e) { + LOG.error("Exception while processing Postgres replication stream. ", e); + } + } + } + + stream.close(); + LOG.info("Replication stream closed successfully."); + } catch (Exception e) { + LOG.error("Exception while creating Postgres replication stream. ", e); + } + } + + @Override + public void disconnect() { + disconnectRequested = true; + } + + public void setEventProcessor(LogicalReplicationEventProcessor eventProcessor) { + this.eventProcessor = eventProcessor; + } + + public void setStartLsn(LogSequenceNumber startLsn) { + this.startLsn = startLsn; + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java new file mode 100644 index 0000000000..a48e468586 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java @@ -0,0 +1,312 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.opensearch.OpenSearchBulkActions; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.converter.StreamRecordConverter; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; +import org.opensearch.dataprepper.plugins.source.rds.datatype.postgres.ColumnType; +import org.opensearch.dataprepper.plugins.source.rds.model.MessageType; +import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class LogicalReplicationEventProcessor { + + private static final Logger LOG = LoggerFactory.getLogger(LogicalReplicationEventProcessor.class); + + static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(60); + static final int DEFAULT_BUFFER_BATCH_SIZE = 1_000; + + private final StreamPartition streamPartition; + private final RdsSourceConfig sourceConfig; + private final StreamRecordConverter recordConverter; + private final Buffer> buffer; + private final BufferAccumulator> bufferAccumulator; + private final List pipelineEvents; + + private long currentLsn; + private long currentEventTimestamp; + + private Map tableMetadataMap; + + public LogicalReplicationEventProcessor(final StreamPartition streamPartition, + final RdsSourceConfig sourceConfig, + final Buffer> buffer, + final String s3Prefix) { + this.streamPartition = streamPartition; + this.sourceConfig = sourceConfig; + recordConverter = new StreamRecordConverter(s3Prefix, sourceConfig.getPartitionCount()); + this.buffer = buffer; + bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT); + + tableMetadataMap = new HashMap<>(); + pipelineEvents = new ArrayList<>(); + } + + public void process(ByteBuffer msg) { + // Message processing logic: + // If it's a BEGIN, note its LSN + // If it's a RELATION, update table metadata map + // If it's INSERT/UPDATE/DELETE, prepare events + // If it's a COMMIT, convert all prepared events and send to buffer + char messageType = (char) msg.get(); + if (messageType == MessageType.BEGIN.getValue()) { + processBeginMessage(msg); + } else if (messageType == MessageType.RELATION.getValue()) { + processRelationMessage(msg); + } else if (messageType == MessageType.INSERT.getValue()) { + processInsertMessage(msg); + } else if (messageType == MessageType.UPDATE.getValue()) { + processUpdateMessage(msg); + } else if (messageType == MessageType.DELETE.getValue()) { + processDeleteMessage(msg); + } else if (messageType == MessageType.COMMIT.getValue()) { + processCommitMessage(msg); + } else { + throw new IllegalArgumentException("Replication message type [" + messageType + "] is not supported. "); + } + } + + void processBeginMessage(ByteBuffer msg) { + currentLsn = msg.getLong(); + long epochMicro = msg.getLong(); + currentEventTimestamp = convertPostgresEventTimestamp(epochMicro); + int transaction_xid = msg.getInt(); + + LOG.debug("Processed BEGIN message with LSN: {}, Timestamp: {}, TransactionId: {}", currentLsn, currentEventTimestamp, transaction_xid); + } + + void processRelationMessage(ByteBuffer msg) { + int tableId = msg.getInt(); + // null terminated string + String schemaName = getNullTerminatedString(msg); + String tableName = getNullTerminatedString(msg); + int replicaId = msg.get(); + short numberOfColumns = msg.getShort(); + + List columnNames = new ArrayList<>(); + for (int i = 0; i < numberOfColumns; i++) { + int flag = msg.get(); // 1 indicates this column is part of the replica identity + // null terminated string + String columnName = getNullTerminatedString(msg); + ColumnType columnType = ColumnType.getByTypeId(msg.getInt()); + String columnTypeName = columnType.getTypeName(); + int typeModifier = msg.getInt(); + if (columnType == ColumnType.VARCHAR) { + int varcharLength = typeModifier - 4; + } else if (columnType == ColumnType.NUMERIC) { + int precision = (typeModifier - 4) >> 16; + int scale = (typeModifier - 4) & 0xFFFF; + } + columnNames.add(columnName); + } + + final List primaryKeys = getPrimaryKeys(schemaName, tableName); + final TableMetadata tableMetadata = new TableMetadata( + tableName, schemaName, columnNames, primaryKeys); + + tableMetadataMap.put((long) tableId, tableMetadata); + + LOG.debug("Processed an Relation message with RelationId: {} Namespace: {} RelationName: {} ReplicaId: {}", tableId, schemaName, tableName, replicaId); + } + + void processCommitMessage(ByteBuffer msg) { + int flag = msg.get(); + long commitLsn = msg.getLong(); + long endLsn = msg.getLong(); + long epochMicro = msg.getLong(); + + if (currentLsn != commitLsn) { + // This shouldn't happen + pipelineEvents.clear(); + throw new RuntimeException("Commit LSN does not match current LSN, skipping"); + } + + writeToBuffer(bufferAccumulator); + LOG.debug("Processed a COMMIT message with Flag: {} CommitLsn: {} EndLsn: {} Timestamp: {}", flag, commitLsn, endLsn, epochMicro); + } + + void processInsertMessage(ByteBuffer msg) { + int tableId = msg.getInt(); + char n_char = (char) msg.get(); // Skip the 'N' character + + final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final List columnNames = tableMetadata.getColumnNames(); + final List primaryKeys = tableMetadata.getPrimaryKeys(); + final long eventTimestampMillis = currentEventTimestamp; + + doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.INDEX); + LOG.debug("Processed an INSERT message with table id: {}", tableId); + } + + void processUpdateMessage(ByteBuffer msg) { + final int tableId = msg.getInt(); + + final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final List columnNames = tableMetadata.getColumnNames(); + final List primaryKeys = tableMetadata.getPrimaryKeys(); + final long eventTimestampMillis = currentEventTimestamp; + + char typeId = (char) msg.get(); + if (typeId == 'N') { + doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.INDEX); + LOG.debug("Processed an UPDATE message with table id: {}", tableId); + } else if (typeId == 'K') { + // Primary keys were changed + doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.DELETE); + msg.get(); // should be a char 'N' + doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.INDEX); + LOG.debug("Processed an UPDATE message with table id: {} and primary key(s) were changed", tableId); + + } else if (typeId == 'O') { + // Replica Identity is set to full, containing both old and new row data + Map oldRowDataMap = getRowDataMap(msg, columnNames); + msg.get(); // should be a char 'N' + Map newRowDataMap = getRowDataMap(msg, columnNames); + + if (isPrimaryKeyChanged(oldRowDataMap, newRowDataMap, primaryKeys)) { + createPipelineEvent(oldRowDataMap, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.DELETE); + } + createPipelineEvent(newRowDataMap, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.INDEX); + } + } + + private boolean isPrimaryKeyChanged(Map oldRowDataMap, Map newRowDataMap, List primaryKeys) { + for (String primaryKey : primaryKeys) { + if (!oldRowDataMap.get(primaryKey).equals(newRowDataMap.get(primaryKey))) { + return true; + } + } + return false; + } + + void processDeleteMessage(ByteBuffer msg) { + int tableId = msg.getInt(); + char n_char = (char) msg.get(); // Skip the 'N' character + + final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final List columnNames = tableMetadata.getColumnNames(); + final List primaryKeys = tableMetadata.getPrimaryKeys(); + final long eventTimestampMillis = currentEventTimestamp; + + doProcess(msg, columnNames, tableMetadata, primaryKeys, eventTimestampMillis, OpenSearchBulkActions.DELETE); + LOG.debug("Processed a DELETE message with table id: {}", tableId); + } + + private void doProcess(ByteBuffer msg, List columnNames, TableMetadata tableMetadata, + List primaryKeys, long eventTimestampMillis, OpenSearchBulkActions bulkAction) { + Map rowDataMap = getRowDataMap(msg, columnNames); + + createPipelineEvent(rowDataMap, tableMetadata, primaryKeys, eventTimestampMillis, bulkAction); + } + + private Map getRowDataMap(ByteBuffer msg, List columnNames) { + Map rowDataMap = new HashMap<>(); + short numberOfColumns = msg.getShort(); + for (int i = 0; i < numberOfColumns; i++) { + char type = (char) msg.get(); + if (type == 'n') { + rowDataMap.put(columnNames.get(i), null); + } else if (type == 't') { + int length = msg.getInt(); + byte[] bytes = new byte[length]; + msg.get(bytes); + rowDataMap.put(columnNames.get(i), new String(bytes)); + } else { + LOG.warn("Unknown column type: {}", type); + } + } + return rowDataMap; + } + + private void createPipelineEvent(Map rowDataMap, TableMetadata tableMetadata, List primaryKeys, long eventTimestampMillis, OpenSearchBulkActions bulkAction) { + final Event dataPrepperEvent = JacksonEvent.builder() + .withEventType("event") + .withData(rowDataMap) + .build(); + + final Event pipelineEvent = recordConverter.convert( + dataPrepperEvent, + tableMetadata.getDatabaseName(), + tableMetadata.getTableName(), + bulkAction, + primaryKeys, + eventTimestampMillis, + eventTimestampMillis, + null); + pipelineEvents.add(pipelineEvent); + } + + private void writeToBuffer(BufferAccumulator> bufferAccumulator) { + for (Event pipelineEvent : pipelineEvents) { + addToBufferAccumulator(bufferAccumulator, new Record<>(pipelineEvent)); + } + + flushBufferAccumulator(bufferAccumulator, pipelineEvents.size()); + pipelineEvents.clear(); + } + + private void addToBufferAccumulator(final BufferAccumulator> bufferAccumulator, final Record record) { + try { + bufferAccumulator.add(record); + } catch (Exception e) { + LOG.error("Failed to add event to buffer", e); + } + } + + private void flushBufferAccumulator(BufferAccumulator> bufferAccumulator, int eventCount) { + try { + bufferAccumulator.flush(); + } catch (Exception e) { + // this will only happen if writing to buffer gets interrupted from shutdown, + // otherwise bufferAccumulator will keep retrying with backoff + LOG.error("Failed to flush buffer", e); + } + } + + private long convertPostgresEventTimestamp(long postgresMicro) { + // Offset in microseconds between 1970-01-01 and 2000-01-01 + long offsetMicro = 946684800L * 1_000_000L; + return (postgresMicro + offsetMicro) / 1000; + } + + private String getNullTerminatedString(ByteBuffer msg) { + StringBuilder sb = new StringBuilder(); + while (msg.hasRemaining()) { + byte b = msg.get(); + if (b == 0) break; // Stop at null terminator + sb.append((char) b); + } + return sb.toString(); + } + + private List getPrimaryKeys(String schemaName, String tableName) { + final String databaseName = sourceConfig.getTableNames().get(0).split("\\.")[0]; + StreamProgressState progressState = streamPartition.getProgressState().get(); + + return progressState.getPrimaryKeyMap().get(databaseName + "." + schemaName + "." + tableName); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClient.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClient.java new file mode 100644 index 0000000000..bef9064b9c --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClient.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import java.io.IOException; + +public interface ReplicationLogClient { + + void connect() throws IOException; + + void disconnect() throws IOException; +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactory.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactory.java new file mode 100644 index 0000000000..d9bc54570a --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactory.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import com.github.shyiko.mysql.binlog.BinaryLogClient; +import com.github.shyiko.mysql.binlog.event.deserialization.EventDeserializer; +import com.github.shyiko.mysql.binlog.network.SSLMode; +import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; +import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager; +import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManagerFactory; +import software.amazon.awssdk.services.rds.RdsClient; + +import java.util.List; +import java.util.NoSuchElementException; + +public class ReplicationLogClientFactory { + + private final RdsSourceConfig sourceConfig; + private final RdsClient rdsClient; + private final DbMetadata dbMetadata; + private String username; + private String password; + private SSLMode sslMode = SSLMode.REQUIRED; + + public ReplicationLogClientFactory(final RdsSourceConfig sourceConfig, + final RdsClient rdsClient, + final DbMetadata dbMetadata) { + this.sourceConfig = sourceConfig; + this.rdsClient = rdsClient; + this.dbMetadata = dbMetadata; + username = sourceConfig.getAuthenticationConfig().getUsername(); + password = sourceConfig.getAuthenticationConfig().getPassword(); + } + + public ReplicationLogClient create(StreamPartition streamPartition) { + if (sourceConfig.getEngine() == EngineType.MYSQL) { + return new BinlogClientWrapper(createBinaryLogClient()); + } else { // Postgres + return createLogicalReplicationClient(streamPartition); + } + } + + private BinaryLogClient createBinaryLogClient() { + BinaryLogClient binaryLogClient = new BinaryLogClient( + dbMetadata.getEndpoint(), + dbMetadata.getPort(), + username, + password); + binaryLogClient.setSSLMode(sslMode); + final EventDeserializer eventDeserializer = new EventDeserializer(); + eventDeserializer.setCompatibilityMode( + EventDeserializer.CompatibilityMode.DATE_AND_TIME_AS_LONG + ); + binaryLogClient.setEventDeserializer(eventDeserializer); + return binaryLogClient; + } + + private LogicalReplicationClient createLogicalReplicationClient(StreamPartition streamPartition) { + final String replicationSlotName = streamPartition.getProgressState().get().getReplicationSlotName(); + if (replicationSlotName == null) { + throw new NoSuchElementException("Replication slot name is not found in progress state."); + } + final ConnectionManagerFactory connectionManagerFactory = new ConnectionManagerFactory(sourceConfig, dbMetadata); + final ConnectionManager connectionManager = connectionManagerFactory.getConnectionManager(); + return new LogicalReplicationClient(connectionManager, replicationSlotName); + } + + public void setSSLMode(SSLMode sslMode) { + this.sslMode = sslMode; + } + + public void setCredentials(String username, String password) { + this.username = username; + this.password = password; + } + + private String getDatabaseName(List tableNames) { + return tableNames.get(0).split("\\.")[0]; + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamScheduler.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamScheduler.java index 31c2900607..7d638931a3 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamScheduler.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamScheduler.java @@ -34,7 +34,7 @@ public class StreamScheduler implements Runnable { private final EnhancedSourceCoordinator sourceCoordinator; private final RdsSourceConfig sourceConfig; private final String s3Prefix; - private final BinlogClientFactory binlogClientFactory; + private ReplicationLogClientFactory replicationLogClientFactory; private final Buffer> buffer; private final PluginMetrics pluginMetrics; private final AcknowledgementSetManager acknowledgementSetManager; @@ -46,7 +46,7 @@ public class StreamScheduler implements Runnable { public StreamScheduler(final EnhancedSourceCoordinator sourceCoordinator, final RdsSourceConfig sourceConfig, final String s3Prefix, - final BinlogClientFactory binlogClientFactory, + final ReplicationLogClientFactory replicationLogClientFactory, final Buffer> buffer, final PluginMetrics pluginMetrics, final AcknowledgementSetManager acknowledgementSetManager, @@ -54,7 +54,7 @@ public StreamScheduler(final EnhancedSourceCoordinator sourceCoordinator, this.sourceCoordinator = sourceCoordinator; this.sourceConfig = sourceConfig; this.s3Prefix = s3Prefix; - this.binlogClientFactory = binlogClientFactory; + this.replicationLogClientFactory = replicationLogClientFactory; this.buffer = buffer; this.pluginMetrics = pluginMetrics; this.acknowledgementSetManager = acknowledgementSetManager; @@ -80,7 +80,7 @@ public void run() { final StreamCheckpointer streamCheckpointer = new StreamCheckpointer(sourceCoordinator, streamPartition, pluginMetrics); streamWorkerTaskRefresher = StreamWorkerTaskRefresher.create( - sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, binlogClientFactory, buffer, + sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, replicationLogClientFactory, buffer, () -> Executors.newSingleThreadExecutor(BackgroundThreadFactory.defaultExecutorThreadFactory("rds-source-stream-worker")), acknowledgementSetManager, pluginMetrics); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorker.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorker.java index 0b92e19d85..d6404a42ef 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorker.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorker.java @@ -11,6 +11,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -22,21 +23,21 @@ public class StreamWorker { private static final int DEFAULT_EXPORT_COMPLETE_WAIT_INTERVAL_MILLIS = 60_000; private final EnhancedSourceCoordinator sourceCoordinator; - private final BinaryLogClient binaryLogClient; + private final ReplicationLogClient replicationLogClient; private final PluginMetrics pluginMetrics; StreamWorker(final EnhancedSourceCoordinator sourceCoordinator, - final BinaryLogClient binaryLogClient, + final ReplicationLogClient replicationLogClient, final PluginMetrics pluginMetrics) { this.sourceCoordinator = sourceCoordinator; - this.binaryLogClient = binaryLogClient; + this.replicationLogClient = replicationLogClient; this.pluginMetrics = pluginMetrics; } public static StreamWorker create(final EnhancedSourceCoordinator sourceCoordinator, - final BinaryLogClient binaryLogClient, + final ReplicationLogClient replicationLogClient, final PluginMetrics pluginMetrics) { - return new StreamWorker(sourceCoordinator, binaryLogClient, pluginMetrics); + return new StreamWorker(sourceCoordinator, replicationLogClient, pluginMetrics); } public void processStream(final StreamPartition streamPartition) { @@ -51,16 +52,20 @@ public void processStream(final StreamPartition streamPartition) { } } - setStartBinlogPosition(streamPartition); + if (replicationLogClient instanceof BinlogClientWrapper) { + setStartBinlogPosition(streamPartition); + } else { + setStartLsn(streamPartition); + } try { LOG.info("Connect to database to read change events."); - binaryLogClient.connect(); + replicationLogClient.connect(); } catch (Exception e) { throw new RuntimeException(e); } finally { try { - binaryLogClient.disconnect(); + replicationLogClient.disconnect(); } catch (Exception e) { LOG.error("Binary log client failed to disconnect.", e); } @@ -90,8 +95,19 @@ private void setStartBinlogPosition(final StreamPartition streamPartition) { final String binlogFilename = startBinlogPosition.getBinlogFilename(); final long binlogPosition = startBinlogPosition.getBinlogPosition(); LOG.debug("Will start binlog stream from binlog file {} and position {}.", binlogFilename, binlogPosition); + BinaryLogClient binaryLogClient = ((BinlogClientWrapper) replicationLogClient).getBinlogClient(); binaryLogClient.setBinlogFilename(binlogFilename); binaryLogClient.setBinlogPosition(binlogPosition); } } + + private void setStartLsn(final StreamPartition streamPartition) { + final String startLsn = streamPartition.getProgressState().get().getCurrentLsn(); + + if (startLsn != null) { + LOG.debug("Will start logical replication from LSN {}", startLsn); + LogicalReplicationClient logicalReplicationClient = (LogicalReplicationClient) replicationLogClient; + logicalReplicationClient.setStartLsn(LogSequenceNumber.valueOf(startLsn)); + } + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java index dcc64354d0..acd8d0535f 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java @@ -16,6 +16,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata; @@ -38,7 +39,7 @@ public class StreamWorkerTaskRefresher implements PluginConfigObserver> buffer; private final Supplier executorServiceSupplier; private final PluginMetrics pluginMetrics; @@ -53,7 +54,7 @@ public StreamWorkerTaskRefresher(final EnhancedSourceCoordinator sourceCoordinat final StreamPartition streamPartition, final StreamCheckpointer streamCheckpointer, final String s3Prefix, - final BinlogClientFactory binlogClientFactory, + final ReplicationLogClientFactory replicationLogClientFactory, final Buffer> buffer, final Supplier executorServiceSupplier, final AcknowledgementSetManager acknowledgementSetManager, @@ -67,7 +68,7 @@ public StreamWorkerTaskRefresher(final EnhancedSourceCoordinator sourceCoordinat executorService = executorServiceSupplier.get(); this.pluginMetrics = pluginMetrics; this.acknowledgementSetManager = acknowledgementSetManager; - this.binlogClientFactory = binlogClientFactory; + this.replicationLogClientFactory = replicationLogClientFactory; this.credentialsChangeCounter = pluginMetrics.counter(CREDENTIALS_CHANGED); this.taskRefreshErrorsCounter = pluginMetrics.counter(TASK_REFRESH_ERRORS); } @@ -76,7 +77,7 @@ public static StreamWorkerTaskRefresher create(final EnhancedSourceCoordinator s final StreamPartition streamPartition, final StreamCheckpointer streamCheckpointer, final String s3Prefix, - final BinlogClientFactory binlogClientFactory, + final ReplicationLogClientFactory binlogClientFactory, final Buffer> buffer, final Supplier executorServiceSupplier, final AcknowledgementSetManager acknowledgementSetManager, @@ -98,7 +99,7 @@ public void update(RdsSourceConfig sourceConfig) { try { executorService.shutdownNow(); executorService = executorServiceSupplier.get(); - binlogClientFactory.setCredentials( + replicationLogClientFactory.setCredentials( sourceConfig.getAuthenticationConfig().getUsername(), sourceConfig.getAuthenticationConfig().getPassword()); refreshTask(sourceConfig); @@ -117,13 +118,22 @@ public void shutdown() { } private void refreshTask(RdsSourceConfig sourceConfig) { - final BinaryLogClient binaryLogClient = binlogClientFactory.create(); final DbTableMetadata dbTableMetadata = getDBTableMetadata(streamPartition); final CascadingActionDetector cascadeActionDetector = new CascadingActionDetector(sourceCoordinator); - binaryLogClient.registerEventListener(BinlogEventListener.create( - streamPartition, buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, - streamCheckpointer, acknowledgementSetManager, dbTableMetadata, cascadeActionDetector)); - final StreamWorker streamWorker = StreamWorker.create(sourceCoordinator, binaryLogClient, pluginMetrics); + + final ReplicationLogClient replicationLogClient = replicationLogClientFactory.create(streamPartition); + if (sourceConfig.getEngine() == EngineType.MYSQL) { + final BinaryLogClient binaryLogClient = ((BinlogClientWrapper) replicationLogClient).getBinlogClient(); + binaryLogClient.registerEventListener(BinlogEventListener.create( + streamPartition, buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, + streamCheckpointer, acknowledgementSetManager, dbTableMetadata, cascadeActionDetector)); + } else { + final LogicalReplicationClient logicalReplicationClient = (LogicalReplicationClient) replicationLogClient; + logicalReplicationClient.setEventProcessor(new LogicalReplicationEventProcessor( + streamPartition, sourceConfig, buffer, s3Prefix + )); + } + final StreamWorker streamWorker = StreamWorker.create(sourceCoordinator, replicationLogClient, pluginMetrics); executorService.submit(() -> streamWorker.processStream(streamPartition)); } diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/RdsServiceTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/RdsServiceTest.java index 102b57f508..afde56fb63 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/RdsServiceTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/RdsServiceTest.java @@ -22,6 +22,7 @@ import org.opensearch.dataprepper.model.plugin.PluginConfigObservable; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.configuration.TlsConfig; import org.opensearch.dataprepper.plugins.source.rds.export.DataFileScheduler; import org.opensearch.dataprepper.plugins.source.rds.export.ExportScheduler; @@ -92,6 +93,7 @@ class RdsServiceTest { @BeforeEach void setUp() { when(clientFactory.buildRdsClient()).thenReturn(rdsClient); + when(sourceConfig.getEngine()).thenReturn(EngineType.MYSQL); } @Test diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderSchedulerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderSchedulerTest.java index dbd21cbe4a..0277258ac9 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderSchedulerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderSchedulerTest.java @@ -20,7 +20,7 @@ import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.LeaderPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.LeaderProgressState; import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata; -import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager; +import org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager; import java.time.Duration; import java.util.Optional; @@ -49,7 +49,7 @@ class LeaderSchedulerTest { private RdsSourceConfig sourceConfig; @Mock - private SchemaManager schemaManager; + private MySqlSchemaManager schemaManager; @Mock private DbTableMetadata dbTableMetadata; diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerFactoryTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerFactoryTest.java new file mode 100644 index 0000000000..91d76fef04 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerFactoryTest.java @@ -0,0 +1,56 @@ +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; +import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; + +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class ConnectionManagerFactoryTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private RdsSourceConfig sourceConfig; + + @Mock + private DbMetadata dbMetadata; + + private ConnectionManagerFactory connectionManagerFactory; + + @BeforeEach + void setUp() { + connectionManagerFactory = createObjectUnderTest(); + } + + @Test + void test_getConnectionManager_for_mysql() { + when(sourceConfig.getEngine()).thenReturn(EngineType.MYSQL); + final ConnectionManager connectionManager = connectionManagerFactory.getConnectionManager(); + assertThat(connectionManager, notNullValue()); + assertThat(connectionManager, instanceOf(MySqlConnectionManager.class)); + } + + @Test + void test_getConnectionManager_for_postgres() { + when(sourceConfig.getEngine()).thenReturn(EngineType.POSTGRES); + when(sourceConfig.getTableNames()).thenReturn(List.of("schema1.table1", "schema1.table2")); + final ConnectionManager connectionManager = connectionManagerFactory.getConnectionManager(); + assertThat(connectionManager, notNullValue()); + assertThat(connectionManager, instanceOf(PostgresConnectionManager.class)); + } + + private ConnectionManagerFactory createObjectUnderTest() { + return new ConnectionManagerFactory(sourceConfig, dbMetadata); + } +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlConnectionManagerTest.java similarity index 81% rename from data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerTest.java rename to data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlConnectionManagerTest.java index 83c93d91c3..3c18f5e1ff 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/ConnectionManagerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlConnectionManagerTest.java @@ -20,16 +20,16 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.FALSE_VALUE; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.PASSWORD_KEY; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.REQUIRE_SSL_KEY; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.TINY_INT_ONE_IS_BIT_KEY; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.TRUE_VALUE; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.USERNAME_KEY; -import static org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager.USE_SSL_KEY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.FALSE_VALUE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.PASSWORD_KEY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.REQUIRE_SSL_KEY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.TINY_INT_ONE_IS_BIT_KEY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.TRUE_VALUE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.USERNAME_KEY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlConnectionManager.USE_SSL_KEY; -class ConnectionManagerTest { +class MySqlConnectionManagerTest { private String hostName; private int port; @@ -49,14 +49,14 @@ void setUp() { @Test void test_getConnection_when_requireSSL_is_true() throws SQLException { requireSSL = true; - final ConnectionManager connectionManager = spy(createObjectUnderTest()); + final MySqlConnectionManager connectionManager = spy(createObjectUnderTest()); final ArgumentCaptor jdbcUrlArgumentCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor propertiesArgumentCaptor = ArgumentCaptor.forClass(Properties.class); doReturn(mock(Connection.class)).when(connectionManager).doGetConnection(jdbcUrlArgumentCaptor.capture(), propertiesArgumentCaptor.capture()); connectionManager.getConnection(); - assertThat(jdbcUrlArgumentCaptor.getValue(), is(String.format(ConnectionManager.JDBC_URL_FORMAT, hostName, port))); + assertThat(jdbcUrlArgumentCaptor.getValue(), is(String.format(MySqlConnectionManager.JDBC_URL_FORMAT, hostName, port))); final Properties properties = propertiesArgumentCaptor.getValue(); assertThat(properties.getProperty(USERNAME_KEY), is(username)); assertThat(properties.getProperty(PASSWORD_KEY), is(password)); @@ -68,14 +68,14 @@ void test_getConnection_when_requireSSL_is_true() throws SQLException { @Test void test_getConnection_when_requireSSL_is_false() throws SQLException { requireSSL = false; - final ConnectionManager connectionManager = spy(createObjectUnderTest()); + final MySqlConnectionManager connectionManager = spy(createObjectUnderTest()); final ArgumentCaptor jdbcUrlArgumentCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor propertiesArgumentCaptor = ArgumentCaptor.forClass(Properties.class); doReturn(mock(Connection.class)).when(connectionManager).doGetConnection(jdbcUrlArgumentCaptor.capture(), propertiesArgumentCaptor.capture()); connectionManager.getConnection(); - assertThat(jdbcUrlArgumentCaptor.getValue(), is(String.format(ConnectionManager.JDBC_URL_FORMAT, hostName, port))); + assertThat(jdbcUrlArgumentCaptor.getValue(), is(String.format(MySqlConnectionManager.JDBC_URL_FORMAT, hostName, port))); final Properties properties = propertiesArgumentCaptor.getValue(); assertThat(properties.getProperty(USERNAME_KEY), is(username)); assertThat(properties.getProperty(PASSWORD_KEY), is(password)); @@ -83,7 +83,7 @@ void test_getConnection_when_requireSSL_is_false() throws SQLException { assertThat(properties.getProperty(TINY_INT_ONE_IS_BIT_KEY), is(FALSE_VALUE)); } - private ConnectionManager createObjectUnderTest() { - return new ConnectionManager(hostName, port, username, password, requireSSL); + private MySqlConnectionManager createObjectUnderTest() { + return new MySqlConnectionManager(hostName, port, username, password, requireSSL); } } \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlSchemaManagerTest.java similarity index 93% rename from data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java rename to data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlSchemaManagerTest.java index ce6af88009..3856cb3f00 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/MySqlSchemaManagerTest.java @@ -37,23 +37,23 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata.DOT_DELIMITER; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_FILE; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_POSITION; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_STATUS_QUERY; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.COLUMN_NAME; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.TYPE_NAME; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.DELETE_RULE; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.FKCOLUMN_NAME; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.FKTABLE_NAME; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.PKCOLUMN_NAME; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.PKTABLE_NAME; -import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.UPDATE_RULE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.BINLOG_FILE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.BINLOG_POSITION; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.BINLOG_STATUS_QUERY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.COLUMN_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.TYPE_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.DELETE_RULE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.FKCOLUMN_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.FKTABLE_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.PKCOLUMN_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.PKTABLE_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager.UPDATE_RULE; @ExtendWith(MockitoExtension.class) -class SchemaManagerTest { +class MySqlSchemaManagerTest { @Mock - private ConnectionManager connectionManager; + private MySqlConnectionManager connectionManager; @Mock(answer = Answers.RETURNS_DEEP_STUBS) private Connection connection; @@ -64,7 +64,7 @@ class SchemaManagerTest { @Mock private ResultSet resultSet; - private SchemaManager schemaManager; + private MySqlSchemaManager schemaManager; @BeforeEach void setUp() { @@ -81,7 +81,7 @@ void test_getPrimaryKeys_returns_primary_keys() throws SQLException { when(resultSet.next()).thenReturn(true, false); when(resultSet.getString(COLUMN_NAME)).thenReturn(primaryKey); - final List primaryKeys = schemaManager.getPrimaryKeys(databaseName, tableName); + final List primaryKeys = schemaManager.getPrimaryKeys(databaseName + "." + tableName); assertThat(primaryKeys, contains(primaryKey)); } @@ -92,7 +92,7 @@ void test_getPrimaryKeys_throws_exception_then_returns_empty_list() throws SQLEx final String tableName = UUID.randomUUID().toString(); when(connectionManager.getConnection()).thenThrow(SQLException.class); - final List primaryKeys = schemaManager.getPrimaryKeys(databaseName, tableName); + final List primaryKeys = schemaManager.getPrimaryKeys(databaseName + "." + tableName); assertThat(primaryKeys, empty()); } @@ -217,7 +217,7 @@ void test_getForeignKeyRelations_returns_foreign_key_relations() throws SQLExcep assertThat(foreignKeyRelation.getDeleteAction(), is(ForeignKeyAction.SET_NULL)); } - private SchemaManager createObjectUnderTest() { - return new SchemaManager(connectionManager); + private MySqlSchemaManager createObjectUnderTest() { + return new MySqlSchemaManager(connectionManager); } } \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresConnectionManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresConnectionManagerTest.java new file mode 100644 index 0000000000..e66e830684 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresConnectionManagerTest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.postgresql.PGProperty; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Properties; +import java.util.Random; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager.DATABASE_REPLICATION; +import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager.FALSE_VALUE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager.REQUIRE_SSL; +import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager.SERVER_VERSION_9_4; +import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager.SIMPLE_QUERY; +import static org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager.TRUE_VALUE; + +class PostgresConnectionManagerTest { + + private String endpoint; + private int port; + private String username; + private String password; + private boolean requireSSL; + private String database; + private final Random random = new Random(); + + @BeforeEach + void setUp() { + endpoint = UUID.randomUUID().toString(); + port = random.nextInt(65536); + username = UUID.randomUUID().toString(); + password = UUID.randomUUID().toString(); + } + + @Test + void test_getConnection_when_requireSSL_is_true() throws SQLException { + requireSSL = true; + final PostgresConnectionManager connectionManager = spy(createObjectUnderTest()); + final ArgumentCaptor jdbcUrlArgumentCaptor = ArgumentCaptor.forClass(String.class); + final ArgumentCaptor propertiesArgumentCaptor = ArgumentCaptor.forClass(Properties.class); + doReturn(mock(Connection.class)).when(connectionManager).doGetConnection(jdbcUrlArgumentCaptor.capture(), propertiesArgumentCaptor.capture()); + + connectionManager.getConnection(); + + assertThat(jdbcUrlArgumentCaptor.getValue(), is(String.format(PostgresConnectionManager.JDBC_URL_FORMAT, endpoint, port, database))); + final Properties properties = propertiesArgumentCaptor.getValue(); + assertThat(PGProperty.USER.getOrDefault(properties), is(username)); + assertThat(PGProperty.PASSWORD.getOrDefault(properties), is(password)); + assertThat(PGProperty.ASSUME_MIN_SERVER_VERSION.getOrDefault(properties), is(SERVER_VERSION_9_4)); + assertThat(PGProperty.REPLICATION.getOrDefault(properties), is(DATABASE_REPLICATION)); + assertThat(PGProperty.PREFER_QUERY_MODE.getOrDefault(properties), is(SIMPLE_QUERY)); + assertThat(PGProperty.SSL.getOrDefault(properties), is(TRUE_VALUE)); + assertThat(PGProperty.SSL_MODE.getOrDefault(properties), is(REQUIRE_SSL)); + } + + @Test + void test_getConnection_when_requireSSL_is_false() throws SQLException { + requireSSL = false; + final PostgresConnectionManager connectionManager = spy(createObjectUnderTest()); + final ArgumentCaptor jdbcUrlArgumentCaptor = ArgumentCaptor.forClass(String.class); + final ArgumentCaptor propertiesArgumentCaptor = ArgumentCaptor.forClass(Properties.class); + doReturn(mock(Connection.class)).when(connectionManager).doGetConnection(jdbcUrlArgumentCaptor.capture(), propertiesArgumentCaptor.capture()); + + connectionManager.getConnection(); + + assertThat(jdbcUrlArgumentCaptor.getValue(), is(String.format(PostgresConnectionManager.JDBC_URL_FORMAT, endpoint, port, database))); + final Properties properties = propertiesArgumentCaptor.getValue(); + assertThat(PGProperty.USER.getOrDefault(properties), is(username)); + assertThat(PGProperty.PASSWORD.getOrDefault(properties), is(password)); + assertThat(PGProperty.ASSUME_MIN_SERVER_VERSION.getOrDefault(properties), is(SERVER_VERSION_9_4)); + assertThat(PGProperty.REPLICATION.getOrDefault(properties), is(DATABASE_REPLICATION)); + assertThat(PGProperty.PREFER_QUERY_MODE.getOrDefault(properties), is(SIMPLE_QUERY)); + assertThat(PGProperty.SSL.getOrDefault(properties), is(FALSE_VALUE)); + } + + private PostgresConnectionManager createObjectUnderTest() { + return new PostgresConnectionManager(endpoint, port, username, password, requireSSL, database); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresSchemaManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresSchemaManagerTest.java new file mode 100644 index 0000000000..0602b2c6d5 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/PostgresSchemaManagerTest.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.postgresql.PGConnection; +import org.postgresql.replication.PGReplicationConnection; +import org.postgresql.replication.fluent.ChainedCreateReplicationSlotBuilder; +import org.postgresql.replication.fluent.logical.ChainedLogicalCreateSlotBuilder; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class PostgresSchemaManagerTest { + + @Mock + private PostgresConnectionManager connectionManager; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Connection connection; + + private PostgresSchemaManager schemaManager; + + @BeforeEach + void setUp() { + schemaManager = createObjectUnderTest(); + } + + @Test + void test_createLogicalReplicationSlot() throws SQLException { + final List tableNames = List.of("table1", "table2"); + final String publicationName = "publication1"; + final String slotName = "slot1"; + final PreparedStatement preparedStatement = mock(PreparedStatement.class); + final PGConnection pgConnection = mock(PGConnection.class); + final PGReplicationConnection replicationConnection = mock(PGReplicationConnection.class); + final ChainedCreateReplicationSlotBuilder chainedCreateSlotBuilder = mock(ChainedCreateReplicationSlotBuilder.class); + final ChainedLogicalCreateSlotBuilder slotBuilder = mock(ChainedLogicalCreateSlotBuilder.class); + + ArgumentCaptor statementCaptor = ArgumentCaptor.forClass(String.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.prepareStatement(statementCaptor.capture())).thenReturn(preparedStatement); + when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection); + when(pgConnection.getReplicationAPI()).thenReturn(replicationConnection); + when(replicationConnection.createReplicationSlot()).thenReturn(chainedCreateSlotBuilder); + when(chainedCreateSlotBuilder.logical()).thenReturn(slotBuilder); + when(slotBuilder.withSlotName(anyString())).thenReturn(slotBuilder); + when(slotBuilder.withOutputPlugin(anyString())).thenReturn(slotBuilder); + + schemaManager.createLogicalReplicationSlot(tableNames, publicationName, slotName); + + String statement = statementCaptor.getValue(); + assertThat(statement, is("CREATE PUBLICATION " + publicationName + " FOR TABLE " + String.join(", ", tableNames) + ";")); + verify(preparedStatement).executeUpdate(); + verify(pgConnection).getReplicationAPI(); + verify(replicationConnection).createReplicationSlot(); + verify(chainedCreateSlotBuilder).logical(); + verify(slotBuilder).withSlotName(slotName); + verify(slotBuilder).withOutputPlugin("pgoutput"); + verify(slotBuilder).make(); + } + + private PostgresSchemaManager createObjectUnderTest() { + return new PostgresSchemaManager(connectionManager); + } +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManagerTest.java index 9ed44e908c..86cc5d4431 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManagerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/QueryManagerTest.java @@ -32,7 +32,7 @@ class QueryManagerTest { @Mock - private ConnectionManager connectionManager; + private MySqlConnectionManager connectionManager; @Mock(answer = Answers.RETURNS_DEEP_STUBS) private Connection connection; diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerFactoryTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerFactoryTest.java new file mode 100644 index 0000000000..e98752ceeb --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerFactoryTest.java @@ -0,0 +1,48 @@ +package org.opensearch.dataprepper.plugins.source.rds.schema; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; + + +@ExtendWith(MockitoExtension.class) +class SchemaManagerFactoryTest { + + @Mock + private MySqlConnectionManager mySqlConnectionManager; + + @Mock + private PostgresConnectionManager postgresConnectionManager; + + private SchemaManagerFactory schemaManagerFactory; + private ConnectionManager connectionManager; + + @BeforeEach + void setUp() { + } + + @Test + void test_getSchemaManager_for_mysql() { + connectionManager = mySqlConnectionManager; + schemaManagerFactory = createObjectUnderTest(); + + assertThat(schemaManagerFactory.getSchemaManager(), instanceOf(MySqlSchemaManager.class)); + } + + @Test + void test_getSchemaManager_for_postgres() { + connectionManager = postgresConnectionManager; + schemaManagerFactory = createObjectUnderTest(); + + assertThat(schemaManagerFactory.getSchemaManager(), instanceOf(PostgresSchemaManager.class)); + } + + private SchemaManagerFactory createObjectUnderTest() { + return new SchemaManagerFactory(connectionManager); + } +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientFactoryTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientFactoryTest.java deleted file mode 100644 index c56ffd94d5..0000000000 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogClientFactoryTest.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.source.rds.stream; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; -import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; -import software.amazon.awssdk.services.rds.RdsClient; - -import java.util.UUID; - -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class BinlogClientFactoryTest { - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private RdsSourceConfig sourceConfig; - - @Mock - private RdsClient rdsClient; - - @Mock - private DbMetadata dbMetadata; - - private BinlogClientFactory binlogClientFactory; - - @Test - void test_create() { - final String username = UUID.randomUUID().toString(); - final String password = UUID.randomUUID().toString(); - when(sourceConfig.getAuthenticationConfig().getUsername()).thenReturn(username); - when(sourceConfig.getAuthenticationConfig().getPassword()).thenReturn(password); - - binlogClientFactory = createObjectUnderTest(); - binlogClientFactory.create(); - - verify(dbMetadata).getEndpoint(); - verify(dbMetadata).getPort(); - } - - private BinlogClientFactory createObjectUnderTest() { - return new BinlogClientFactory(sourceConfig, rdsClient, dbMetadata); - } -} \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java new file mode 100644 index 0000000000..5089f0cf67 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresConnectionManager; +import org.postgresql.PGConnection; +import org.postgresql.replication.LogSequenceNumber; +import org.postgresql.replication.PGReplicationStream; +import org.postgresql.replication.fluent.logical.ChainedLogicalStreamBuilder; + +import java.nio.ByteBuffer; +import java.sql.Connection; +import java.sql.SQLException; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import static org.awaitility.Awaitility.await; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class LogicalReplicationClientTest { + + @Mock + private PostgresConnectionManager connectionManager; + + @Mock + private LogicalReplicationEventProcessor eventProcessor; + + private String replicationSlotName; + private LogicalReplicationClient logicalReplicationClient; + + @BeforeEach + void setUp() { + replicationSlotName = UUID.randomUUID().toString(); + logicalReplicationClient = createObjectUnderTest(); + logicalReplicationClient.setEventProcessor(eventProcessor); + } + + @Test + void test_connect() throws SQLException, InterruptedException { + final Connection connection = mock(Connection.class); + final PGConnection pgConnection = mock(PGConnection.class, RETURNS_DEEP_STUBS); + final ChainedLogicalStreamBuilder logicalStreamBuilder = mock(ChainedLogicalStreamBuilder.class); + final PGReplicationStream stream = mock(PGReplicationStream.class); + final ByteBuffer message = mock(ByteBuffer.class); + final LogSequenceNumber lsn = mock(LogSequenceNumber.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection); + when(pgConnection.getReplicationAPI().replicationStream().logical()).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotName(anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotOption(anyString(), anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.start()).thenReturn(stream); + when(stream.readPending()).thenReturn(message).thenReturn(null); + when(stream.getLastReceiveLSN()).thenReturn(lsn); + + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + executorService.submit(() -> logicalReplicationClient.connect()); + + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor).process(message)); + Thread.sleep(20); + executorService.shutdownNow(); + + verify(stream).setAppliedLSN(lsn); + verify(stream).setFlushedLSN(lsn); + } + + private LogicalReplicationClient createObjectUnderTest() { + return new LogicalReplicationClient(connectionManager, replicationSlotName); + } +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java new file mode 100644 index 0000000000..22614e4f02 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class LogicalReplicationEventProcessorTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private StreamPartition streamPartition; + + @Mock + private RdsSourceConfig sourceConfig; + + @Mock + private Buffer> buffer; + + @Mock + private ByteBuffer message; + + private String s3Prefix; + + private LogicalReplicationEventProcessor objectUnderTest; + + @BeforeEach + void setUp() { + s3Prefix = UUID.randomUUID().toString(); + + objectUnderTest = spy(createObjectUnderTest()); + } + + @Test + void test_correct_process_method_invoked_for_begin_message() { + when(message.get()).thenReturn((byte) 'B'); + + objectUnderTest.process(message); + + verify(objectUnderTest).processBeginMessage(message); + } + + @Test + void test_correct_process_method_invoked_for_relation_message() { + when(message.get()).thenReturn((byte) 'R'); + final StreamProgressState progressState = mock(StreamProgressState.class); + when(streamPartition.getProgressState()).thenReturn(Optional.of(progressState)); + when(sourceConfig.getTableNames()).thenReturn(List.of("database.schema.table1")); + when(progressState.getPrimaryKeyMap()).thenReturn(Map.of("database.schema.table1", List.of("key1", "key2"))); + + objectUnderTest.process(message); + + verify(objectUnderTest).processRelationMessage(message); + } + + @Test + void test_correct_process_method_invoked_for_commit_message() { + when(message.get()).thenReturn((byte) 'C'); + + objectUnderTest.process(message); + + verify(objectUnderTest).processCommitMessage(message); + } + + @Test + void test_correct_process_method_invoked_for_insert_message() { + when(message.get()).thenReturn((byte) 'I'); + doNothing().when(objectUnderTest).processInsertMessage(message); + + objectUnderTest.process(message); + + verify(objectUnderTest).processInsertMessage(message); + } + + @Test + void test_correct_process_method_invoked_for_update_message() { + when(message.get()).thenReturn((byte) 'U'); + doNothing().when(objectUnderTest).processUpdateMessage(message); + + objectUnderTest.process(message); + + verify(objectUnderTest).processUpdateMessage(message); + } + + @Test + void test_correct_process_method_invoked_for_delete_message() { + when(message.get()).thenReturn((byte) 'D'); + doNothing().when(objectUnderTest).processDeleteMessage(message); + + objectUnderTest.process(message); + + verify(objectUnderTest).processDeleteMessage(message); + } + + @Test + void test_unsupported_message_type_throws_exception() { + when(message.get()).thenReturn((byte) 'A'); + + assertThrows(IllegalArgumentException.class, () -> objectUnderTest.process(message)); + } + + private LogicalReplicationEventProcessor createObjectUnderTest() { + return new LogicalReplicationEventProcessor(streamPartition, sourceConfig, buffer, s3Prefix); + } +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactoryTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactoryTest.java new file mode 100644 index 0000000000..43978eaea4 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactoryTest.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.rds.stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; +import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; +import software.amazon.awssdk.services.rds.RdsClient; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class ReplicationLogClientFactoryTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private RdsSourceConfig sourceConfig; + + @Mock + private RdsClient rdsClient; + + @Mock + private DbMetadata dbMetadata; + + @Mock + private StreamPartition streamPartition; + + private ReplicationLogClientFactory replicationLogClientFactory; + + @Test + void test_create_binlog_client() { + final String username = UUID.randomUUID().toString(); + final String password = UUID.randomUUID().toString(); + + when(sourceConfig.getEngine()).thenReturn(EngineType.MYSQL); + when(sourceConfig.getAuthenticationConfig().getUsername()).thenReturn(username); + when(sourceConfig.getAuthenticationConfig().getPassword()).thenReturn(password); + + replicationLogClientFactory = createObjectUnderTest(); + ReplicationLogClient replicationLogClient = replicationLogClientFactory.create(streamPartition); + + verify(dbMetadata).getEndpoint(); + verify(dbMetadata).getPort(); + assertThat(replicationLogClient, instanceOf(BinlogClientWrapper.class)); + } + + @Test + void test_create_logical_replication_client() { + final String username = UUID.randomUUID().toString(); + final String password = UUID.randomUUID().toString(); + final StreamProgressState streamProgressState = mock(StreamProgressState.class); + final String slotName = UUID.randomUUID().toString(); + final List tableNames = List.of("table1", "table2"); + + when(sourceConfig.getEngine()).thenReturn(EngineType.POSTGRES); + when(sourceConfig.isTlsEnabled()).thenReturn(true); + when(sourceConfig.getTableNames()).thenReturn(tableNames); + when(sourceConfig.getAuthenticationConfig().getUsername()).thenReturn(username); + when(sourceConfig.getAuthenticationConfig().getPassword()).thenReturn(password); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamProgressState.getReplicationSlotName()).thenReturn(slotName); + + replicationLogClientFactory = createObjectUnderTest(); + ReplicationLogClient replicationLogClient = replicationLogClientFactory.create(streamPartition); + + verify(dbMetadata).getEndpoint(); + verify(dbMetadata).getPort(); + assertThat(replicationLogClient, instanceOf(LogicalReplicationClient.class)); + } + + private ReplicationLogClientFactory createObjectUnderTest() { + return new ReplicationLogClientFactory(sourceConfig, rdsClient, dbMetadata); + } +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamSchedulerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamSchedulerTest.java index fc0ac8268a..325be414d9 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamSchedulerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamSchedulerTest.java @@ -51,7 +51,7 @@ class StreamSchedulerTest { private RdsSourceConfig sourceConfig; @Mock - private BinlogClientFactory binlogClientFactory; + private ReplicationLogClientFactory replicationLogClientFactory; @Mock private PluginMetrics pluginMetrics; @@ -88,7 +88,7 @@ void test_given_no_stream_partition_then_no_stream_actions() throws InterruptedE Thread.sleep(100); executorService.shutdownNow(); - verifyNoInteractions(binlogClientFactory, pluginConfigObservable); + verifyNoInteractions(replicationLogClientFactory, pluginConfigObservable); } @Test @@ -100,7 +100,7 @@ void test_given_stream_partition_then_start_stream() throws InterruptedException executorService.submit(() -> { try (MockedStatic streamWorkerTaskRefresherMockedStatic = mockStatic(StreamWorkerTaskRefresher.class)) { streamWorkerTaskRefresherMockedStatic.when(() -> StreamWorkerTaskRefresher.create(eq(sourceCoordinator), eq(streamPartition), any(StreamCheckpointer.class), - eq(s3Prefix), eq(binlogClientFactory), eq(buffer), any(Supplier.class), eq(acknowledgementSetManager), eq(pluginMetrics))) + eq(s3Prefix), eq(replicationLogClientFactory), eq(buffer), any(Supplier.class), eq(acknowledgementSetManager), eq(pluginMetrics))) .thenReturn(streamWorkerTaskRefresher); objectUnderTest.run(); } @@ -129,6 +129,6 @@ void test_shutdown() throws InterruptedException { private StreamScheduler createObjectUnderTest() { return new StreamScheduler( - sourceCoordinator, sourceConfig, s3Prefix, binlogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable); + sourceCoordinator, sourceConfig, s3Prefix, replicationLogClientFactory, buffer, pluginMetrics, acknowledgementSetManager, pluginConfigObservable); } } \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java index 7647a5c008..51c5c77636 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java @@ -22,6 +22,7 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata; @@ -64,10 +65,16 @@ class StreamWorkerTaskRefresherTest { private StreamCheckpointer streamCheckpointer; @Mock - private BinlogClientFactory binlogClientFactory; + private ReplicationLogClientFactory replicationLogClientFactory; @Mock - private BinaryLogClient binlogClient; + private ReplicationLogClient replicationLogClient; + + @Mock + private BinlogClientWrapper binaryLogClientWrapper; + + @Mock + private BinaryLogClient binaryLogClient; @Mock private Buffer> buffer; @@ -112,28 +119,30 @@ void setUp() { when(pluginMetrics.counter(CREDENTIALS_CHANGED)).thenReturn(credentialsChangeCounter); when(pluginMetrics.counter(TASK_REFRESH_ERRORS)).thenReturn(taskRefreshErrorsCounter); when(executorServiceSupplier.get()).thenReturn(executorService).thenReturn(newExecutorService); + when(sourceConfig.getEngine()).thenReturn(EngineType.MYSQL); streamWorkerTaskRefresher = createObjectUnderTest(); } @Test void test_initialize_then_process_stream() { - when(binlogClientFactory.create()).thenReturn(binlogClient); + when(replicationLogClientFactory.create(streamPartition)).thenReturn(binaryLogClientWrapper); + when(binaryLogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient); final Map progressState = mockGlobalStateAndProgressState(); try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class); MockedStatic binlogEventListenerMockedStatic = mockStatic(BinlogEventListener.class); MockedStatic dbTableMetadataMockedStatic = mockStatic(DbTableMetadata.class)) { dbTableMetadataMockedStatic.when(() -> DbTableMetadata.fromMap(progressState)).thenReturn(dbTableMetadata); - streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(BinaryLogClient.class), eq(pluginMetrics))) + streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(ReplicationLogClient.class), eq(pluginMetrics))) .thenReturn(streamWorker); binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(streamPartition), eq(buffer), any(RdsSourceConfig.class), - any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), + any(String.class), eq(pluginMetrics), eq(binaryLogClient), eq(streamCheckpointer), eq(acknowledgementSetManager), eq(dbTableMetadata), any(CascadingActionDetector.class))) .thenReturn(binlogEventListener); streamWorkerTaskRefresher.initialize(sourceConfig); } - verify(binlogClientFactory).create(); - verify(binlogClient).registerEventListener(binlogEventListener); + verify(replicationLogClientFactory).create(streamPartition); + verify(binaryLogClient).registerEventListener(binlogEventListener); ArgumentCaptor runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class); verify(executorService).submit(runnableArgumentCaptor.capture()); @@ -154,17 +163,19 @@ void test_update_when_credentials_changed_then_refresh_task() { final String password2 = UUID.randomUUID().toString(); when(sourceConfig2.getAuthenticationConfig().getUsername()).thenReturn(username); when(sourceConfig2.getAuthenticationConfig().getPassword()).thenReturn(password2); + when(sourceConfig2.getEngine()).thenReturn(EngineType.MYSQL); - when(binlogClientFactory.create()).thenReturn(binlogClient).thenReturn(binlogClient); + when(replicationLogClientFactory.create(streamPartition)).thenReturn(binaryLogClientWrapper).thenReturn(binaryLogClientWrapper); + when(binaryLogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient); final Map progressState = mockGlobalStateAndProgressState(); try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class); MockedStatic binlogEventListenerMockedStatic = mockStatic(BinlogEventListener.class); MockedStatic dbTableMetadataMockedStatic = mockStatic(DbTableMetadata.class)) { dbTableMetadataMockedStatic.when(() -> DbTableMetadata.fromMap(progressState)).thenReturn(dbTableMetadata); - streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(BinaryLogClient.class), eq(pluginMetrics))) + streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(ReplicationLogClient.class), eq(pluginMetrics))) .thenReturn(streamWorker); binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(streamPartition), eq(buffer), any(RdsSourceConfig.class), - any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), + any(String.class), eq(pluginMetrics), eq(binaryLogClient), eq(streamCheckpointer), eq(acknowledgementSetManager), eq(dbTableMetadata), any(CascadingActionDetector.class))) .thenReturn(binlogEventListener); streamWorkerTaskRefresher.initialize(sourceConfig); @@ -174,8 +185,8 @@ void test_update_when_credentials_changed_then_refresh_task() { verify(credentialsChangeCounter).increment(); verify(executorService).shutdownNow(); - verify(binlogClientFactory, times(2)).create(); - verify(binlogClient, times(2)).registerEventListener(binlogEventListener); + verify(replicationLogClientFactory, times(2)).create(streamPartition); + verify(binaryLogClient, times(2)).registerEventListener(binlogEventListener); ArgumentCaptor runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class); verify(newExecutorService).submit(runnableArgumentCaptor.capture()); @@ -192,16 +203,17 @@ void test_update_when_credentials_unchanged_then_do_nothing() { when(sourceConfig.getAuthenticationConfig().getUsername()).thenReturn(username); when(sourceConfig.getAuthenticationConfig().getPassword()).thenReturn(password); - when(binlogClientFactory.create()).thenReturn(binlogClient); + when(replicationLogClientFactory.create(streamPartition)).thenReturn(binaryLogClientWrapper); + when(binaryLogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient); final Map progressState = mockGlobalStateAndProgressState(); try (MockedStatic streamWorkerMockedStatic = mockStatic(StreamWorker.class); MockedStatic binlogEventListenerMockedStatic = mockStatic(BinlogEventListener.class); MockedStatic dbTableMetadataMockedStatic = mockStatic(DbTableMetadata.class)) { dbTableMetadataMockedStatic.when(() -> DbTableMetadata.fromMap(progressState)).thenReturn(dbTableMetadata); - streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(BinaryLogClient.class), eq(pluginMetrics))) + streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(ReplicationLogClient.class), eq(pluginMetrics))) .thenReturn(streamWorker); binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(streamPartition), eq(buffer), any(RdsSourceConfig.class), - any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), + any(String.class), eq(pluginMetrics), eq(binaryLogClient), eq(streamCheckpointer), eq(acknowledgementSetManager), eq(dbTableMetadata), any(CascadingActionDetector.class))) .thenReturn(binlogEventListener); streamWorkerTaskRefresher.initialize(sourceConfig); @@ -222,7 +234,7 @@ private StreamWorkerTaskRefresher createObjectUnderTest() { final String s3Prefix = UUID.randomUUID().toString(); return new StreamWorkerTaskRefresher( - sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, binlogClientFactory, buffer, executorServiceSupplier, acknowledgementSetManager, pluginMetrics); + sourceCoordinator, streamPartition, streamCheckpointer, s3Prefix, replicationLogClientFactory, buffer, executorServiceSupplier, acknowledgementSetManager, pluginMetrics); } private Map mockGlobalStateAndProgressState() { diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTest.java index ecc7d86d47..1eaf719cf5 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTest.java @@ -32,6 +32,9 @@ class StreamWorkerTest { @Mock private EnhancedSourceCoordinator sourceCoordinator; + @Mock + private BinlogClientWrapper binlogClientWrapper; + @Mock private BinaryLogClient binaryLogClient; @@ -56,12 +59,13 @@ void test_processStream_with_given_binlog_coordinates() throws IOException { when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); when(streamProgressState.getCurrentPosition()).thenReturn(new BinlogCoordinate(binlogFilename, binlogPosition)); when(streamProgressState.shouldWaitForExport()).thenReturn(false); + when(binlogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient); streamWorker.processStream(streamPartition); verify(binaryLogClient).setBinlogFilename(binlogFilename); verify(binaryLogClient).setBinlogPosition(binlogPosition); - verify(binaryLogClient).connect(); + verify(binlogClientWrapper).connect(); } @Test @@ -69,7 +73,7 @@ void test_processStream_without_current_binlog_coordinates() throws IOException StreamProgressState streamProgressState = mock(StreamProgressState.class); when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); final String binlogFilename = "binlog-001"; - final Long binlogPosition = 100L; + final long binlogPosition = 100L; when(streamProgressState.getCurrentPosition()).thenReturn(null); when(streamProgressState.shouldWaitForExport()).thenReturn(false); @@ -77,10 +81,10 @@ void test_processStream_without_current_binlog_coordinates() throws IOException verify(binaryLogClient, never()).setBinlogFilename(binlogFilename); verify(binaryLogClient, never()).setBinlogPosition(binlogPosition); - verify(binaryLogClient).connect(); + verify(binlogClientWrapper).connect(); } private StreamWorker createObjectUnderTest() { - return new StreamWorker(sourceCoordinator, binaryLogClient, pluginMetrics); + return new StreamWorker(sourceCoordinator, binlogClientWrapper, pluginMetrics); } } \ No newline at end of file