Skip to content

Commit

Permalink
Add support for create database if not exist
Browse files Browse the repository at this point in the history
- Support `InitDbMessage`.
- Support `changeDatabase` in `MySqlConnection`.
- Add integration tests for that.
  • Loading branch information
mirromutth committed Dec 14, 2023
1 parent 46e9e69 commit 24f603c
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 30 deletions.
2 changes: 2 additions & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/Capability.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public final class Capability {

/**
* Can use long password.
* <p>
* TODO: Reinterpret it as {@code CLIENT_MYSQL} to support MariaDB 10.2 and above.
*/
private static final int LONG_PASSWORD = 1;

Expand Down
54 changes: 51 additions & 3 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
Expand Down Expand Up @@ -91,6 +92,30 @@ public final class MySqlConnection implements Connection, ConnectionState {
}
};

private static final BiConsumer<ServerMessage, SynchronousSink<Boolean>> INIT_DB = (message, sink) -> {
if (message instanceof ErrorMessage) {
ErrorMessage msg = (ErrorMessage) message;
logger.debug("Use database failed: [{}] {}", msg.getCode(), msg.getMessage());
sink.next(false);
sink.complete();
} else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
sink.next(true);
sink.complete();
} else {
ReferenceCountUtil.safeRelease(message);
}
};

private static final BiConsumer<ServerMessage, SynchronousSink<Void>> INIT_DB_AFTER = (message, sink) -> {
if (message instanceof ErrorMessage) {
sink.error(((ErrorMessage) message).toException());
} else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
sink.complete();
} else {
ReferenceCountUtil.safeRelease(message);
}
};

private final Client client;

private final Codecs codecs;
Expand Down Expand Up @@ -403,13 +428,17 @@ boolean isSessionAutoCommit() {
* @param client must be logged-in.
* @param codecs the {@link Codecs}.
* @param context must be initialized.
* @param database the database that should be lazy init.
* @param queryCache the cache of {@link Query}.
* @param prepareCache the cache of server-preparing result.
* @param prepare judging for prefer use prepare statement to execute simple query.
* @return a {@link Mono} will emit an initialized {@link MySqlConnection}.
*/
static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContext context,
QueryCache queryCache, PrepareCache prepareCache, @Nullable Predicate<String> prepare) {
static Mono<MySqlConnection> init(
Client client, Codecs codecs, ConnectionContext context, String database,
QueryCache queryCache, PrepareCache prepareCache,
@Nullable Predicate<String> prepare
) {
ServerVersion version = context.getServerVersion();
StringBuilder query = new StringBuilder(128);

Expand All @@ -431,7 +460,7 @@ static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContex
handler = MySqlConnection::init;
}

return new TextSimpleStatement(client, codecs, context, query.toString())
Mono<MySqlConnection> connection = new TextSimpleStatement(client, codecs, context, query.toString())
.execute()
.flatMap(handler)
.last()
Expand All @@ -445,6 +474,25 @@ static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContex
return new MySqlConnection(client, context, codecs, data.level, data.lockWaitTimeout,
queryCache, prepareCache, data.product, prepare);
});

if (database.isEmpty()) {
return connection;
}

requireValidName(database, "database must not be empty and not contain backticks");

return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB)
.last()
.flatMap(success -> {
if (success) {
return Mono.just(conn);
}

String sql = String.format("CREATE DATABASE IF NOT EXISTS `%s`", database);

return QueryFlow.executeVoid(client, sql)
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn)));
}));
}

private static Publisher<InitData> init(MySqlResult r) {
Expand Down
41 changes: 25 additions & 16 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import io.netty.channel.unix.DomainSocketAddress;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -86,6 +85,7 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
}

String database = configuration.getDatabase();
boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist();
String user = configuration.getUser();
CharSequence password = configuration.getPassword();
SslMode sslMode = ssl.getSslMode();
Expand All @@ -95,32 +95,36 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
Predicate<String> prepare = configuration.getPreferPrepareStatement();
int prepareCacheSize = configuration.getPrepareCacheSize();
Publisher<String> passwordPublisher = configuration.getPasswordPublisher();

if (Objects.nonNull(passwordPublisher)) {
return Mono.from(passwordPublisher)
.flatMap(token -> getMySqlConnection(
configuration, queryCache,
ssl, address,
database, user,
sslMode, context,
extensions, prepare,
prepareCacheSize, token));
return Mono.from(passwordPublisher).flatMap(token -> getMySqlConnection(
configuration, queryCache,
ssl, address,
database, createDbIfNotExist,
user, sslMode, context,
extensions, prepare,
prepareCacheSize, token
));
}
return getMySqlConnection(configuration, queryCache,

return getMySqlConnection(
configuration, queryCache,
ssl, address,
database, user,
sslMode, context,
database, createDbIfNotExist,
user, sslMode, context,
extensions, prepare,
prepareCacheSize, password);
prepareCacheSize, password
);
}));
}

@NotNull
private static Mono<MySqlConnection> getMySqlConnection(
final MySqlConnectionConfiguration configuration,
final LazyQueryCache queryCache,
final MySqlSslConfiguration ssl,
final SocketAddress address,
final String database,
final boolean createDbIfNotExist,
final String user,
final SslMode sslMode,
final ConnectionContext context,
Expand All @@ -130,16 +134,21 @@ private static Mono<MySqlConnection> getMySqlConnection(
@Nullable final CharSequence password) {
return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(),
context, configuration.getConnectTimeout(), configuration.getSocketTimeout())
.flatMap(client -> QueryFlow.login(client, sslMode, database, user, password, context))
.flatMap(client -> {
// Lazy init database after handshake/login
String db = createDbIfNotExist ? "" : database;
return QueryFlow.login(client, sslMode, db, user, password, context);
})
.flatMap(client -> {
ByteBufAllocator allocator = client.getByteBufAllocator();
CodecsBuilder builder = Codecs.builder(allocator);
PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize);
String db = createDbIfNotExist ? database : "";

extensions.forEach(CodecRegistrar.class, registrar ->
registrar.register(allocator, builder));

return MySqlConnection.init(client, builder.build(), context, queryCache.get(),
return MySqlConnection.init(client, builder.build(), context, db, queryCache.get(),
prepareCache, prepare);
});
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ private Capability clientCapability(Capability serverCapability) {

builder.disableDatabasePinned();
builder.disableCompression();
// TODO: support LOAD DATA LOCAL INFILE
builder.disableLoadDataInfile();
builder.disableIgnoreAmbiguitySpace();
builder.disableInteractiveTimeout();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.asyncer.r2dbc.mysql.message.client;

import io.asyncer.r2dbc.mysql.ConnectionContext;
import io.netty.buffer.ByteBuf;

public final class InitDbMessage extends ScalarClientMessage {

private static final byte FLAG = 0x02;

private final String database;

public InitDbMessage(String database) { this.database = database; }

@Override
protected void writeTo(ByteBuf buf, ConnectionContext context) {
// RestOfPacketString, no need terminal or length
buf.writeByte(FLAG).writeCharSequence(database, context.getClientCollation().getCharset());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
class ConnectionIntegrationTest extends IntegrationTestSupport {

ConnectionIntegrationTest() {
super(configuration(false, null, null));
super(configuration("r2dbc", false, false, null, null));
}

@Test
Expand Down
35 changes: 35 additions & 0 deletions src/test/java/io/asyncer/r2dbc/mysql/InitDbIntegrationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.asyncer.r2dbc.mysql;

import org.junit.jupiter.api.Test;

import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for {@code createDatabaseIfNotExist}.
*/
class InitDbIntegrationTest extends IntegrationTestSupport {

private static final String DATABASE = "test-" + ThreadLocalRandom.current().nextInt(10000);

InitDbIntegrationTest() {
super(configuration(
DATABASE, true, false,
null, null
));
}

@Test
void shouldCreateDatabase() {
complete(conn -> conn.createStatement("SHOW DATABASES")
.execute()
.flatMap(it -> it.map((row, rowMetadata) -> row.get(0, String.class)))
.collect(Collectors.toSet())
.doOnNext(it -> assertThat(it).contains(DATABASE))
.thenMany(conn.createStatement("DROP DATABASE `" + DATABASE + "`")
.execute()
.flatMap(MySqlResult::getRowsUpdated)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ static Mono<Long> extractRowsUpdated(Result result) {
return Mono.from(result.getRowsUpdated());
}

static MySqlConnectionConfiguration configuration(boolean autodetectExtensions,
@Nullable ZoneId serverZoneId, @Nullable Predicate<String> preferPrepared) {
static MySqlConnectionConfiguration configuration(
String database, boolean createDatabaseIfNotExist, boolean autodetectExtensions,
@Nullable ZoneId serverZoneId, @Nullable Predicate<String> preferPrepared
) {
String password = System.getProperty("test.mysql.password");

assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty")
Expand All @@ -84,7 +86,8 @@ static MySqlConnectionConfiguration configuration(boolean autodetectExtensions,
.connectTimeout(Duration.ofSeconds(3))
.user("root")
.password(password)
.database("r2dbc")
.database(database)
.createDatabaseIfNotExist(createDatabaseIfNotExist)
.autodetectExtensions(autodetectExtensions);

if (serverZoneId != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class JacksonPrepareIntegrationTest extends JacksonIntegrationTestSupport {

JacksonPrepareIntegrationTest() {
super(configuration(true, null, sql -> false));
super(configuration("r2dbc", false, true, null, sql -> false));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
class JacksonTextIntegrationTest extends JacksonIntegrationTestSupport {

JacksonTextIntegrationTest() {
super(configuration(true, null, null));
super(configuration("r2dbc", false, true, null, null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class MySqlPrepareTestKit extends MySqlTestKitSupport {

MySqlPrepareTestKit() {
super(IntegrationTestSupport.configuration(false, null, sql -> true));
super(IntegrationTestSupport.configuration("r2dbc", false, false, null, sql -> true));
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/io/asyncer/r2dbc/mysql/MySqlTextTestKit.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
class MySqlTextTestKit extends MySqlTestKitSupport {

MySqlTextTestKit() {
super(IntegrationTestSupport.configuration(false, null, null));
super(IntegrationTestSupport.configuration("r2dbc", false, false, null, null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class PrepareQueryIntegrationTest extends QueryIntegrationTestSupport {

PrepareQueryIntegrationTest() {
super(configuration(false, null, sql -> true));
super(configuration("r2dbc", false, false, null, sql -> true));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
class TextQueryIntegrationTest extends QueryIntegrationTestSupport {

TextQueryIntegrationTest() {
super(configuration(false, null, null));
super(configuration("r2dbc", false, false, null, null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ abstract class TimeZoneIntegrationTestSupport extends IntegrationTestSupport {
}

TimeZoneIntegrationTestSupport(@Nullable Predicate<String> preferPrepared) {
super(configuration(false, SERVER_ZONE, preferPrepared));
super(configuration("r2dbc", false, false, SERVER_ZONE, preferPrepared));
}

@Test
Expand Down

0 comments on commit 24f603c

Please sign in to comment.