diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
index acf55812b..6e8b31810 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
@@ -16,13 +16,16 @@
package io.asyncer.r2dbc.mysql;
+import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.codec.CodecContext;
import io.asyncer.r2dbc.mysql.collation.CharCollation;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
+import io.r2dbc.spi.IsolationLevel;
import org.jetbrains.annotations.Nullable;
import java.nio.file.Path;
+import java.time.Duration;
import java.time.ZoneId;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
@@ -37,6 +40,10 @@ public final class ConnectionContext implements CodecContext {
private static final ServerVersion NONE_VERSION = ServerVersion.create(0, 0, 0);
+ private static final ServerVersion MYSQL_5_7_4 = ServerVersion.create(5, 7, 4);
+
+ private static final ServerVersion MARIA_10_1_1 = ServerVersion.create(10, 1, 1, true);
+
private final ZeroDateOption zeroDateOption;
@Nullable
@@ -52,16 +59,47 @@ public final class ConnectionContext implements CodecContext {
private Capability capability = Capability.DEFAULT;
+ private PrepareCache prepareCache;
+
@Nullable
private ZoneId timeZone;
+ private String product = "Unknown";
+
+ /**
+ * Current isolation level inferred by past statements.
+ *
+ * Inference rules:
+ *
- In the beginning, it is also {@link #sessionIsolationLevel}.
+ * - A transaction has began with a {@link IsolationLevel}, it will be changed to the value
+ * - The transaction end (commit or rollback), it will recover to {@link #sessionIsolationLevel}.
+ */
+ private volatile IsolationLevel currentIsolationLevel;
+
+ /**
+ * Session isolation level.
+ *
+ * - It is applied to all subsequent transactions performed within the current session.
+ * - Calls {@link io.r2dbc.spi.Connection#setTransactionIsolationLevel}, it will change to the value.
+ * - It can be changed within transactions, but does not affect the current ongoing transaction.
+ */
+ private volatile IsolationLevel sessionIsolationLevel;
+
private boolean lockWaitTimeoutSupported = false;
+ /**
+ * Current lock wait timeout in seconds.
+ */
+ private volatile Duration currentLockWaitTimeout;
+
+ /**
+ * Session lock wait timeout in seconds.
+ */
+ private volatile Duration sessionLockWaitTimeout;
+
/**
* Assume that the auto commit is always turned on, it will be set after handshake V10 request message, or OK
* message which means handshake V9 completed.
- *
- * It would be updated multiple times, so {@code volatile} is required.
*/
private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT;
@@ -80,18 +118,50 @@ public final class ConnectionContext implements CodecContext {
}
/**
- * Initializes this context.
+ * Initializes handshake information after connection is established.
*
* @param connectionId the connection identifier that is specified by server.
* @param version the server version.
* @param capability the connection capabilities.
*/
- void init(int connectionId, ServerVersion version, Capability capability) {
+ void initHandshake(int connectionId, ServerVersion version, Capability capability) {
this.connectionId = connectionId;
this.serverVersion = version;
this.capability = capability;
}
+ /**
+ * Initializes session information after logged-in.
+ *
+ * @param prepareCache the prepare cache.
+ * @param isolationLevel the session isolation level.
+ * @param lockWaitTimeoutSupported if the server supports lock wait timeout.
+ * @param lockWaitTimeout the lock wait timeout.
+ * @param product the server product name.
+ * @param timeZone the server timezone.
+ */
+ void initSession(
+ PrepareCache prepareCache,
+ IsolationLevel isolationLevel,
+ boolean lockWaitTimeoutSupported,
+ Duration lockWaitTimeout,
+ @Nullable String product,
+ @Nullable ZoneId timeZone
+ ) {
+ this.prepareCache = prepareCache;
+ this.currentIsolationLevel = this.sessionIsolationLevel = isolationLevel;
+ this.lockWaitTimeoutSupported = lockWaitTimeoutSupported;
+ this.currentLockWaitTimeout = this.sessionLockWaitTimeout = lockWaitTimeout;
+ this.product = product == null ? "Unknown" : product;
+
+ if (timeZone != null) {
+ if (isTimeZoneInitialized()) {
+ throw new IllegalStateException("Connection timezone have been initialized");
+ }
+ this.timeZone = timeZone;
+ }
+ }
+
/**
* Get the connection identifier that is specified by server.
*
@@ -128,6 +198,14 @@ public ZoneId getTimeZone() {
return timeZone;
}
+ String getProduct() {
+ return product;
+ }
+
+ PrepareCache getPrepareCache() {
+ return prepareCache;
+ }
+
boolean isTimeZoneInitialized() {
return timeZone != null;
}
@@ -138,13 +216,6 @@ public boolean isMariaDb() {
return (capability != null && capability.isMariaDb()) || serverVersion.isMariaDb();
}
- void initTimeZone(ZoneId timeZone) {
- if (isTimeZoneInitialized()) {
- throw new IllegalStateException("Connection timezone have been initialized");
- }
- this.timeZone = timeZone;
- }
-
@Override
public ZeroDateOption getZeroDateOption() {
return zeroDateOption;
@@ -170,19 +241,23 @@ public int getLocalInfileBufferSize() {
}
/**
- * Checks if the server supports lock wait timeout.
+ * Checks if the server supports InnoDB lock wait timeout.
*
- * @return if the server supports lock wait timeout.
+ * @return if the server supports InnoDB lock wait timeout.
*/
public boolean isLockWaitTimeoutSupported() {
return lockWaitTimeoutSupported;
}
/**
- * Enables lock wait timeout supported when loading session variables.
+ * Checks if the server supports statement timeout.
+ *
+ * @return if the server supports statement timeout.
*/
- void enableLockWaitTimeoutSupported() {
- this.lockWaitTimeoutSupported = true;
+ public boolean isStatementTimeoutSupported() {
+ boolean isMariaDb = isMariaDb();
+ return (isMariaDb && serverVersion.isGreaterThanOrEqualTo(MARIA_10_1_1)) ||
+ (!isMariaDb && serverVersion.isGreaterThanOrEqualTo(MYSQL_5_7_4));
}
/**
@@ -202,4 +277,48 @@ public short getServerStatuses() {
public void setServerStatuses(short serverStatuses) {
this.serverStatuses = serverStatuses;
}
+
+ IsolationLevel getCurrentIsolationLevel() {
+ return currentIsolationLevel;
+ }
+
+ void setCurrentIsolationLevel(IsolationLevel isolationLevel) {
+ this.currentIsolationLevel = isolationLevel;
+ }
+
+ void resetCurrentIsolationLevel() {
+ this.currentIsolationLevel = this.sessionIsolationLevel;
+ }
+
+ IsolationLevel getSessionIsolationLevel() {
+ return sessionIsolationLevel;
+ }
+
+ void setSessionIsolationLevel(IsolationLevel isolationLevel) {
+ this.sessionIsolationLevel = isolationLevel;
+ }
+
+ void setCurrentLockWaitTimeout(Duration timeoutSeconds) {
+ this.currentLockWaitTimeout = timeoutSeconds;
+ }
+
+ void resetCurrentLockWaitTimeout() {
+ this.currentLockWaitTimeout = this.sessionLockWaitTimeout;
+ }
+
+ boolean isLockWaitTimeoutChanged() {
+ return currentLockWaitTimeout != sessionLockWaitTimeout;
+ }
+
+ Duration getSessionLockWaitTimeout() {
+ return sessionLockWaitTimeout;
+ }
+
+ void setAllLockWaitTimeout(Duration timeoutSeconds) {
+ this.currentLockWaitTimeout = this.sessionLockWaitTimeout = timeoutSeconds;
+ }
+
+ boolean isInTransaction() {
+ return (serverStatuses & ServerStatuses.IN_TRANSACTION) != 0;
+ }
}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java
deleted file mode 100644
index 73a9caf09..000000000
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Copyright 2023 asyncer.io projects
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.asyncer.r2dbc.mysql;
-
-import io.r2dbc.spi.IsolationLevel;
-
-/**
- * An internal interface for check, set and reset connection states.
- */
-interface ConnectionState {
-
- /**
- * Sets current isolation level.
- *
- * @param level current level.
- */
- void setIsolationLevel(IsolationLevel level);
-
- /**
- * Returns session lock wait timeout.
- *
- * @return Session lock wait timeout.
- */
- long getSessionLockWaitTimeout();
-
- /**
- * Sets current lock wait timeout.
- *
- * @param timeoutSeconds seconds of current lock wait timeout.
- */
- void setCurrentLockWaitTimeout(long timeoutSeconds);
-
- /**
- * Checks if lock wait timeout has been changed by {@link #setCurrentLockWaitTimeout(long)}.
- *
- * @return if lock wait timeout changed.
- */
- boolean isLockWaitTimeoutChanged();
-
- /**
- * Resets current isolation level in initial state.
- */
- void resetIsolationLevel();
-
- /**
- * Resets current isolation level in initial state.
- */
- void resetCurrentLockWaitTimeout();
-
- /**
- * Checks if connection is processing a transaction.
- *
- * @return if in a transaction.
- */
- boolean isInTransaction();
-}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java
new file mode 100644
index 000000000..32dcc1c8a
--- /dev/null
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java
@@ -0,0 +1,747 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql;
+
+import io.asyncer.r2dbc.mysql.api.MySqlResult;
+import io.asyncer.r2dbc.mysql.authentication.MySqlAuthProvider;
+import io.asyncer.r2dbc.mysql.cache.Caches;
+import io.asyncer.r2dbc.mysql.cache.PrepareCache;
+import io.asyncer.r2dbc.mysql.client.Client;
+import io.asyncer.r2dbc.mysql.client.FluxExchangeable;
+import io.asyncer.r2dbc.mysql.codec.Codecs;
+import io.asyncer.r2dbc.mysql.codec.CodecsBuilder;
+import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm;
+import io.asyncer.r2dbc.mysql.constant.SslMode;
+import io.asyncer.r2dbc.mysql.extension.CodecRegistrar;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
+import io.asyncer.r2dbc.mysql.message.client.AuthResponse;
+import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
+import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse;
+import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
+import io.asyncer.r2dbc.mysql.message.client.SslRequest;
+import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage;
+import io.asyncer.r2dbc.mysql.message.server.AuthMoreDataMessage;
+import io.asyncer.r2dbc.mysql.message.server.ChangeAuthMessage;
+import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
+import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
+import io.asyncer.r2dbc.mysql.message.server.HandshakeHeader;
+import io.asyncer.r2dbc.mysql.message.server.HandshakeRequest;
+import io.asyncer.r2dbc.mysql.message.server.OkMessage;
+import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
+import io.asyncer.r2dbc.mysql.message.server.SyntheticSslResponseMessage;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.util.ReferenceCountUtil;
+import io.netty.util.internal.logging.InternalLogger;
+import io.netty.util.internal.logging.InternalLoggerFactory;
+import io.r2dbc.spi.IsolationLevel;
+import io.r2dbc.spi.R2dbcNonTransientResourceException;
+import io.r2dbc.spi.R2dbcPermissionDeniedException;
+import io.r2dbc.spi.Readable;
+import org.jetbrains.annotations.Nullable;
+import reactor.core.CoreSubscriber;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.core.publisher.SynchronousSink;
+import reactor.util.concurrent.Queues;
+
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+import java.time.DateTimeException;
+import java.time.Duration;
+import java.time.ZoneId;
+import java.time.ZoneOffset;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+/**
+ * A message flow utility that can initializes the session of {@link Client}.
+ *
+ * It should not use server-side prepared statements, because {@link PrepareCache} will be initialized after the session
+ * is initialized.
+ */
+final class InitFlow {
+
+ private static final InternalLogger logger = InternalLoggerFactory.getInstance(InitFlow.class);
+
+ private static final ServerVersion MARIA_11_1_1 = ServerVersion.create(11, 1, 1, true);
+
+ private static final ServerVersion MYSQL_8_0_3 = ServerVersion.create(8, 0, 3);
+
+ private static final ServerVersion MYSQL_5_7_20 = ServerVersion.create(5, 7, 20);
+
+ private static final ServerVersion MYSQL_8 = ServerVersion.create(8, 0, 0);
+
+ private static final BiConsumer> INIT_DB = (message, sink) -> {
+ if (message instanceof ErrorMessage) {
+ ErrorMessage msg = (ErrorMessage) message;
+ logger.debug("Use database failed: [{}] [{}] {}", msg.getCode(), msg.getSqlState(), msg.getMessage());
+ sink.next(false);
+ sink.complete();
+ } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
+ sink.next(true);
+ sink.complete();
+ } else {
+ ReferenceCountUtil.safeRelease(message);
+ }
+ };
+
+ private static final BiConsumer> INIT_DB_AFTER = (message, sink) -> {
+ if (message instanceof ErrorMessage) {
+ sink.error(((ErrorMessage) message).toException());
+ } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
+ sink.complete();
+ } else {
+ ReferenceCountUtil.safeRelease(message);
+ }
+ };
+
+ /**
+ * Initializes handshake and login a {@link Client}.
+ *
+ * @param client the {@link Client} to exchange messages with.
+ * @param sslMode the {@link SslMode} defines SSL capability and behavior.
+ * @param database the database that will be connected.
+ * @param user the user that will be login.
+ * @param password the password of the {@code user}.
+ * @param compressionAlgorithms the list of compression algorithms.
+ * @param zstdCompressionLevel the zstd compression level.
+ * @return a {@link Flux} that indicates the initialization is done, or an error if the initialization failed.
+ */
+ static Flux initHandshake(Client client, SslMode sslMode, String database, String user,
+ @Nullable CharSequence password, Set compressionAlgorithms, int zstdCompressionLevel) {
+ return client.exchange(new HandshakeExchangeable(client, sslMode, database, user, password,
+ compressionAlgorithms, zstdCompressionLevel));
+ }
+
+ /**
+ * Initializes the session and {@link Codecs} of a {@link Client}.
+ *
+ * @param client the client
+ * @param database the database to use after session initialization
+ * @param prepareCacheSize the size of prepare cache
+ * @param sessionVariables the session variables to set
+ * @param forceTimeZone if the timezone should be set to session
+ * @param lockWaitTimeout the lock wait timeout that should be set to session
+ * @param statementTimeout the statement timeout that should be set to session
+ * @return a {@link Mono} that indicates the {@link Codecs}, or an error if the initialization failed
+ */
+ static Mono initSession(
+ Client client,
+ String database,
+ int prepareCacheSize,
+ List sessionVariables,
+ boolean forceTimeZone,
+ @Nullable Duration lockWaitTimeout,
+ @Nullable Duration statementTimeout,
+ Extensions extensions
+ ) {
+ return Mono.defer(() -> {
+ ByteBufAllocator allocator = client.getByteBufAllocator();
+ CodecsBuilder builder = Codecs.builder();
+
+ extensions.forEach(CodecRegistrar.class, registrar ->
+ registrar.register(allocator, builder));
+
+ Codecs codecs = builder.build();
+
+ List variables = mergeSessionVariables(client, sessionVariables, forceTimeZone, statementTimeout);
+
+ logger.debug("Initializing client session: {}", variables);
+
+ return QueryFlow.setSessionVariables(client, variables)
+ .then(loadSessionVariables(client, codecs))
+ .flatMap(data -> loadAndInitInnoDbEngineStatus(data, client, codecs, lockWaitTimeout))
+ .flatMap(data -> {
+ ConnectionContext context = client.getContext();
+
+ logger.debug("Initializing connection {} context: {}", context.getConnectionId(), data);
+ context.initSession(
+ Caches.createPrepareCache(prepareCacheSize),
+ data.level,
+ data.lockWaitTimeoutSupported,
+ data.lockWaitTimeout,
+ data.product,
+ data.timeZone
+ );
+
+ if (!data.lockWaitTimeoutSupported) {
+ logger.info(
+ "Lock wait timeout is not supported by server, all related operations will be ignored");
+ }
+
+ return database.isEmpty() ? Mono.just(codecs) :
+ initDatabase(client, database).then(Mono.just(codecs));
+ });
+ });
+ }
+
+ private static Mono loadAndInitInnoDbEngineStatus(
+ SessionState data,
+ Client client,
+ Codecs codecs,
+ @Nullable Duration lockWaitTimeout
+ ) {
+ return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\\\_lock\\\\_wait\\\\_timeout'")
+ .execute()
+ .flatMap(r -> r.map(readable -> {
+ String value = readable.get(1, String.class);
+
+ if (value == null || value.isEmpty()) {
+ return data;
+ } else {
+ return data.lockWaitTimeout(Duration.ofSeconds(Long.parseLong(value)));
+ }
+ }))
+ .single(data)
+ .flatMap(d -> {
+ if (lockWaitTimeout != null) {
+ // Do not use context.isLockWaitTimeoutSupported() here, because its session variable is not set
+ if (d.lockWaitTimeoutSupported) {
+ return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(lockWaitTimeout))
+ .then(Mono.fromSupplier(() -> d.lockWaitTimeout(lockWaitTimeout)));
+ }
+
+ logger.warn("Lock wait timeout is not supported by server, ignore initial setting");
+ return Mono.just(d);
+ }
+ return Mono.just(d);
+ });
+ }
+
+ private static Mono loadSessionVariables(Client client, Codecs codecs) {
+ ConnectionContext context = client.getContext();
+ StringBuilder query = new StringBuilder(128)
+ .append("SELECT ")
+ .append(transactionIsolationColumn(context))
+ .append(",@@version_comment AS v");
+
+ Function> handler;
+
+ if (context.isTimeZoneInitialized()) {
+ handler = r -> convertSessionData(r, false);
+ } else {
+ query.append(",@@system_time_zone AS s,@@time_zone AS t");
+ handler = r -> convertSessionData(r, true);
+ }
+
+ return new TextSimpleStatement(client, codecs, query.toString())
+ .execute()
+ .flatMap(handler)
+ .last();
+ }
+
+ private static Mono initDatabase(Client client, String database) {
+ return client.exchange(new InitDbMessage(database), INIT_DB)
+ .last()
+ .flatMap(success -> {
+ if (success) {
+ return Mono.empty();
+ }
+
+ String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database);
+
+ return QueryFlow.executeVoid(client, sql)
+ .then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then());
+ });
+ }
+
+ private static List mergeSessionVariables(
+ Client client,
+ List sessionVariables,
+ boolean forceTimeZone,
+ @Nullable Duration statementTimeout
+ ) {
+ ConnectionContext context = client.getContext();
+
+ if ((!forceTimeZone || !context.isTimeZoneInitialized()) && statementTimeout == null) {
+ return sessionVariables;
+ }
+
+ List variables = new ArrayList<>(sessionVariables.size() + 2);
+
+ variables.addAll(sessionVariables);
+
+ if (forceTimeZone && context.isTimeZoneInitialized()) {
+ variables.add(timeZoneVariable(context.getTimeZone()));
+ }
+
+ if (statementTimeout != null) {
+ if (context.isStatementTimeoutSupported()) {
+ variables.add(StringUtils.statementTimeoutVariable(statementTimeout, context.isMariaDb()));
+ } else {
+ logger.warn("Statement timeout is not supported in {}, ignore initial setting",
+ context.getServerVersion());
+ }
+ }
+
+ return variables;
+ }
+
+ private static String timeZoneVariable(ZoneId timeZone) {
+ String offerStr = timeZone instanceof ZoneOffset && "Z".equalsIgnoreCase(timeZone.getId()) ?
+ "+00:00" : timeZone.getId();
+
+ return "time_zone='" + offerStr + "'";
+ }
+
+ private static Flux convertSessionData(MySqlResult r, boolean timeZone) {
+ return r.map(readable -> {
+ IsolationLevel level = convertIsolationLevel(readable.get(0, String.class));
+ String product = readable.get(1, String.class);
+
+ return new SessionState(level, product, timeZone ? readZoneId(readable) : null);
+ });
+ }
+
+ /**
+ * Resolves the column of session isolation level, the {@literal @@tx_isolation} has been marked as deprecated.
+ *
+ * If server is MariaDB, {@literal @@transaction_isolation} is used starting from {@literal 11.1.1}.
+ *
+ * If the server is MySQL, use {@literal @@transaction_isolation} starting from {@literal 8.0.3}, or between
+ * {@literal 5.7.20} and {@literal 8.0.0} (exclusive).
+ */
+ private static String transactionIsolationColumn(ConnectionContext context) {
+ ServerVersion version = context.getServerVersion();
+
+ if (context.isMariaDb()) {
+ return version.isGreaterThanOrEqualTo(MARIA_11_1_1) ? "@@transaction_isolation AS i" :
+ "@@tx_isolation AS i";
+ }
+
+ return version.isGreaterThanOrEqualTo(MYSQL_8_0_3) ||
+ (version.isGreaterThanOrEqualTo(MYSQL_5_7_20) && version.isLessThan(MYSQL_8)) ?
+ "@@transaction_isolation AS i" : "@@tx_isolation AS i";
+ }
+
+ private static ZoneId readZoneId(Readable readable) {
+ String systemTimeZone = readable.get(2, String.class);
+ String timeZone = readable.get(3, String.class);
+
+ if (timeZone == null || timeZone.isEmpty() || "SYSTEM".equalsIgnoreCase(timeZone)) {
+ if (systemTimeZone == null || systemTimeZone.isEmpty()) {
+ logger.warn("MySQL does not return any timezone, trying to use system default timezone");
+ return ZoneId.systemDefault().normalized();
+ } else {
+ return convertZoneId(systemTimeZone);
+ }
+ } else {
+ return convertZoneId(timeZone);
+ }
+ }
+
+ private static ZoneId convertZoneId(String id) {
+ try {
+ return StringUtils.parseZoneId(id);
+ } catch (DateTimeException e) {
+ logger.warn("The server timezone is unknown <{}>, trying to use system default timezone", id, e);
+
+ return ZoneId.systemDefault().normalized();
+ }
+ }
+
+ private static IsolationLevel convertIsolationLevel(@Nullable String name) {
+ if (name == null) {
+ logger.warn("Isolation level is null in current session, fallback to repeatable read");
+
+ return IsolationLevel.REPEATABLE_READ;
+ }
+
+ switch (name) {
+ case "READ-UNCOMMITTED":
+ return IsolationLevel.READ_UNCOMMITTED;
+ case "READ-COMMITTED":
+ return IsolationLevel.READ_COMMITTED;
+ case "REPEATABLE-READ":
+ return IsolationLevel.REPEATABLE_READ;
+ case "SERIALIZABLE":
+ return IsolationLevel.SERIALIZABLE;
+ }
+
+ logger.warn("Unknown isolation level {} in current session, fallback to repeatable read", name);
+
+ return IsolationLevel.REPEATABLE_READ;
+ }
+
+ private InitFlow() {
+ }
+
+ private static final class SessionState {
+
+ private final IsolationLevel level;
+
+ @Nullable
+ private final String product;
+
+ @Nullable
+ private final ZoneId timeZone;
+
+ private final Duration lockWaitTimeout;
+
+ private final boolean lockWaitTimeoutSupported;
+
+ SessionState(IsolationLevel level, @Nullable String product, @Nullable ZoneId timeZone) {
+ this(level, product, timeZone, Duration.ZERO, false);
+ }
+
+ private SessionState(
+ IsolationLevel level,
+ @Nullable String product,
+ @Nullable ZoneId timeZone,
+ Duration lockWaitTimeout,
+ boolean lockWaitTimeoutSupported
+ ) {
+ this.level = level;
+ this.product = product;
+ this.timeZone = timeZone;
+ this.lockWaitTimeout = lockWaitTimeout;
+ this.lockWaitTimeoutSupported = lockWaitTimeoutSupported;
+ }
+
+ SessionState lockWaitTimeout(Duration timeout) {
+ return new SessionState(level, product, timeZone, timeout, true);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof SessionState)) {
+ return false;
+ }
+
+ SessionState that = (SessionState) o;
+
+ return lockWaitTimeoutSupported == that.lockWaitTimeoutSupported &&
+ level.equals(that.level) &&
+ Objects.equals(product, that.product) &&
+ Objects.equals(timeZone, that.timeZone) &&
+ lockWaitTimeout.equals(that.lockWaitTimeout);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = level.hashCode();
+ result = 31 * result + (product != null ? product.hashCode() : 0);
+ result = 31 * result + (timeZone != null ? timeZone.hashCode() : 0);
+ result = 31 * result + lockWaitTimeout.hashCode();
+ return 31 * result + (lockWaitTimeoutSupported ? 1 : 0);
+ }
+
+ @Override
+ public String toString() {
+ return "SessionState{level=" + level +
+ ", product='" + product +
+ "', timeZone=" + timeZone +
+ ", lockWaitTimeout=" + lockWaitTimeout +
+ ", lockWaitTimeoutSupported=" + lockWaitTimeoutSupported +
+ '}';
+ }
+ }
+}
+
+/**
+ * An implementation of {@link FluxExchangeable} that considers login to the database.
+ *
+ * Not like other {@link FluxExchangeable}s, it is started by a server-side message, which should be an implementation
+ * of {@link HandshakeRequest}.
+ */
+final class HandshakeExchangeable extends FluxExchangeable {
+
+ private static final InternalLogger logger = InternalLoggerFactory.getInstance(HandshakeExchangeable.class);
+
+ private static final Map ATTRIBUTES = Collections.emptyMap();
+
+ private static final String CLI_SPECIFIC = "HY000";
+
+ private static final int HANDSHAKE_VERSION = 10;
+
+ private final Sinks.Many requests = Sinks.many().unicast()
+ .onBackpressureBuffer(Queues.one().get());
+
+ private final Client client;
+
+ private final SslMode sslMode;
+
+ private final String database;
+
+ private final String user;
+
+ @Nullable
+ private final CharSequence password;
+
+ private final Set compressions;
+
+ private final int zstdCompressionLevel;
+
+ private boolean handshake = true;
+
+ private MySqlAuthProvider authProvider;
+
+ private byte[] salt;
+
+ private boolean sslCompleted;
+
+ HandshakeExchangeable(Client client, SslMode sslMode, String database, String user,
+ @Nullable CharSequence password, Set compressions,
+ int zstdCompressionLevel) {
+ this.client = client;
+ this.sslMode = sslMode;
+ this.database = database;
+ this.user = user;
+ this.password = password;
+ this.compressions = compressions;
+ this.zstdCompressionLevel = zstdCompressionLevel;
+ this.sslCompleted = sslMode == SslMode.TUNNEL;
+ }
+
+ @Override
+ public void subscribe(CoreSubscriber super ClientMessage> actual) {
+ requests.asFlux().subscribe(actual);
+ }
+
+ @Override
+ public void accept(ServerMessage message, SynchronousSink sink) {
+ if (message instanceof ErrorMessage) {
+ sink.error(((ErrorMessage) message).toException());
+ return;
+ }
+
+ // Ensures it will be initialized only once.
+ if (handshake) {
+ handshake = false;
+ if (message instanceof HandshakeRequest) {
+ HandshakeRequest request = (HandshakeRequest) message;
+ Capability capability = initHandshake(request);
+
+ if (capability.isSslEnabled()) {
+ emitNext(SslRequest.from(capability, client.getContext().getClientCollation().getId()), sink);
+ } else {
+ emitNext(createHandshakeResponse(capability), sink);
+ }
+ } else {
+ sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" +
+ message.getClass().getSimpleName() + "' in init phase"));
+ }
+
+ return;
+ }
+
+ if (message instanceof OkMessage) {
+ logger.trace("Connection (id {}) login success", client.getContext().getConnectionId());
+ client.loginSuccess();
+ sink.complete();
+ } else if (message instanceof SyntheticSslResponseMessage) {
+ sslCompleted = true;
+ emitNext(createHandshakeResponse(client.getContext().getCapability()), sink);
+ } else if (message instanceof AuthMoreDataMessage) {
+ AuthMoreDataMessage msg = (AuthMoreDataMessage) message;
+
+ if (msg.isFailed()) {
+ if (logger.isDebugEnabled()) {
+ logger.debug("Connection (id {}) fast authentication failed, use full authentication",
+ client.getContext().getConnectionId());
+ }
+
+ emitNext(createAuthResponse("full authentication"), sink);
+ }
+ // Otherwise success, wait until OK message or Error message.
+ } else if (message instanceof ChangeAuthMessage) {
+ ChangeAuthMessage msg = (ChangeAuthMessage) message;
+
+ authProvider = MySqlAuthProvider.build(msg.getAuthType());
+ salt = msg.getSalt();
+ emitNext(createAuthResponse("change authentication"), sink);
+ } else {
+ sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" +
+ message.getClass().getSimpleName() + "' in login phase"));
+ }
+ }
+
+ @Override
+ public void dispose() {
+ // No particular error condition handling for complete signal.
+ this.requests.tryEmitComplete();
+ }
+
+ private void emitNext(SubsequenceClientMessage message, SynchronousSink sink) {
+ Sinks.EmitResult result = requests.tryEmitNext(message);
+
+ if (result != Sinks.EmitResult.OK) {
+ sink.error(new IllegalStateException("Fail to emit a login request due to " + result));
+ }
+ }
+
+ private AuthResponse createAuthResponse(String phase) {
+ MySqlAuthProvider authProvider = getAndNextProvider();
+
+ if (authProvider.isSslNecessary() && !sslCompleted) {
+ throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC);
+ }
+
+ return new AuthResponse(authProvider.authentication(password, salt, client.getContext().getClientCollation()));
+ }
+
+ private Capability clientCapability(Capability serverCapability) {
+ Capability.Builder builder = serverCapability.mutate();
+
+ builder.disableSessionTrack();
+ builder.disableDatabasePinned();
+ builder.disableIgnoreAmbiguitySpace();
+ builder.disableInteractiveTimeout();
+
+ if (sslMode == SslMode.TUNNEL) {
+ // Tunnel does not use MySQL SSL protocol, disable it.
+ builder.disableSsl();
+ } else if (!serverCapability.isSslEnabled()) {
+ // Server unsupported SSL.
+ if (sslMode.requireSsl()) {
+ // Before handshake, Client.context does not be initialized
+ throw new R2dbcPermissionDeniedException("Server does not support SSL but mode '" + sslMode +
+ "' requires SSL", CLI_SPECIFIC);
+ } else if (sslMode.startSsl()) {
+ // SSL has start yet, and client can disable SSL, disable now.
+ client.sslUnsupported();
+ }
+ } else {
+ // The server supports SSL, but the user does not want to use SSL, disable it.
+ if (!sslMode.startSsl()) {
+ builder.disableSsl();
+ }
+ }
+
+ if (isZstdAllowed(serverCapability)) {
+ if (isZstdSupported()) {
+ builder.disableZlibCompression();
+ } else {
+ logger.warn("Server supports zstd, but zstd-jni dependency is missing");
+
+ if (isZlibAllowed(serverCapability)) {
+ builder.disableZstdCompression();
+ } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) {
+ builder.disableCompression();
+ } else {
+ throw new R2dbcNonTransientResourceException(
+ "Environment does not support a compression algorithm in " + compressions +
+ ", config does not allow uncompressed mode", CLI_SPECIFIC);
+ }
+ }
+ } else if (isZlibAllowed(serverCapability)) {
+ builder.disableZstdCompression();
+ } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) {
+ builder.disableCompression();
+ } else {
+ throw new R2dbcPermissionDeniedException(
+ "Environment does not support a compression algorithm in " + compressions +
+ ", config does not allow uncompressed mode", CLI_SPECIFIC);
+ }
+
+ if (database.isEmpty()) {
+ builder.disableConnectWithDatabase();
+ }
+
+ if (client.getContext().getLocalInfilePath() == null) {
+ builder.disableLoadDataLocalInfile();
+ }
+
+ if (ATTRIBUTES.isEmpty()) {
+ builder.disableConnectAttributes();
+ }
+
+ return builder.build();
+ }
+
+ private Capability initHandshake(HandshakeRequest message) {
+ HandshakeHeader header = message.getHeader();
+ int handshakeVersion = header.getProtocolVersion();
+ ServerVersion serverVersion = header.getServerVersion();
+
+ if (handshakeVersion < HANDSHAKE_VERSION) {
+ logger.warn("MySQL use handshake V{}, server version is {}, maybe most features are unavailable",
+ handshakeVersion, serverVersion);
+ }
+
+ Capability capability = clientCapability(message.getServerCapability());
+
+ // No need initialize server statuses because it has initialized by read filter.
+ this.client.getContext().initHandshake(header.getConnectionId(), serverVersion, capability);
+ this.authProvider = MySqlAuthProvider.build(message.getAuthType());
+ this.salt = message.getSalt();
+
+ return capability;
+ }
+
+ private MySqlAuthProvider getAndNextProvider() {
+ MySqlAuthProvider authProvider = this.authProvider;
+ this.authProvider = authProvider.next();
+ return authProvider;
+ }
+
+ private HandshakeResponse createHandshakeResponse(Capability capability) {
+ MySqlAuthProvider authProvider = getAndNextProvider();
+
+ if (authProvider.isSslNecessary() && !sslCompleted) {
+ throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"),
+ CLI_SPECIFIC);
+ }
+
+ byte[] authorization = authProvider.authentication(password, salt, client.getContext().getClientCollation());
+ String authType = authProvider.getType();
+
+ if (MySqlAuthProvider.NO_AUTH_PROVIDER.equals(authType)) {
+ // Authentication type is not matter because of it has no authentication type.
+ // Server need send a Change Authentication Message after handshake response.
+ authType = MySqlAuthProvider.CACHING_SHA2_PASSWORD;
+ }
+
+ return HandshakeResponse.from(capability, client.getContext().getClientCollation().getId(), user, authorization,
+ authType, database, ATTRIBUTES, zstdCompressionLevel);
+ }
+
+ private boolean isZstdAllowed(Capability capability) {
+ return capability.isZstdCompression() && compressions.contains(CompressionAlgorithm.ZSTD);
+ }
+
+ private boolean isZlibAllowed(Capability capability) {
+ return capability.isZlibCompression() && compressions.contains(CompressionAlgorithm.ZLIB);
+ }
+
+ private static String authFails(String authType, String phase) {
+ return "Authentication type '" + authType + "' must require SSL in " + phase + " phase";
+ }
+
+ private static boolean isZstdSupported() {
+ try {
+ ClassLoader loader = AccessController.doPrivileged((PrivilegedAction) () -> {
+ ClassLoader cl = Thread.currentThread().getContextClassLoader();
+ return cl == null ? ClassLoader.getSystemClassLoader() : cl;
+ });
+ Class.forName("com.github.luben.zstd.Zstd", false, loader);
+ return true;
+ } catch (ClassNotFoundException e) {
+ return false;
+ }
+ }
+}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionMetadata.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlClientConnectionMetadata.java
similarity index 60%
rename from r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionMetadata.java
rename to r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlClientConnectionMetadata.java
index ee7faf42d..61cb1d0b8 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionMetadata.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlClientConnectionMetadata.java
@@ -17,39 +17,31 @@
package io.asyncer.r2dbc.mysql;
import io.asyncer.r2dbc.mysql.api.MySqlConnectionMetadata;
-import org.jetbrains.annotations.Nullable;
-
-import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
+import io.asyncer.r2dbc.mysql.client.Client;
/**
* Connection metadata for a connection connected to MySQL database.
*/
-final class MySqlSimpleConnectionMetadata implements MySqlConnectionMetadata {
-
- private final String version;
-
- private final String product;
+final class MySqlClientConnectionMetadata implements MySqlConnectionMetadata {
- private final boolean isMariaDb;
+ private final Client client;
- MySqlSimpleConnectionMetadata(String version, @Nullable String product, boolean isMariaDb) {
- this.version = requireNonNull(version, "version must not be null");
- this.product = product == null ? "Unknown" : product;
- this.isMariaDb = isMariaDb;
+ MySqlClientConnectionMetadata(Client client) {
+ this.client = client;
}
@Override
public String getDatabaseVersion() {
- return version;
+ return client.getContext().getServerVersion().toString();
}
@Override
public boolean isMariaDb() {
- return isMariaDb;
+ return client.getContext().isMariaDb();
}
@Override
public String getDatabaseProductName() {
- return product;
+ return client.getContext().getProduct();
}
}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java
index 6d76a8bed..d003db2b0 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java
@@ -18,16 +18,9 @@
import io.asyncer.r2dbc.mysql.api.MySqlConnection;
import io.asyncer.r2dbc.mysql.cache.Caches;
-import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.cache.QueryCache;
import io.asyncer.r2dbc.mysql.client.Client;
-import io.asyncer.r2dbc.mysql.codec.Codecs;
-import io.asyncer.r2dbc.mysql.codec.CodecsBuilder;
-import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm;
-import io.asyncer.r2dbc.mysql.constant.SslMode;
-import io.asyncer.r2dbc.mysql.extension.CodecRegistrar;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
-import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.unix.DomainSocketAddress;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
@@ -38,13 +31,9 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.ZoneId;
-import java.time.ZoneOffset;
-import java.util.ArrayList;
-import java.util.List;
import java.util.Objects;
-import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
-import java.util.function.Predicate;
+import java.util.function.Supplier;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
@@ -93,102 +82,103 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
address = new DomainSocketAddress(configuration.getDomain());
}
- String database = configuration.getDatabase();
- boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist();
String user = configuration.getUser();
CharSequence password = configuration.getPassword();
- SslMode sslMode = ssl.getSslMode();
- int zstdCompressionLevel = configuration.getZstdCompressionLevel();
- ZoneId connectionTimeZone = retrieveZoneId(configuration.getConnectionTimeZone());
- ConnectionContext context = new ConnectionContext(
- configuration.getZeroDateOption(),
- configuration.getLoadLocalInfilePath(),
- configuration.getLocalInfileBufferSize(),
- configuration.isPreserveInstants(),
- connectionTimeZone
- );
- Set compressionAlgorithms = configuration.getCompressionAlgorithms();
- Extensions extensions = configuration.getExtensions();
- Predicate prepare = configuration.getPreferPrepareStatement();
- int prepareCacheSize = configuration.getPrepareCacheSize();
Publisher passwordPublisher = configuration.getPasswordPublisher();
- boolean forceTimeZone = configuration.isForceConnectionTimeZoneToSession();
- List sessionVariables = forceTimeZone && connectionTimeZone != null ?
- mergeSessionVariables(configuration.getSessionVariables(), connectionTimeZone) :
- configuration.getSessionVariables();
if (Objects.nonNull(passwordPublisher)) {
return Mono.from(passwordPublisher).flatMap(token -> getMySqlConnection(
- configuration, queryCache,
- ssl, address,
- database, createDbIfNotExist,
- user, sslMode,
- compressionAlgorithms, zstdCompressionLevel,
- context, extensions, sessionVariables, prepare,
- prepareCacheSize, token
+ configuration, ssl,
+ queryCache,
+ address,
+ user,
+ token
));
}
return getMySqlConnection(
- configuration, queryCache,
- ssl, address,
- database, createDbIfNotExist,
- user, sslMode,
- compressionAlgorithms, zstdCompressionLevel,
- context, extensions, sessionVariables, prepare,
- prepareCacheSize, password
+ configuration, ssl,
+ queryCache,
+ address,
+ user,
+ password
);
}));
}
+ /**
+ * Gets an initialized {@link MySqlConnection} from authentication credential and configurations.
+ *
+ * It contains following steps:
+ *
- Create connection context
+ * - Connect to MySQL server with TCP or Unix Domain Socket
+ * - Handshake/login and init handshake states
+ * - Init session states
+ *
+ * @param configuration the connection configuration.
+ * @param ssl the SSL configuration.
+ * @param queryCache lazy-init query cache, it is shared among all connections from the same factory.
+ * @param address TCP or Unix Domain Socket address.
+ * @param user the user of the authentication.
+ * @param password the password of the authentication.
+ * @return a {@link MySqlConnection}.
+ */
private static Mono getMySqlConnection(
- final MySqlConnectionConfiguration configuration,
- final LazyQueryCache queryCache,
- final MySqlSslConfiguration ssl,
- final SocketAddress address,
- final String database,
- final boolean createDbIfNotExist,
- final String user,
- final SslMode sslMode,
- final Set compressionAlgorithms,
- final int zstdLevel,
- final ConnectionContext context,
- final Extensions extensions,
- final List sessionVariables,
- @Nullable final Predicate prepare,
- final int prepareCacheSize,
- @Nullable final CharSequence password) {
- return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(),
- context, configuration.getConnectTimeout(), configuration.getLoopResources())
- .flatMap(client -> {
- // Lazy init database after handshake/login
- String db = createDbIfNotExist ? "" : database;
- return QueryFlow.login(client, sslMode, db, user, password, compressionAlgorithms, zstdLevel);
- })
- .flatMap(client -> {
- ByteBufAllocator allocator = client.getByteBufAllocator();
- CodecsBuilder builder = Codecs.builder();
- PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize);
- String db = createDbIfNotExist ? database : "";
-
- extensions.forEach(CodecRegistrar.class, registrar ->
- registrar.register(allocator, builder));
-
- Mono c = MySqlSimpleConnection.init(client, builder.build(), db, queryCache.get(),
- prepareCache, sessionVariables, prepare);
-
- if (configuration.getLockWaitTimeout() != null) {
- c = c.flatMap(connection -> connection.setLockWaitTimeout(configuration.getLockWaitTimeout())
- .thenReturn(connection));
- }
-
- if (configuration.getStatementTimeout() != null) {
- c = c.flatMap(connection -> connection.setStatementTimeout(configuration.getStatementTimeout())
- .thenReturn(connection));
- }
-
- return c;
- });
+ final MySqlConnectionConfiguration configuration,
+ final MySqlSslConfiguration ssl,
+ final LazyQueryCache queryCache,
+ final SocketAddress address,
+ final String user,
+ @Nullable final CharSequence password
+ ) {
+ return Mono.fromSupplier(() -> {
+ ZoneId connectionTimeZone = retrieveZoneId(configuration.getConnectionTimeZone());
+ return new ConnectionContext(
+ configuration.getZeroDateOption(),
+ configuration.getLoadLocalInfilePath(),
+ configuration.getLocalInfileBufferSize(),
+ configuration.isPreserveInstants(),
+ connectionTimeZone
+ );
+ }).flatMap(context -> Client.connect(
+ ssl,
+ address,
+ configuration.isTcpKeepAlive(),
+ configuration.isTcpNoDelay(),
+ context,
+ configuration.getConnectTimeout(),
+ configuration.getLoopResources()
+ )).flatMap(client -> {
+ // Lazy init database after handshake/login
+ boolean deferDatabase = configuration.isCreateDatabaseIfNotExist();
+ String database = configuration.getDatabase();
+ String loginDb = deferDatabase ? "" : database;
+ String sessionDb = deferDatabase ? database : "";
+
+ return InitFlow.initHandshake(
+ client,
+ ssl.getSslMode(),
+ loginDb,
+ user,
+ password,
+ configuration.getCompressionAlgorithms(),
+ configuration.getZstdCompressionLevel()
+ ).then(InitFlow.initSession(
+ client,
+ sessionDb,
+ configuration.getPrepareCacheSize(),
+ configuration.getSessionVariables(),
+ configuration.isForceConnectionTimeZoneToSession(),
+ configuration.getLockWaitTimeout(),
+ configuration.getStatementTimeout(),
+ configuration.getExtensions()
+ )).map(codecs -> new MySqlSimpleConnection(
+ client,
+ codecs,
+ queryCache.get(),
+ configuration.getPreferPrepareStatement()
+ )).onErrorResume(e -> client.forceClose().then(Mono.error(e)));
+ });
}
@Nullable
@@ -202,19 +192,7 @@ private static ZoneId retrieveZoneId(String timeZone) {
return StringUtils.parseZoneId(timeZone);
}
- private static List mergeSessionVariables(List sessionVariables, ZoneId timeZone) {
- List res = new ArrayList<>(sessionVariables.size() + 1);
-
- String offerStr = timeZone instanceof ZoneOffset && "Z".equalsIgnoreCase(timeZone.getId()) ?
- "+00:00" : timeZone.getId();
-
- res.addAll(sessionVariables);
- res.add("time_zone='" + offerStr + "'");
-
- return res;
- }
-
- private static final class LazyQueryCache {
+ private static final class LazyQueryCache implements Supplier {
private final int capacity;
@@ -227,6 +205,7 @@ private LazyQueryCache(int capacity) {
this.capacity = capacity;
}
+ @Override
public QueryCache get() {
QueryCache cache = this.cache;
if (cache == null) {
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java
index 660e25e06..ce5ba41e4 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java
@@ -19,16 +19,13 @@
import io.asyncer.r2dbc.mysql.api.MySqlBatch;
import io.asyncer.r2dbc.mysql.api.MySqlConnection;
import io.asyncer.r2dbc.mysql.api.MySqlConnectionMetadata;
-import io.asyncer.r2dbc.mysql.api.MySqlResult;
import io.asyncer.r2dbc.mysql.api.MySqlStatement;
import io.asyncer.r2dbc.mysql.api.MySqlTransactionDefinition;
-import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.cache.QueryCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
-import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
@@ -38,18 +35,15 @@
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.spi.IsolationLevel;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
-import io.r2dbc.spi.Readable;
import io.r2dbc.spi.TransactionDefinition;
import io.r2dbc.spi.ValidationDepth;
import org.jetbrains.annotations.Nullable;
+import org.jetbrains.annotations.TestOnly;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;
-import java.time.DateTimeException;
import java.time.Duration;
-import java.time.ZoneId;
-import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;
@@ -60,24 +54,12 @@
/**
* An implementation of {@link MySqlConnection} for connecting to the MySQL database.
*/
-final class MySqlSimpleConnection implements MySqlConnection, ConnectionState {
+final class MySqlSimpleConnection implements MySqlConnection {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(MySqlSimpleConnection.class);
private static final String PING_MARKER = "/* ping */";
- private static final ServerVersion MARIA_11_1_1 = ServerVersion.create(11, 1, 1, true);
-
- private static final ServerVersion MYSQL_8_0_3 = ServerVersion.create(8, 0, 3);
-
- private static final ServerVersion MYSQL_5_7_20 = ServerVersion.create(5, 7, 20);
-
- private static final ServerVersion MYSQL_8 = ServerVersion.create(8, 0, 0);
-
- private static final ServerVersion MYSQL_5_7_4 = ServerVersion.create(5, 7, 4);
-
- private static final ServerVersion MARIA_10_1_1 = ServerVersion.create(10, 1, 1, true);
-
private static final Function VALIDATE = message -> {
if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
return true;
@@ -106,87 +88,29 @@ final class MySqlSimpleConnection implements MySqlConnection, ConnectionState {
}
};
- private static final BiConsumer> INIT_DB = (message, sink) -> {
- if (message instanceof ErrorMessage) {
- ErrorMessage msg = (ErrorMessage) message;
- logger.debug("Use database failed: [{}] [{}] {}", msg.getCode(), msg.getSqlState(),
- msg.getMessage());
- sink.next(false);
- sink.complete();
- } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
- sink.next(true);
- sink.complete();
- } else {
- ReferenceCountUtil.safeRelease(message);
- }
- };
-
- private static final BiConsumer> INIT_DB_AFTER = (message, sink) -> {
- if (message instanceof ErrorMessage) {
- sink.error(((ErrorMessage) message).toException());
- } else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
- sink.complete();
- } else {
- ReferenceCountUtil.safeRelease(message);
- }
- };
-
private final Client client;
private final Codecs codecs;
- private final boolean batchSupported;
-
private final MySqlConnectionMetadata metadata;
- private volatile IsolationLevel sessionLevel;
-
private final QueryCache queryCache;
- private final PrepareCache prepareCache;
-
@Nullable
private final Predicate prepare;
- /**
- * Current isolation level inferred by past statements.
- *
- * Inference rules:
- *
- In the beginning, it is also {@link #sessionLevel}.
- * - After the user calls {@link #setTransactionIsolationLevel(IsolationLevel)}, it will change to
- * the user-specified value.
- * - After the end of a transaction (commit or rollback), it will recover to {@link #sessionLevel}.
- *
- */
- private volatile IsolationLevel currentLevel;
-
- /**
- * Session lock wait timeout.
- */
- private volatile long lockWaitTimeout;
-
- /**
- * Current transaction lock wait timeout.
- */
- private volatile long currentLockWaitTimeout;
+ // TODO: Check it when executing
+ private final boolean batchSupported;
- MySqlSimpleConnection(Client client, Codecs codecs, IsolationLevel level,
- long lockWaitTimeout, QueryCache queryCache, PrepareCache prepareCache, @Nullable String product,
- @Nullable Predicate prepare) {
+ MySqlSimpleConnection(Client client, Codecs codecs, QueryCache queryCache, @Nullable Predicate prepare) {
ConnectionContext context = client.getContext();
this.client = client;
- this.sessionLevel = level;
- this.currentLevel = level;
this.codecs = codecs;
- this.lockWaitTimeout = lockWaitTimeout;
- this.currentLockWaitTimeout = lockWaitTimeout;
+ this.metadata = new MySqlClientConnectionMetadata(client);
this.queryCache = queryCache;
- this.prepareCache = prepareCache;
- this.metadata = new MySqlSimpleConnectionMetadata(context.getServerVersion().toString(), product,
- context.isMariaDb());
- this.batchSupported = context.getCapability().isMultiStatementsAllowed();
this.prepare = prepare;
+ this.batchSupported = context.getCapability().isMultiStatementsAllowed();
if (this.batchSupported) {
logger.debug("Batch is supported by server");
@@ -202,7 +126,7 @@ public Mono beginTransaction() {
@Override
public Mono beginTransaction(TransactionDefinition definition) {
- return Mono.defer(() -> QueryFlow.beginTransaction(client, this, batchSupported, definition));
+ return Mono.defer(() -> QueryFlow.beginTransaction(client, batchSupported, definition));
}
@Override
@@ -219,7 +143,7 @@ public Mono close() {
@Override
public Mono commitTransaction() {
- return Mono.defer(() -> QueryFlow.doneTransaction(client, this, true, batchSupported));
+ return Mono.defer(() -> QueryFlow.doneTransaction(client, true, batchSupported));
}
@Override
@@ -231,7 +155,7 @@ public MySqlBatch createBatch() {
public Mono createSavepoint(String name) {
requireNonEmpty(name, "Savepoint name must not be empty");
- return QueryFlow.createSavepoint(client, this, name, batchSupported);
+ return QueryFlow.createSavepoint(client, name, batchSupported);
}
@Override
@@ -247,7 +171,7 @@ public MySqlStatement createStatement(String sql) {
if (query.isSimple()) {
if (prepare != null && prepare.test(sql)) {
logger.debug("Create a simple statement provided by prepare query");
- return new PrepareSimpleStatement(client, codecs, sql, prepareCache);
+ return new PrepareSimpleStatement(client, codecs, sql);
}
logger.debug("Create a simple statement provided by text query");
@@ -262,7 +186,7 @@ public MySqlStatement createStatement(String sql) {
logger.debug("Create a parameterized statement provided by prepare query");
- return new PrepareParameterizedStatement(client, codecs, query, prepareCache);
+ return new PrepareParameterizedStatement(client, codecs, query);
}
@Override
@@ -285,7 +209,7 @@ public Mono releaseSavepoint(String name) {
@Override
public Mono rollbackTransaction() {
- return Mono.defer(() -> QueryFlow.doneTransaction(client, this, false, batchSupported));
+ return Mono.defer(() -> QueryFlow.doneTransaction(client, false, batchSupported));
}
@Override
@@ -301,7 +225,7 @@ public MySqlConnectionMetadata getMetadata() {
}
/**
- * MySQL does not have any way to query the isolation level of the current transaction, only inferred from past
+ * MySQL does not have a way to query the isolation level of the current transaction, only inferred from past
* statements, so driver can not make sure the result is right.
*
* See MySQL Bug 53341
@@ -310,16 +234,7 @@ public MySqlConnectionMetadata getMetadata() {
*/
@Override
public IsolationLevel getTransactionIsolationLevel() {
- return currentLevel;
- }
-
- /**
- * Gets session transaction isolation level(Only for testing).
- *
- * @return session transaction isolation level.
- */
- IsolationLevel getSessionTransactionIsolationLevel() {
- return sessionLevel;
+ return client.getContext().getCurrentIsolationLevel();
}
@Override
@@ -330,9 +245,11 @@ public Mono setTransactionIsolationLevel(IsolationLevel isolationLevel) {
return QueryFlow.executeVoid(client,
"SET SESSION TRANSACTION ISOLATION LEVEL " + isolationLevel.asSql())
.doOnSuccess(ignored -> {
- this.sessionLevel = isolationLevel;
- if (!this.isInTransaction()) {
- this.currentLevel = isolationLevel;
+ ConnectionContext context = client.getContext();
+
+ context.setSessionIsolationLevel(isolationLevel);
+ if (!context.isInTransaction()) {
+ context.setCurrentIsolationLevel(isolationLevel);
}
});
}
@@ -366,12 +283,13 @@ public Mono validate(ValidationDepth depth) {
public boolean isAutoCommit() {
// Within transaction, autocommit remains disabled until end the transaction with COMMIT or ROLLBACK.
// The autocommit mode then reverts to its previous state.
- return !isInTransaction() && isSessionAutoCommit();
+ return !client.getContext().isInTransaction() && isSessionAutoCommit();
}
@Override
public Mono setAutoCommit(boolean autoCommit) {
return Mono.defer(() -> {
+ // TODO: remove the check or checking when executing
if (autoCommit == isSessionAutoCommit()) {
return Mono.empty();
}
@@ -380,321 +298,58 @@ public Mono setAutoCommit(boolean autoCommit) {
});
}
- @Override
- public void setIsolationLevel(IsolationLevel level) {
- this.currentLevel = level;
- }
-
- @Override
- public long getSessionLockWaitTimeout() {
- return lockWaitTimeout;
- }
-
- @Override
- public void setCurrentLockWaitTimeout(long timeoutSeconds) {
- this.currentLockWaitTimeout = timeoutSeconds;
- }
-
- @Override
- public void resetIsolationLevel() {
- this.currentLevel = this.sessionLevel;
- }
-
- @Override
- public boolean isLockWaitTimeoutChanged() {
- return currentLockWaitTimeout != lockWaitTimeout;
- }
-
- @Override
- public void resetCurrentLockWaitTimeout() {
- this.currentLockWaitTimeout = this.lockWaitTimeout;
- }
-
- @Override
- public boolean isInTransaction() {
- return (client.getContext().getServerStatuses() & ServerStatuses.IN_TRANSACTION) != 0;
- }
-
@Override
public Mono setLockWaitTimeout(Duration timeout) {
requireNonNull(timeout, "timeout must not be null");
- if (!client.getContext().isLockWaitTimeoutSupported()) {
- logger.warn("Lock wait timeout is not supported by server, setLockWaitTimeout operation is ignored");
- return Mono.empty();
+ if (client.getContext().isLockWaitTimeoutSupported()) {
+ return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(timeout))
+ .doOnSuccess(ignored -> client.getContext().setAllLockWaitTimeout(timeout));
}
- long timeoutSeconds = timeout.getSeconds();
- return QueryFlow.executeVoid(client, "SET innodb_lock_wait_timeout=" + timeoutSeconds)
- .doOnSuccess(ignored -> this.lockWaitTimeout = this.currentLockWaitTimeout = timeoutSeconds);
+ logger.warn("Lock wait timeout is not supported by server, setLockWaitTimeout operation is ignored");
+ return Mono.empty();
+
}
@Override
public Mono setStatementTimeout(Duration timeout) {
requireNonNull(timeout, "timeout must not be null");
- final ConnectionContext context = client.getContext();
- final boolean isMariaDb = context.isMariaDb();
- final ServerVersion serverVersion = context.getServerVersion();
- final long timeoutMs = timeout.toMillis();
- final String sql = isMariaDb ? "SET max_statement_time=" + timeoutMs / 1000.0
- : "SET SESSION MAX_EXECUTION_TIME=" + timeoutMs;
+ ConnectionContext context = client.getContext();
// mariadb: https://mariadb.com/kb/en/aborting-statements/
// mysql: https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/
// ref: https://github.com/mariadb-corporation/mariadb-connector-r2dbc
- if (isMariaDb && serverVersion.isGreaterThanOrEqualTo(MARIA_10_1_1)
- || !isMariaDb && serverVersion.isGreaterThanOrEqualTo(MYSQL_5_7_4)) {
- return QueryFlow.executeVoid(client, sql);
+ if (context.isStatementTimeoutSupported()) {
+ String variable = StringUtils.statementTimeoutVariable(timeout, context.isMariaDb());
+ return QueryFlow.setSessionVariable(client, variable);
}
return Mono.error(
new R2dbcNonTransientResourceException(
- "Statement timeout is not supported by server version " + serverVersion,
+ "Statement timeout is not supported by server version " + context.getServerVersion(),
"HY000",
- -1,
- sql
+ -1
)
);
}
- private boolean isSessionAutoCommit() {
- return (client.getContext().getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0;
- }
-
- static Flux doPingInternal(Client client) {
- return client.exchange(PingMessage.INSTANCE, PING);
- }
-
/**
- * Initialize a {@link MySqlConnection} after login.
+ * Visible only for testing.
*
- * @param client must be logged-in.
- * @param codecs the {@link Codecs}.
- * @param database the database that should be lazy init.
- * @param queryCache the cache of {@link Query}.
- * @param prepareCache the cache of server-preparing result.
- * @param sessionVariables the session variables to set.
- * @param prepare judging for prefer use prepare statement to execute simple query.
- * @return a {@link Mono} will emit an initialized {@link MySqlConnection}.
+ * @return current connection context
*/
- static Mono init(
- Client client, Codecs codecs, String database,
- QueryCache queryCache, PrepareCache prepareCache,
- List sessionVariables, @Nullable Predicate prepare
- ) {
- Mono connection = initSessionVariables(client, sessionVariables)
- .then(loadSessionVariables(client, codecs))
- .flatMap(data -> loadInnoDbEngineStatus(data, client, codecs))
- .map(data -> {
- ConnectionContext context = client.getContext();
- ZoneId timeZone = data.timeZone;
- if (timeZone != null) {
- logger.debug("Got server time zone {} from loading session variables", timeZone);
- context.initTimeZone(timeZone);
- }
-
- if (data.lockWaitTimeoutSupported) {
- context.enableLockWaitTimeoutSupported();
- } else {
- logger.info("Lock wait timeout is not supported by server, all related operations will be ignored");
- }
-
- return new MySqlSimpleConnection(client, codecs, data.level, data.lockWaitTimeout,
- queryCache, prepareCache, data.product, prepare);
- });
-
- if (database.isEmpty()) {
- return connection;
- }
-
- return connection.flatMap(c -> initDatabase(client, database).thenReturn(c));
- }
-
- private static Mono initSessionVariables(Client client, List sessionVariables) {
- if (sessionVariables.isEmpty()) {
- return Mono.empty();
- }
-
- StringBuilder query = new StringBuilder(sessionVariables.size() * 32 + 16).append("SET ");
- boolean comma = false;
-
- for (String variable : sessionVariables) {
- if (variable.isEmpty()) {
- continue;
- }
-
- if (comma) {
- query.append(',');
- } else {
- comma = true;
- }
-
- if (variable.startsWith("@")) {
- query.append(variable);
- } else {
- query.append("SESSION ").append(variable);
- }
- }
-
- return QueryFlow.executeVoid(client, query.toString());
- }
-
- private static Mono loadSessionVariables(Client client, Codecs codecs) {
- ConnectionContext context = client.getContext();
- StringBuilder query = new StringBuilder(128)
- .append("SELECT ")
- .append(transactionIsolationColumn(context))
- .append(",@@version_comment AS v");
-
- Function> handler;
-
- if (context.isTimeZoneInitialized()) {
- handler = r -> convertSessionData(r, false);
- } else {
- query.append(",@@system_time_zone AS s,@@time_zone AS t");
- handler = r -> convertSessionData(r, true);
- }
-
- return new TextSimpleStatement(client, codecs, query.toString())
- .execute()
- .flatMap(handler)
- .last();
- }
-
- private static Mono loadInnoDbEngineStatus(SessionData data, Client client, Codecs codecs) {
- return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\\\_lock\\\\_wait\\\\_timeout'")
- .execute()
- .flatMap(r -> r.map(readable -> {
- String value = readable.get(1, String.class);
-
- if (value == null || value.isEmpty()) {
- return data;
- } else {
- return data.lockWaitTimeout(Long.parseLong(value));
- }
- }))
- .single(data);
- }
-
- private static Mono initDatabase(Client client, String database) {
- return client.exchange(new InitDbMessage(database), INIT_DB)
- .last()
- .flatMap(success -> {
- if (success) {
- return Mono.empty();
- }
-
- String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database);
-
- return QueryFlow.executeVoid(client, sql)
- .then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then());
- });
- }
-
- private static Flux convertSessionData(MySqlResult r, boolean timeZone) {
- return r.map(readable -> {
- IsolationLevel level = convertIsolationLevel(readable.get(0, String.class));
- String product = readable.get(1, String.class);
-
- return new SessionData(level, product, timeZone ? readZoneId(readable) : null);
- });
- }
-
- private static ZoneId readZoneId(Readable readable) {
- String systemTimeZone = readable.get(2, String.class);
- String timeZone = readable.get(3, String.class);
-
- if (timeZone == null || timeZone.isEmpty() || "SYSTEM".equalsIgnoreCase(timeZone)) {
- if (systemTimeZone == null || systemTimeZone.isEmpty()) {
- logger.warn("MySQL does not return any timezone, trying to use system default timezone");
- return ZoneId.systemDefault().normalized();
- } else {
- return convertZoneId(systemTimeZone);
- }
- } else {
- return convertZoneId(timeZone);
- }
- }
-
- private static ZoneId convertZoneId(String id) {
- try {
- return StringUtils.parseZoneId(id);
- } catch (DateTimeException e) {
- logger.warn("The server timezone is unknown <{}>, trying to use system default timezone", id, e);
-
- return ZoneId.systemDefault().normalized();
- }
- }
-
- private static IsolationLevel convertIsolationLevel(@Nullable String name) {
- if (name == null) {
- logger.warn("Isolation level is null in current session, fallback to repeatable read");
-
- return IsolationLevel.REPEATABLE_READ;
- }
-
- switch (name) {
- case "READ-UNCOMMITTED":
- return IsolationLevel.READ_UNCOMMITTED;
- case "READ-COMMITTED":
- return IsolationLevel.READ_COMMITTED;
- case "REPEATABLE-READ":
- return IsolationLevel.REPEATABLE_READ;
- case "SERIALIZABLE":
- return IsolationLevel.SERIALIZABLE;
- }
-
- logger.warn("Unknown isolation level {} in current session, fallback to repeatable read", name);
-
- return IsolationLevel.REPEATABLE_READ;
+ @TestOnly
+ ConnectionContext context() {
+ return client.getContext();
}
- /**
- * Resolves the column of session isolation level, the {@literal @@tx_isolation} has been marked as deprecated.
- *
- * If server is MariaDB, {@literal @@transaction_isolation} is used starting from {@literal 11.1.1}.
- *
- * If the server is MySQL, use {@literal @@transaction_isolation} starting from {@literal 8.0.3}, or between
- * {@literal 5.7.20} and {@literal 8.0.0} (exclusive).
- */
- private static String transactionIsolationColumn(ConnectionContext context) {
- ServerVersion version = context.getServerVersion();
-
- if (context.isMariaDb()) {
- return version.isGreaterThanOrEqualTo(MARIA_11_1_1) ? "@@transaction_isolation AS i" :
- "@@tx_isolation AS i";
- }
-
- return version.isGreaterThanOrEqualTo(MYSQL_8_0_3) ||
- (version.isGreaterThanOrEqualTo(MYSQL_5_7_20) && version.isLessThan(MYSQL_8)) ?
- "@@transaction_isolation AS i" : "@@tx_isolation AS i";
+ private boolean isSessionAutoCommit() {
+ return (client.getContext().getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0;
}
- private static final class SessionData {
-
- private final IsolationLevel level;
-
- @Nullable
- private final String product;
-
- @Nullable
- private final ZoneId timeZone;
-
- private long lockWaitTimeout = -1;
-
- private boolean lockWaitTimeoutSupported;
-
- private SessionData(IsolationLevel level, @Nullable String product, @Nullable ZoneId timeZone) {
- this.level = level;
- this.product = product;
- this.timeZone = timeZone;
- }
-
- SessionData lockWaitTimeout(long timeout) {
- this.lockWaitTimeoutSupported = true;
- this.lockWaitTimeout = timeout;
- return this;
- }
+ static Flux doPingInternal(Client client) {
+ return client.exchange(PingMessage.INSTANCE, PING);
}
}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java
index d9e290811..44edd9509 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatement.java
@@ -18,7 +18,6 @@
import io.asyncer.r2dbc.mysql.api.MySqlResult;
import io.asyncer.r2dbc.mysql.api.MySqlStatement;
-import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
@@ -33,20 +32,17 @@
*/
final class PrepareParameterizedStatement extends ParameterizedStatementSupport {
- private final PrepareCache prepareCache;
-
private int fetchSize = 0;
- PrepareParameterizedStatement(Client client, Codecs codecs, Query query, PrepareCache prepareCache) {
+ PrepareParameterizedStatement(Client client, Codecs codecs, Query query) {
super(client, codecs, query);
- this.prepareCache = prepareCache;
}
@Override
public Flux execute(List bindings) {
return Flux.defer(() -> QueryFlow.execute(client,
StringUtils.extendReturning(query.getFormattedSql(), returningIdentifiers()),
- bindings, fetchSize, prepareCache
+ bindings, fetchSize
))
.map(messages -> MySqlSegmentResult.toResult(true, client, codecs, syntheticKeyName(), messages));
}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java
index d78bb3488..7ff6b06f6 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java
@@ -36,19 +36,16 @@ final class PrepareSimpleStatement extends SimpleStatementSupport {
private static final List BINDINGS = Collections.singletonList(new Binding(0));
- private final PrepareCache prepareCache;
-
private int fetchSize = 0;
- PrepareSimpleStatement(Client client, Codecs codecs, String sql, PrepareCache prepareCache) {
+ PrepareSimpleStatement(Client client, Codecs codecs, String sql) {
super(client, codecs, sql);
- this.prepareCache = prepareCache;
}
@Override
public Flux execute() {
return Flux.defer(() -> QueryFlow.execute(client,
- StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize, prepareCache))
+ StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize))
.map(messages -> MySqlSegmentResult.toResult(true, client, codecs, syntheticKeyName(), messages));
}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
index e7a5de4bc..23ce5e806 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
@@ -18,19 +18,12 @@
import io.asyncer.r2dbc.mysql.api.MySqlBatch;
import io.asyncer.r2dbc.mysql.api.MySqlTransactionDefinition;
-import io.asyncer.r2dbc.mysql.authentication.MySqlAuthProvider;
-import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.client.FluxExchangeable;
-import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
-import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
-import io.asyncer.r2dbc.mysql.message.client.AuthResponse;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
-import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse;
import io.asyncer.r2dbc.mysql.message.client.LocalInfileResponse;
-import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage;
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
import io.asyncer.r2dbc.mysql.message.client.PrepareQueryMessage;
import io.asyncer.r2dbc.mysql.message.client.PreparedCloseMessage;
@@ -38,29 +31,21 @@
import io.asyncer.r2dbc.mysql.message.client.PreparedFetchMessage;
import io.asyncer.r2dbc.mysql.message.client.PreparedResetMessage;
import io.asyncer.r2dbc.mysql.message.client.PreparedTextQueryMessage;
-import io.asyncer.r2dbc.mysql.message.client.SslRequest;
import io.asyncer.r2dbc.mysql.message.client.TextQueryMessage;
-import io.asyncer.r2dbc.mysql.message.server.AuthMoreDataMessage;
-import io.asyncer.r2dbc.mysql.message.server.ChangeAuthMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.asyncer.r2dbc.mysql.message.server.EofMessage;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
-import io.asyncer.r2dbc.mysql.message.server.HandshakeHeader;
-import io.asyncer.r2dbc.mysql.message.server.HandshakeRequest;
import io.asyncer.r2dbc.mysql.message.server.LocalInfileRequest;
import io.asyncer.r2dbc.mysql.message.server.OkMessage;
import io.asyncer.r2dbc.mysql.message.server.PreparedOkMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerStatusMessage;
import io.asyncer.r2dbc.mysql.message.server.SyntheticMetadataMessage;
-import io.asyncer.r2dbc.mysql.message.server.SyntheticSslResponseMessage;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.spi.IsolationLevel;
-import io.r2dbc.spi.R2dbcNonTransientResourceException;
-import io.r2dbc.spi.R2dbcPermissionDeniedException;
import io.r2dbc.spi.TransactionDefinition;
import org.jetbrains.annotations.Nullable;
import reactor.core.CoreSubscriber;
@@ -72,15 +57,10 @@
import reactor.core.publisher.SynchronousSink;
import reactor.util.concurrent.Queues;
-import java.security.AccessController;
-import java.security.PrivilegedAction;
import java.time.Duration;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.Iterator;
import java.util.List;
-import java.util.Map;
-import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
@@ -116,18 +96,16 @@ final class QueryFlow {
* @param sql the statement for exception tracing.
* @param bindings the data of bindings.
* @param fetchSize the size of fetching, if it less than or equal to {@literal 0} means fetch all rows.
- * @param cache the cache of server-preparing result.
* @return the messages received in response to this exchange.
*/
- static Flux> execute(Client client, String sql, List bindings, int fetchSize,
- PrepareCache cache) {
+ static Flux> execute(Client client, String sql, List bindings, int fetchSize) {
return Flux.defer(() -> {
if (bindings.isEmpty()) {
return Flux.empty();
}
// Note: the prepared SQL may not be sent when the cache matches.
- return client.exchange(new PrepareExchangeable(cache, sql, bindings.iterator(), fetchSize))
+ return client.exchange(new PrepareExchangeable(client, sql, bindings.iterator(), fetchSize))
.windowUntil(RESULT_DONE);
});
}
@@ -194,29 +172,6 @@ static Flux> execute(Client client, List statements)
});
}
- /**
- * Login a {@link Client} and receive the {@code client} after logon. It will emit an exception when client receives
- * a {@link ErrorMessage}.
- *
- * @param client the {@link Client} to exchange messages with.
- * @param sslMode the {@link SslMode} defines SSL capability and behavior.
- * @param database the database that will be connected.
- * @param user the user that will be login.
- * @param password the password of the {@code user}.
- * @param compressionAlgorithms the list of compression algorithms.
- * @param zstdCompressionLevel the zstd compression level.
- * @param context the {@link ConnectionContext} for initialization.
- * @return the messages received in response to the login exchange.
- */
- static Mono login(Client client, SslMode sslMode, String database, String user,
- @Nullable CharSequence password,
- Set compressionAlgorithms, int zstdCompressionLevel) {
- return client.exchange(new LoginExchangeable(client, sslMode, database, user, password,
- compressionAlgorithms, zstdCompressionLevel))
- .onErrorResume(e -> client.forceClose().then(Mono.error(e)))
- .then(Mono.just(client));
- }
-
/**
* Execute a simple query and return a {@link Mono} for the complete signal or error. Query execution terminates
* with the last {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception.
@@ -245,17 +200,15 @@ static Mono executeVoid(Client client, String sql) {
/**
* Begins a new transaction with a {@link TransactionDefinition}. It will change current transaction statuses of
- * the {@link ConnectionState}.
+ * the {@link ConnectionContext}.
*
* @param client the {@link Client} to exchange messages with.
- * @param state the connection state for checks and sets transaction statuses.
* @param batchSupported if connection supports batch query.
* @param definition the {@link TransactionDefinition}.
* @return receives complete signal.
*/
- static Mono beginTransaction(Client client, ConnectionState state, boolean batchSupported,
- TransactionDefinition definition) {
- final StartTransactionState startState = new StartTransactionState(state, definition, client);
+ static Mono beginTransaction(Client client, boolean batchSupported, TransactionDefinition definition) {
+ final StartTransactionState startState = new StartTransactionState(client, definition);
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(startState)).then();
@@ -265,18 +218,15 @@ static Mono beginTransaction(Client client, ConnectionState state, boolean
}
/**
- * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in the initial
- * connection state.
+ * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionContext}.
*
* @param client the {@link Client} to exchange messages with.
- * @param state the connection state for checks and resets transaction statuses.
* @param commit if it is commit, otherwise rollback.
* @param batchSupported if connection supports batch query.
* @return receives complete signal.
*/
- static Mono doneTransaction(Client client, ConnectionState state, boolean commit,
- boolean batchSupported) {
- final CommitRollbackState commitState = new CommitRollbackState(state, commit);
+ static Mono doneTransaction(Client client, boolean commit, boolean batchSupported) {
+ final CommitRollbackState commitState = new CommitRollbackState(client, commit);
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(commitState)).then();
@@ -285,15 +235,80 @@ static Mono doneTransaction(Client client, ConnectionState state, boolean
return client.exchange(new TransactionMultiExchangeable(commitState)).then();
}
- static Mono createSavepoint(Client client, ConnectionState state, String name,
- boolean batchSupported) {
- final CreateSavepointState savepointState = new CreateSavepointState(state, name);
+ /**
+ * Creates a savepoint with a name. It will begin a new transaction before creating a savepoint if the connection is
+ * not in a transaction.
+ *
+ * @param client the {@link Client} to exchange messages with.
+ * @param name the name of the savepoint.
+ * @param batchSupported if connection supports batch query.
+ * @return a {@link Mono} receives complete signal.
+ */
+ static Mono createSavepoint(Client client, String name, boolean batchSupported) {
+ final CreateSavepointState savepointState = new CreateSavepointState(client, name);
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(savepointState)).then();
}
return client.exchange(new TransactionMultiExchangeable(savepointState)).then();
}
+ /**
+ * Sets a session variable to the server.
+ *
+ * @param client the {@link Client} to exchange messages with.
+ * @param variable the session variable to set, e.g. {@code "sql_mode='ANSI'"}.
+ * @return a {@link Mono} receives complete signal.
+ */
+ static Mono setSessionVariable(Client client, String variable) {
+ if (variable.isEmpty()) {
+ return Mono.empty();
+ } else if (variable.startsWith("@")) {
+ return executeVoid(client, "SET " + variable);
+ }
+
+ return executeVoid(client, "SET SESSION " + variable);
+ }
+
+ /**
+ * Sets multiple session variables to the server.
+ *
+ * @param client the {@link Client} to exchange messages with.
+ * @param sessionVariables the session variables to set, e.g. {@code ["sql_mode='ANSI'", "time_zone='+09:00'"]}.
+ * @return a {@link Mono} receives complete signal.
+ */
+ static Mono setSessionVariables(Client client, List sessionVariables) {
+ switch (sessionVariables.size()) {
+ case 0:
+ return Mono.empty();
+ case 1:
+ return setSessionVariable(client, sessionVariables.get(0));
+ default: {
+ StringBuilder query = new StringBuilder(sessionVariables.size() * 32 + 16).append("SET ");
+ boolean comma = false;
+
+ for (String variable : sessionVariables) {
+ if (variable.isEmpty()) {
+ continue;
+ }
+
+ if (comma) {
+ query.append(',');
+ } else {
+ comma = true;
+ }
+
+ if (variable.startsWith("@")) {
+ query.append(variable);
+ } else {
+ query.append("SESSION ").append(variable);
+ }
+ }
+
+ return executeVoid(client, query.toString());
+ }
+ }
+ }
+
/**
* Execute a simple query statement. Query execution terminates with the last {@link CompleteMessage} or a
* {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. The exchange will be completed by
@@ -544,7 +559,7 @@ final class PrepareExchangeable extends FluxExchangeable {
private final Sinks.Many requests = Sinks.many().unicast()
.onBackpressureBuffer(Queues.one().get());
- private final PrepareCache cache;
+ private final Client client;
private final String sql;
@@ -559,8 +574,8 @@ final class PrepareExchangeable extends FluxExchangeable {
private boolean shouldClose;
- PrepareExchangeable(PrepareCache cache, String sql, Iterator bindings, int fetchSize) {
- this.cache = cache;
+ PrepareExchangeable(Client client, String sql, Iterator bindings, int fetchSize) {
+ this.client = client;
this.sql = sql;
this.bindings = bindings;
this.fetchSize = fetchSize;
@@ -572,7 +587,7 @@ public void subscribe(CoreSubscriber super ClientMessage> actual) {
requests.asFlux().subscribe(actual);
// After subscribe.
- Integer statementId = cache.getIfPresent(sql);
+ Integer statementId = client.getContext().getPrepareCache().getIfPresent(sql);
if (statementId == null) {
logger.debug("Prepare cache mismatch, try to preparing");
this.shouldClose = true;
@@ -713,7 +728,7 @@ private void putToCache(Integer statementId) {
boolean putSucceed;
try {
- putSucceed = cache.putIfAbsent(sql, statementId, evictId -> {
+ putSucceed = client.getContext().getPrepareCache().putIfAbsent(sql, statementId, evictId -> {
logger.debug("Prepare cache evicts statement {} when putting", evictId);
Sinks.EmitResult result = requests.tryEmitNext(new PreparedCloseMessage(evictId));
@@ -809,292 +824,9 @@ private void onCompleteMessage(CompleteMessage message, SynchronousSink
- * Not like other {@link FluxExchangeable}s, it is started by a server-side message, which should be an implementation
- * of {@link HandshakeRequest}.
- */
-final class LoginExchangeable extends FluxExchangeable {
-
- private static final InternalLogger logger = InternalLoggerFactory.getInstance(LoginExchangeable.class);
-
- private static final Map ATTRIBUTES = Collections.emptyMap();
-
- private static final String CLI_SPECIFIC = "HY000";
-
- private static final int HANDSHAKE_VERSION = 10;
-
- private final Sinks.Many requests = Sinks.many().unicast()
- .onBackpressureBuffer(Queues.one().get());
-
- private final Client client;
-
- private final SslMode sslMode;
-
- private final String database;
-
- private final String user;
-
- @Nullable
- private final CharSequence password;
-
- private final Set compressions;
-
- private final int zstdCompressionLevel;
-
- private boolean handshake = true;
-
- private MySqlAuthProvider authProvider;
-
- private byte[] salt;
-
- private boolean sslCompleted;
-
- LoginExchangeable(Client client, SslMode sslMode, String database, String user,
- @Nullable CharSequence password, Set compressions,
- int zstdCompressionLevel) {
- this.client = client;
- this.sslMode = sslMode;
- this.database = database;
- this.user = user;
- this.password = password;
- this.compressions = compressions;
- this.zstdCompressionLevel = zstdCompressionLevel;
- this.sslCompleted = sslMode == SslMode.TUNNEL;
- }
-
- @Override
- public void subscribe(CoreSubscriber super ClientMessage> actual) {
- requests.asFlux().subscribe(actual);
- }
-
- @Override
- public void accept(ServerMessage message, SynchronousSink sink) {
- if (message instanceof ErrorMessage) {
- sink.error(((ErrorMessage) message).toException());
- return;
- }
-
- // Ensures it will be initialized only once.
- if (handshake) {
- handshake = false;
- if (message instanceof HandshakeRequest) {
- HandshakeRequest request = (HandshakeRequest) message;
- Capability capability = initHandshake(request);
-
- if (capability.isSslEnabled()) {
- emitNext(SslRequest.from(capability, client.getContext().getClientCollation().getId()), sink);
- } else {
- emitNext(createHandshakeResponse(capability), sink);
- }
- } else {
- sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" +
- message.getClass().getSimpleName() + "' in init phase"));
- }
-
- return;
- }
-
- if (message instanceof OkMessage) {
- client.loginSuccess();
- sink.complete();
- } else if (message instanceof SyntheticSslResponseMessage) {
- sslCompleted = true;
- emitNext(createHandshakeResponse(client.getContext().getCapability()), sink);
- } else if (message instanceof AuthMoreDataMessage) {
- AuthMoreDataMessage msg = (AuthMoreDataMessage) message;
-
- if (msg.isFailed()) {
- if (logger.isDebugEnabled()) {
- logger.debug("Connection (id {}) fast authentication failed, use full authentication",
- client.getContext().getConnectionId());
- }
-
- emitNext(createAuthResponse("full authentication"), sink);
- }
- // Otherwise success, wait until OK message or Error message.
- } else if (message instanceof ChangeAuthMessage) {
- ChangeAuthMessage msg = (ChangeAuthMessage) message;
-
- authProvider = MySqlAuthProvider.build(msg.getAuthType());
- salt = msg.getSalt();
- emitNext(createAuthResponse("change authentication"), sink);
- } else {
- sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" +
- message.getClass().getSimpleName() + "' in login phase"));
- }
- }
-
- @Override
- public void dispose() {
- // No particular error condition handling for complete signal.
- this.requests.tryEmitComplete();
- }
-
- private void emitNext(SubsequenceClientMessage message, SynchronousSink sink) {
- Sinks.EmitResult result = requests.tryEmitNext(message);
-
- if (result != Sinks.EmitResult.OK) {
- sink.error(new IllegalStateException("Fail to emit a login request due to " + result));
- }
- }
-
- private AuthResponse createAuthResponse(String phase) {
- MySqlAuthProvider authProvider = getAndNextProvider();
-
- if (authProvider.isSslNecessary() && !sslCompleted) {
- throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC);
- }
-
- return new AuthResponse(authProvider.authentication(password, salt, client.getContext().getClientCollation()));
- }
-
- private Capability clientCapability(Capability serverCapability) {
- Capability.Builder builder = serverCapability.mutate();
-
- builder.disableSessionTrack();
- builder.disableDatabasePinned();
- builder.disableIgnoreAmbiguitySpace();
- builder.disableInteractiveTimeout();
-
- if (sslMode == SslMode.TUNNEL) {
- // Tunnel does not use MySQL SSL protocol, disable it.
- builder.disableSsl();
- } else if (!serverCapability.isSslEnabled()) {
- // Server unsupported SSL.
- if (sslMode.requireSsl()) {
- // Before handshake, Client.context does not be initialized
- throw new R2dbcPermissionDeniedException("Server does not support SSL but mode '" + sslMode +
- "' requires SSL", CLI_SPECIFIC);
- } else if (sslMode.startSsl()) {
- // SSL has start yet, and client can disable SSL, disable now.
- client.sslUnsupported();
- }
- } else {
- // The server supports SSL, but the user does not want to use SSL, disable it.
- if (!sslMode.startSsl()) {
- builder.disableSsl();
- }
- }
-
- if (isZstdAllowed(serverCapability)) {
- if (isZstdSupported()) {
- builder.disableZlibCompression();
- } else {
- logger.warn("Server supports zstd, but zstd-jni dependency is missing");
-
- if (isZlibAllowed(serverCapability)) {
- builder.disableZstdCompression();
- } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) {
- builder.disableCompression();
- } else {
- throw new R2dbcNonTransientResourceException(
- "Environment does not support a compression algorithm in " + compressions +
- ", config does not allow uncompressed mode", CLI_SPECIFIC);
- }
- }
- } else if (isZlibAllowed(serverCapability)) {
- builder.disableZstdCompression();
- } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) {
- builder.disableCompression();
- } else {
- throw new R2dbcPermissionDeniedException(
- "Environment does not support a compression algorithm in " + compressions +
- ", config does not allow uncompressed mode", CLI_SPECIFIC);
- }
-
- if (database.isEmpty()) {
- builder.disableConnectWithDatabase();
- }
-
- if (client.getContext().getLocalInfilePath() == null) {
- builder.disableLoadDataLocalInfile();
- }
-
- if (ATTRIBUTES.isEmpty()) {
- builder.disableConnectAttributes();
- }
-
- return builder.build();
- }
-
- private Capability initHandshake(HandshakeRequest message) {
- HandshakeHeader header = message.getHeader();
- int handshakeVersion = header.getProtocolVersion();
- ServerVersion serverVersion = header.getServerVersion();
-
- if (handshakeVersion < HANDSHAKE_VERSION) {
- logger.warn("MySQL use handshake V{}, server version is {}, maybe most features are unavailable",
- handshakeVersion, serverVersion);
- }
-
- Capability capability = clientCapability(message.getServerCapability());
-
- // No need initialize server statuses because it has initialized by read filter.
- this.client.getContext().init(header.getConnectionId(), serverVersion, capability);
- this.authProvider = MySqlAuthProvider.build(message.getAuthType());
- this.salt = message.getSalt();
-
- return capability;
- }
-
- private MySqlAuthProvider getAndNextProvider() {
- MySqlAuthProvider authProvider = this.authProvider;
- this.authProvider = authProvider.next();
- return authProvider;
- }
-
- private HandshakeResponse createHandshakeResponse(Capability capability) {
- MySqlAuthProvider authProvider = getAndNextProvider();
-
- if (authProvider.isSslNecessary() && !sslCompleted) {
- throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), "handshake"),
- CLI_SPECIFIC);
- }
-
- byte[] authorization = authProvider.authentication(password, salt, client.getContext().getClientCollation());
- String authType = authProvider.getType();
-
- if (MySqlAuthProvider.NO_AUTH_PROVIDER.equals(authType)) {
- // Authentication type is not matter because of it has no authentication type.
- // Server need send a Change Authentication Message after handshake response.
- authType = MySqlAuthProvider.CACHING_SHA2_PASSWORD;
- }
-
- return HandshakeResponse.from(capability, client.getContext().getClientCollation().getId(), user, authorization,
- authType, database, ATTRIBUTES, zstdCompressionLevel);
- }
-
- private boolean isZstdAllowed(Capability capability) {
- return capability.isZstdCompression() && compressions.contains(CompressionAlgorithm.ZSTD);
- }
-
- private boolean isZlibAllowed(Capability capability) {
- return capability.isZlibCompression() && compressions.contains(CompressionAlgorithm.ZLIB);
- }
-
- private static String authFails(String authType, String phase) {
- return "Authentication type '" + authType + "' must require SSL in " + phase + " phase";
- }
-
- private static boolean isZstdSupported() {
- try {
- ClassLoader loader = AccessController.doPrivileged((PrivilegedAction) () -> {
- ClassLoader cl = Thread.currentThread().getContextClassLoader();
- return cl == null ? ClassLoader.getSystemClassLoader() : cl;
- });
- Class.forName("com.github.luben.zstd.Zstd", false, loader);
- return true;
- } catch (ClassNotFoundException e) {
- return false;
- }
- }
-}
-
abstract class AbstractTransactionState {
- final ConnectionState state;
+ final Client client;
final List statements = new ArrayList<>(5);
@@ -1106,8 +838,8 @@ abstract class AbstractTransactionState {
@Nullable
private String sql;
- protected AbstractTransactionState(ConnectionState state) {
- this.state = state;
+ protected AbstractTransactionState(Client client) {
+ this.client = client;
}
final void setSql(String sql) {
@@ -1165,22 +897,24 @@ final class CommitRollbackState extends AbstractTransactionState {
private final boolean commit;
- CommitRollbackState(ConnectionState state, boolean commit) {
- super(state);
+ CommitRollbackState(Client client, boolean commit) {
+ super(client);
this.commit = commit;
}
@Override
boolean cancelTasks() {
- if (!state.isInTransaction()) {
+ ConnectionContext context = client.getContext();
+
+ if (!context.isInTransaction()) {
tasks |= CANCEL;
return true;
}
- if (state.isLockWaitTimeoutChanged()) {
+ if (context.isLockWaitTimeoutChanged()) {
// If server does not support lock wait timeout, the state will not be changed, so it is safe.
tasks |= LOCK_WAIT_TIMEOUT;
- statements.add("SET innodb_lock_wait_timeout=" + state.getSessionLockWaitTimeout());
+ statements.add(StringUtils.lockWaitTimeoutStatement(context.getSessionLockWaitTimeout()));
}
tasks |= COMMIT_OR_ROLLBACK;
@@ -1193,10 +927,10 @@ boolean cancelTasks() {
protected boolean process(int task, SynchronousSink sink) {
switch (task) {
case LOCK_WAIT_TIMEOUT:
- state.resetCurrentLockWaitTimeout();
+ client.getContext().resetCurrentLockWaitTimeout();
return true;
case COMMIT_OR_ROLLBACK:
- state.resetIsolationLevel();
+ client.getContext().resetCurrentIsolationLevel();
sink.complete();
return false;
case CANCEL:
@@ -1222,26 +956,24 @@ final class StartTransactionState extends AbstractTransactionState {
private final TransactionDefinition definition;
- private final Client client;
-
- StartTransactionState(ConnectionState state, TransactionDefinition definition, Client client) {
- super(state);
+ StartTransactionState(Client client, TransactionDefinition definition) {
+ super(client);
this.definition = definition;
- this.client = client;
}
@Override
boolean cancelTasks() {
- if (state.isInTransaction()) {
+ final ConnectionContext context = client.getContext();
+ if (context.isInTransaction()) {
tasks |= CANCEL;
return true;
}
+
final Duration timeout = definition.getAttribute(TransactionDefinition.LOCK_WAIT_TIMEOUT);
if (timeout != null) {
- if (client.getContext().isLockWaitTimeoutSupported()) {
- long lockWaitTimeout = timeout.getSeconds();
+ if (context.isLockWaitTimeoutSupported()) {
tasks |= LOCK_WAIT_TIMEOUT;
- statements.add("SET innodb_lock_wait_timeout=" + lockWaitTimeout);
+ statements.add(StringUtils.lockWaitTimeoutStatement(timeout));
} else {
QueryFlow.logger.warn(
"Lock wait timeout is not supported by server, transaction definition lockWaitTimeout is ignored");
@@ -1267,22 +999,19 @@ protected boolean process(int task, SynchronousSink sink) {
case LOCK_WAIT_TIMEOUT:
final Duration timeout = definition.getAttribute(TransactionDefinition.LOCK_WAIT_TIMEOUT);
if (timeout != null) {
- final long lockWaitTimeout = timeout.getSeconds();
- state.setCurrentLockWaitTimeout(lockWaitTimeout);
+ client.getContext().setCurrentLockWaitTimeout(timeout);
}
return true;
case ISOLATION_LEVEL:
- final IsolationLevel isolationLevel =
- definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL);
+ final IsolationLevel isolationLevel = definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL);
if (isolationLevel != null) {
- state.setIsolationLevel(isolationLevel);
+ client.getContext().setCurrentIsolationLevel(isolationLevel);
}
return true;
case START_TRANSACTION:
case CANCEL:
sink.complete();
return false;
-
}
sink.error(new IllegalStateException("Undefined transaction task: " + task + ", remain: " + tasks));
@@ -1352,14 +1081,14 @@ final class CreateSavepointState extends AbstractTransactionState {
private final String name;
- CreateSavepointState(final ConnectionState state, final String name) {
- super(state);
+ CreateSavepointState(final Client client, final String name) {
+ super(client);
this.name = name;
}
@Override
boolean cancelTasks() {
- if (!state.isInTransaction()) {
+ if (!client.getContext().isInTransaction()) {
tasks |= START_TRANSACTION;
statements.add("BEGIN");
}
diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
index e5c3596b6..1a96e2d79 100644
--- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
+++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
@@ -16,13 +16,14 @@
package io.asyncer.r2dbc.mysql.internal.util;
+import java.time.Duration;
import java.time.ZoneId;
import java.time.ZoneOffset;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
/**
- * A utility for processing {@link String} in MySQL/MariaDB.
+ * A utility for processing {@link String} and simple statements in MySQL/MariaDB.
*/
public final class StringUtils {
@@ -79,16 +80,48 @@ public static String extendReturning(String sql, String returning) {
return returning.isEmpty() ? sql : sql + " RETURNING " + returning;
}
+ /**
+ * Generates a {@link String} indicating the statement timeout variable. e.g. {@code "max_statement_time=1.5"} for
+ * MariaDB or {@code "max_execution_time=1500"} for MySQL.
+ *
+ * @param timeout the statement timeout
+ * @param isMariaDb whether the current server is MariaDB
+ * @return the statement timeout variable
+ */
+ public static String statementTimeoutVariable(Duration timeout, boolean isMariaDb) {
+ // mariadb: https://mariadb.com/kb/en/aborting-statements/
+ // mysql: https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/
+ // ref: https://github.com/mariadb-corporation/mariadb-connector-r2dbc
+ if (isMariaDb) {
+ // MariaDB supports fractional seconds with microsecond precision
+ double seconds = (timeout.getSeconds() + timeout.getNano() / 1_000_000_000.0);
+ return "max_statement_time=" + seconds;
+ }
+
+ return "max_execution_time=" + timeout.toMillis();
+ }
+
+ /**
+ * Generates a statement to set the lock wait timeout for the current session. It is using InnoDB-specific session
+ * variable {@code innodb_lock_wait_timeout}.
+ *
+ * @param timeout the lock wait timeout
+ * @return the lock wait timeout statement
+ */
+ public static String lockWaitTimeoutStatement(Duration timeout) {
+ return "SET innodb_lock_wait_timeout=" + timeout.getSeconds();
+ }
+
/**
* Parses a normalized {@link ZoneId} from a time zone string of MySQL.
*
- * Note: since java 14.0.2, 11.0.8, 8u261 and 7u271, America/Nuuk is already renamed from America/Godthab.
- * See also tzdata2020a
+ * Note: since java 14.0.2, 11.0.8, 8u261 and 7u271, America/Nuuk is already renamed from America/Godthab. See also
+ * tzdata2020a
*
* @param zoneId the time zone string
* @return the normalized {@link ZoneId}
- * @throws IllegalArgumentException if the time zone string is {@code null} or empty
- * @throws java.time.DateTimeException if the time zone string has an invalid format
+ * @throws IllegalArgumentException if the time zone string is {@code null} or empty
+ * @throws java.time.DateTimeException if the time zone string has an invalid format
* @throws java.time.zone.ZoneRulesException if the time zone string cannot be found
*/
public static ZoneId parseZoneId(String zoneId) {
diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java
index 5e2be6114..5d0635412 100644
--- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java
+++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java
@@ -16,10 +16,13 @@
package io.asyncer.r2dbc.mysql;
+import io.asyncer.r2dbc.mysql.cache.Caches;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
+import io.r2dbc.spi.IsolationLevel;
import org.junit.jupiter.api.Test;
+import java.time.Duration;
import java.time.ZoneId;
import static org.assertj.core.api.Assertions.assertThat;
@@ -46,15 +49,36 @@ void getTimeZone() {
void setTwiceTimeZone() {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, null);
- context.initTimeZone(ZoneId.systemDefault());
- assertThatIllegalStateException().isThrownBy(() -> context.initTimeZone(ZoneId.systemDefault()));
+
+ context.initSession(
+ Caches.createPrepareCache(0),
+ IsolationLevel.REPEATABLE_READ,
+ false, Duration.ZERO,
+ null,
+ ZoneId.systemDefault()
+ );
+ assertThatIllegalStateException().isThrownBy(() -> context.initSession(
+ Caches.createPrepareCache(0),
+ IsolationLevel.REPEATABLE_READ,
+ false,
+ Duration.ZERO,
+ null,
+ ZoneId.systemDefault()
+ ));
}
@Test
void badSetTimeZone() {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, ZoneId.systemDefault());
- assertThatIllegalStateException().isThrownBy(() -> context.initTimeZone(ZoneId.systemDefault()));
+ assertThatIllegalStateException().isThrownBy(() -> context.initSession(
+ Caches.createPrepareCache(0),
+ IsolationLevel.REPEATABLE_READ,
+ false,
+ Duration.ZERO,
+ null,
+ ZoneId.systemDefault()
+ ));
}
public static ConnectionContext mock() {
@@ -69,7 +93,7 @@ public static ConnectionContext mock(boolean isMariaDB, ZoneId zoneId) {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, zoneId);
- context.init(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"),
+ context.initHandshake(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"),
Capability.of(~(isMariaDB ? 1 : 0)));
context.setServerStatuses(ServerStatuses.AUTO_COMMIT);
diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
index b45d7f91c..8fa06f1f9 100644
--- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
+++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
@@ -68,16 +68,16 @@ class ConnectionIntegrationTest extends IntegrationTestSupport {
@Test
void isInTransaction() {
- castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.isInTransaction())
+ castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.context().isInTransaction())
.isFalse())
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(connection.commitTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()));
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()));
}
@DisabledIf("envIsLessThanMySql56")
@@ -88,16 +88,16 @@ void startTransaction() {
TransactionDefinition readWriteConsistent = MySqlTransactionDefinition.mutability(true)
.consistent();
- castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.isInTransaction())
+ castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.context().isInTransaction())
.isFalse())
.then(connection.beginTransaction(readOnlyConsistent))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())
.then(connection.beginTransaction(readWriteConsistent))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()));
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()));
}
@Test
@@ -115,9 +115,9 @@ void autoRollbackPreRelease() {
.flatMap(MySqlResult::getRowsUpdated)
.single()
.doOnNext(it -> assertThat(it).isEqualTo(1))
- .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(conn.context().isInTransaction()).isTrue())
.then(conn.preRelease())
- .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(conn.context().isInTransaction()).isFalse())
.then(conn.postAllocate())
.thenMany(conn.createStatement("SELECT * FROM test")
.execute())
@@ -143,7 +143,7 @@ void shouldNotRollbackCommittedPreRelease() {
.doOnNext(it -> assertThat(it).isEqualTo(1))
.then(conn.commitTransaction())
.then(conn.preRelease())
- .doOnSuccess(ignored -> assertThat(conn.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(conn.context().isInTransaction()).isFalse())
.then(conn.postAllocate())
.thenMany(conn.createStatement("SELECT * FROM test")
.execute())
@@ -158,15 +158,15 @@ void transactionDefinitionLockWaitTimeout() {
.beginTransaction(MySqlTransactionDefinition.empty()
.lockWaitTimeout(Duration.ofSeconds(345)))
.doOnSuccess(ignored -> {
- assertThat(connection.isInTransaction()).isTrue();
+ assertThat(connection.context().isInTransaction()).isTrue();
assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ);
- assertThat(connection.isLockWaitTimeoutChanged()).isTrue();
+ assertThat(connection.context().isLockWaitTimeoutChanged()).isTrue();
})
.then(connection.rollbackTransaction())
.doOnSuccess(ignored -> {
- assertThat(connection.isInTransaction()).isFalse();
+ assertThat(connection.context().isInTransaction()).isFalse();
assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ);
- assertThat(connection.isLockWaitTimeoutChanged()).isFalse();
+ assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse();
}));
}
@@ -175,15 +175,15 @@ void transactionDefinitionIsolationLevel() {
castedComplete(connection -> connection
.beginTransaction(MySqlTransactionDefinition.from(READ_COMMITTED))
.doOnSuccess(ignored -> {
- assertThat(connection.isInTransaction()).isTrue();
+ assertThat(connection.context().isInTransaction()).isTrue();
assertThat(connection.getTransactionIsolationLevel()).isEqualTo(READ_COMMITTED);
- assertThat(connection.isLockWaitTimeoutChanged()).isFalse();
+ assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse();
})
.then(connection.rollbackTransaction())
.doOnSuccess(ignored -> {
- assertThat(connection.isInTransaction()).isFalse();
+ assertThat(connection.context().isInTransaction()).isFalse();
assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ);
- assertThat(connection.isLockWaitTimeoutChanged()).isFalse();
+ assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse();
}));
}
@@ -194,7 +194,7 @@ void setTransactionLevelNotInTransaction() {
Mono.fromSupplier(connection::getTransactionIsolationLevel)
.doOnSuccess(it -> assertThat(it).isEqualTo(REPEATABLE_READ))
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.fromSupplier(connection::getTransactionIsolationLevel))
.doOnSuccess(it -> assertThat(it).isEqualTo(REPEATABLE_READ))
.then(connection.rollbackTransaction())
@@ -203,7 +203,7 @@ void setTransactionLevelNotInTransaction() {
.then(Mono.fromSupplier(connection::getTransactionIsolationLevel))
.doOnSuccess(it -> assertThat(it).isEqualTo(READ_COMMITTED))
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
// ensure transaction isolation level applies to subsequent transactions
.then(Mono.fromSupplier(connection::getTransactionIsolationLevel))
.doOnSuccess(it -> assertThat(it).isEqualTo(READ_COMMITTED))
@@ -222,13 +222,13 @@ void setTransactionLevelInTransaction() {
.then(Mono.fromSupplier(connection::getTransactionIsolationLevel))
.doOnSuccess(it -> assertThat(it).isNotEqualTo(READ_COMMITTED))
.then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())
// ensure that session isolation level is changed after rollback
.then(Mono.fromSupplier(connection::getTransactionIsolationLevel))
.doOnSuccess(it -> assertThat(it).isEqualTo(READ_COMMITTED))
// ensure transaction isolation level applies to subsequent transactions
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
);
}
@@ -240,15 +240,15 @@ void transactionDefinition() {
.lockWaitTimeout(Duration.ofSeconds(112))
.consistent())
.doOnSuccess(ignored -> {
- assertThat(connection.isInTransaction()).isTrue();
+ assertThat(connection.context().isInTransaction()).isTrue();
assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ);
- assertThat(connection.isLockWaitTimeoutChanged()).isTrue();
+ assertThat(connection.context().isLockWaitTimeoutChanged()).isTrue();
})
.then(connection.rollbackTransaction())
.doOnSuccess(ignored -> {
- assertThat(connection.isInTransaction()).isFalse();
+ assertThat(connection.context().isInTransaction()).isFalse();
assertThat(connection.getTransactionIsolationLevel()).isEqualTo(REPEATABLE_READ);
- assertThat(connection.isLockWaitTimeoutChanged()).isFalse();
+ assertThat(connection.context().isLockWaitTimeoutChanged()).isFalse();
}));
}
@@ -290,7 +290,7 @@ void createSavepointAndRollbackToSavepoint(String savepoint) {
"CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
@@ -301,7 +301,7 @@ void createSavepointAndRollbackToSavepoint(String savepoint) {
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.createSavepoint(savepoint))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
@@ -312,12 +312,12 @@ void createSavepointAndRollbackToSavepoint(String savepoint) {
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(4))
.then(connection.rollbackTransactionToSavepoint(savepoint))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(0))
@@ -331,7 +331,7 @@ void createSavepointAndRollbackEntireTransaction(String savepoint) {
"CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
@@ -342,7 +342,7 @@ void createSavepointAndRollbackEntireTransaction(String savepoint) {
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.createSavepoint(savepoint))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
@@ -353,7 +353,7 @@ void createSavepointAndRollbackEntireTransaction(String savepoint) {
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(4))
.then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(0))
@@ -374,8 +374,7 @@ void rollbackTransactionWithoutBegin() {
void setTransactionIsolationLevel() {
complete(connection -> Flux.just(READ_UNCOMMITTED, READ_COMMITTED, REPEATABLE_READ, SERIALIZABLE)
.concatMap(level -> connection.setTransactionIsolationLevel(level)
- .map(ignored -> assertThat(level))
- .doOnNext(a -> a.isEqualTo(connection.getTransactionIsolationLevel()))));
+ .doOnSuccess(ignored -> assertThat(level).isEqualTo(connection.getTransactionIsolationLevel()))));
}
@Test
@@ -400,7 +399,7 @@ void commitTransactionShouldRespectQueuedMessages() {
.execute(),
connection.commitTransaction()
))
- .doOnComplete(() -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnComplete(() -> assertThat(connection.context().isInTransaction()).isFalse())
.thenMany(connection.createStatement("SELECT COUNT(*) FROM test").execute())
.flatMap(result ->
Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))
@@ -421,7 +420,7 @@ void rollbackTransactionShouldRespectQueuedMessages() {
.execute(),
connection.rollbackTransaction()
))
- .doOnComplete(() -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnComplete(() -> assertThat(connection.context().isInTransaction()).isFalse())
.thenMany(connection.createStatement("SELECT COUNT(*) FROM test").execute())
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))
.doOnNext(count -> assertThat(count).isEqualTo(0L)))
@@ -435,15 +434,15 @@ void beginTransactionShouldRespectQueuedMessages() {
Mono.from(connection.createStatement(tdl).execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.beginTransaction()))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isTrue())
.thenMany(Flux.merge(
connection.createStatement("INSERT INTO test VALUES (1, 'test1')").execute(),
connection.commitTransaction(),
connection.beginTransaction()
))
- .doOnComplete(() -> assertThat(connection.isInTransaction()).isTrue())
+ .doOnComplete(() -> assertThat(connection.context().isInTransaction()).isTrue())
.then(Mono.from(connection.rollbackTransaction()))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse())
.thenMany(connection.createStatement("SELECT COUNT(*) FROM test").execute())
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))
.doOnNext(count -> assertThat(count).isEqualTo(1L)))
diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java
index c8d50c633..b2847c20d 100644
--- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java
+++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java
@@ -16,21 +16,36 @@
package io.asyncer.r2dbc.mysql;
+import io.asyncer.r2dbc.mysql.api.MySqlTransactionDefinition;
import io.asyncer.r2dbc.mysql.cache.Caches;
+import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
+import io.asyncer.r2dbc.mysql.client.FluxExchangeable;
import io.asyncer.r2dbc.mysql.codec.Codecs;
+import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.TextQueryMessage;
+import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.r2dbc.spi.IsolationLevel;
import org.assertj.core.api.ThrowableTypeAssert;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
+import reactor.core.publisher.SynchronousSink;
import reactor.test.StepVerifier;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -39,39 +54,24 @@
*/
class MySqlSimpleConnectionTest {
- private final Client client;
-
- private final Codecs codecs = mock(Codecs.class);
-
- private final IsolationLevel level = IsolationLevel.REPEATABLE_READ;
-
- private final String product = "MockConnection";
-
- private final MySqlSimpleConnection noPrepare;
-
- MySqlSimpleConnectionTest() {
- Client client = mock(Client.class);
-
- when(client.getContext()).thenReturn(ConnectionContextTest.mock());
-
- this.client = client;
- this.noPrepare = new MySqlSimpleConnection(client,
- codecs, level, 50, Caches.createQueryCache(0),
- Caches.createPrepareCache(0), product, null);
- }
+ private static final Codecs CODECS = mock(Codecs.class);
@Test
void createStatement() {
String condition = "SELECT * FROM test";
- MySqlSimpleConnection allPrepare = new MySqlSimpleConnection(client,
- codecs, level, 50, Caches.createQueryCache(0),
- Caches.createPrepareCache(0), product, sql -> true);
- MySqlSimpleConnection halfPrepare = new MySqlSimpleConnection(client,
- codecs, level, 50, Caches.createQueryCache(0),
- Caches.createPrepareCache(0), product, sql -> false);
- MySqlSimpleConnection conditionPrepare = new MySqlSimpleConnection(client,
- codecs, level, 50, Caches.createQueryCache(0),
- Caches.createPrepareCache(0), product, sql -> sql.equals(condition));
+ MySqlSimpleConnection allPrepare = new MySqlSimpleConnection(
+ mockClient(),
+ CODECS,
+ Caches.createQueryCache(0), sql -> true);
+ MySqlSimpleConnection halfPrepare = new MySqlSimpleConnection(
+ mockClient(),
+ CODECS,
+ Caches.createQueryCache(0), sql -> false);
+ MySqlSimpleConnection conditionPrepare = new MySqlSimpleConnection(
+ mockClient(),
+ CODECS,
+ Caches.createQueryCache(0), sql -> sql.equals(condition));
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
assertThat(noPrepare.createStatement("SELECT * FROM test WHERE id=1"))
.isExactlyInstanceOf(TextSimpleStatement.class);
@@ -105,12 +105,14 @@ void createStatement() {
@SuppressWarnings("ConstantConditions")
@Test
void badCreateStatement() {
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
assertThatIllegalArgumentException().isThrownBy(() -> noPrepare.createStatement(null));
}
@SuppressWarnings("ConstantConditions")
@Test
void badCreateSavepoint() {
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
ThrowableTypeAssert> asserted = assertThatIllegalArgumentException();
asserted.isThrownBy(() -> noPrepare.createSavepoint(""));
@@ -120,6 +122,7 @@ void badCreateSavepoint() {
@SuppressWarnings("ConstantConditions")
@Test
void badReleaseSavepoint() {
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
ThrowableTypeAssert> asserted = assertThatIllegalArgumentException();
asserted.isThrownBy(() -> noPrepare.releaseSavepoint(""));
@@ -129,6 +132,7 @@ void badReleaseSavepoint() {
@SuppressWarnings("ConstantConditions")
@Test
void badRollbackTransactionToSavepoint() {
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
ThrowableTypeAssert> asserted = assertThatIllegalArgumentException();
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint(""));
@@ -138,24 +142,120 @@ void badRollbackTransactionToSavepoint() {
@SuppressWarnings("ConstantConditions")
@Test
void badSetTransactionIsolationLevel() {
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
assertThatIllegalArgumentException().isThrownBy(() -> noPrepare.setTransactionIsolationLevel(null));
}
- @Test
- void shouldSetTransactionIsolationLevelSuccessfully() {
- ClientMessage message = new TextQueryMessage("SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE");
+ @ParameterizedTest
+ @ValueSource(strings = { "READ UNCOMMITTED", "READ COMMITTED", "REPEATABLE READ", "SERIALIZABLE" })
+ void shouldSetTransactionIsolationLevelSuccessfully(String levelSql) {
+ Client client = mockClient();
+ IsolationLevel level = IsolationLevel.valueOf(levelSql);
+ ClientMessage message = new TextQueryMessage("SET SESSION TRANSACTION ISOLATION LEVEL " + levelSql);
+
when(client.exchange(eq(message), any())).thenReturn(Flux.empty());
- noPrepare.setTransactionIsolationLevel(IsolationLevel.SERIALIZABLE)
+ MySqlSimpleConnection noPrepare = newNoPrepare(client);
+ noPrepare.setTransactionIsolationLevel(level)
.as(StepVerifier::create)
.verifyComplete();
- assertThat(noPrepare.getSessionTransactionIsolationLevel()).isEqualTo(IsolationLevel.SERIALIZABLE);
+ assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(level);
+ assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(level);
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = {
+ "READ UNCOMMITTED,SERIALIZABLE",
+ "READ COMMITTED,REPEATABLE READ",
+ "REPEATABLE READ,READ UNCOMMITTED"
+ })
+ void shouldSetTransactionIsolationLevelInTransaction(String levels) {
+ String[] levelStatements = levels.split(",");
+ IsolationLevel currentLevel = IsolationLevel.valueOf(levelStatements[0]);
+ IsolationLevel sessionLevel = IsolationLevel.valueOf(levelStatements[1]);
+ Client client = mockClient();
+ ClientMessage session = new TextQueryMessage("SET SESSION TRANSACTION ISOLATION LEVEL " + sessionLevel.asSql());
+ CompleteMessage mockDone = mock(CompleteMessage.class);
+ @SuppressWarnings("unchecked")
+ SynchronousSink sink = (SynchronousSink) mock(SynchronousSink.class);
+ AtomicBoolean completed = new AtomicBoolean(false);
+
+ doAnswer(it -> {
+ throw it.getArgument(0, Exception.class);
+ }).when(sink).error(any());
+ doAnswer(it -> {
+ completed.set(true);
+ return null;
+ }).when(sink).complete();
+ when(mockDone.isDone()).thenReturn(true);
+ when(client.exchange(eq(session), any())).thenReturn(Flux.empty());
+ when(client.exchange(any())).thenAnswer(it -> {
+ FluxExchangeable exchangeable = it.getArgument(0);
+ @SuppressWarnings("unchecked")
+ CoreSubscriber super ClientMessage> subscriber = mock(CoreSubscriber.class);
+ exchangeable.subscribe(subscriber);
+
+ while (!completed.get()) {
+ exchangeable.accept(mockDone, sink);
+ }
+
+ // Mock server status to be in transaction
+ client.getContext().setServerStatuses(ServerStatuses.IN_TRANSACTION);
+
+ return Flux.empty();
+ });
+
+ IsolationLevel mockLevel = IsolationLevel.valueOf("DEFAULT");
+ client.getContext().initSession(
+ mock(PrepareCache.class),
+ mockLevel,
+ false,
+ Duration.ZERO,
+ null,
+ null
+ );
+ MySqlSimpleConnection noPrepare = newNoPrepare(client);
+
+ assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(mockLevel);
+ assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(mockLevel);
+
+ noPrepare.beginTransaction(MySqlTransactionDefinition.from(currentLevel))
+ .as(StepVerifier::create)
+ .verifyComplete();
+
+ assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(currentLevel);
+ assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(mockLevel);
+
+ noPrepare.setTransactionIsolationLevel(sessionLevel)
+ .as(StepVerifier::create)
+ .verifyComplete();
+
+ assertThat(client.getContext().getCurrentIsolationLevel()).isEqualTo(currentLevel);
+ assertThat(client.getContext().getSessionIsolationLevel()).isEqualTo(sessionLevel);
}
@SuppressWarnings("ConstantConditions")
@Test
void badValidate() {
+ MySqlSimpleConnection noPrepare = newNoPrepare(mockClient());
assertThatIllegalArgumentException().isThrownBy(() -> noPrepare.validate(null));
}
+
+ private static Client mockClient() {
+ Client client = mock(Client.class);
+
+ when(client.getContext()).thenReturn(ConnectionContextTest.mock());
+
+ return client;
+ }
+
+ private static MySqlSimpleConnection newNoPrepare(Client client) {
+ return new MySqlSimpleConnection(
+ client,
+ CODECS,
+ Caches.createQueryCache(0),
+ null
+ );
+ }
}
diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java
index 345704af5..94e1591f4 100644
--- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java
+++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParameterizedStatementTest.java
@@ -52,8 +52,7 @@ public PrepareParameterizedStatement makeInstance(boolean isMariaDB, String sql,
return new PrepareParameterizedStatement(
client,
codecs,
- Query.parse(sql),
- Caches.createPrepareCache(0)
+ Query.parse(sql)
);
}
diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java
index 0e18e7233..56d5ac907 100644
--- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java
+++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java
@@ -16,7 +16,6 @@
package io.asyncer.r2dbc.mysql;
-import io.asyncer.r2dbc.mysql.cache.Caches;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
@@ -64,12 +63,7 @@ public PrepareSimpleStatement makeInstance(boolean isMariaDB, String ignored, St
when(client.getContext()).thenReturn(ConnectionContextTest.mock(isMariaDB));
- return new PrepareSimpleStatement(
- client,
- codecs,
- sql,
- Caches.createPrepareCache(0)
- );
+ return new PrepareSimpleStatement(client, codecs, sql);
}
@Override