Skip to content

Commit

Permalink
Add support for SQL mode NO_BACKSLASH_ESCAPES
Browse files Browse the repository at this point in the history
  • Loading branch information
mirromutth committed Mar 29, 2024
1 parent aa90297 commit 8c24b1c
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ public boolean isMariaDb() {
return (capability != null && capability.isMariaDb()) || serverVersion.isMariaDb();
}

public boolean isNoBackslashEscapes() {
return (serverStatuses & ServerStatuses.NO_BACKSLASH_ESCAPES) != 0;
}

void initTimeZone(ZoneId timeZone) {
if (isTimeZoneInitialized()) {
throw new IllegalStateException("Connection timezone have been initialized");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ private static Mono<SessionData> loadSessionVariables(Client client, Codecs code
}

private static Mono<SessionData> loadInnoDbEngineStatus(SessionData data, Client client, Codecs codecs) {
return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\\\_lock\\\\_wait\\\\_timeout'")
return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\_lock\\_wait\\_timeout'")
.execute()
.flatMap(r -> r.map(readable -> {
String value = readable.get(1, String.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,20 @@ public final class ServerStatuses {
public static final short LAST_ROW_SENT = 128;

// public static final short DB_DROPPED = 256;
// public static final short NO_BACKSLASH_ESCAPES = 512;

/**
* Server does not permit backslash escapes.
*
* @since 1.1.3
*/
public static final short NO_BACKSLASH_ESCAPES = 512;

// public static final short METADATA_CHANGED = 1024;
// public static final short QUERY_WAS_SLOW = 2048;
// public static final short PS_OUT_PARAMS = 4096;
// public static final short IN_TRANS_READONLY = 8192;
// public static final short SESSION_STATE_CHANGED = 16384;

private ServerStatuses() { }
private ServerStatuses() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ final class ParamWriter extends ParameterWriter {

private final StringBuilder builder;

private final boolean noBackslashEscapes;

private final Query query;

private int index;

private Mode mode;

private ParamWriter(Query query) {
private ParamWriter(boolean noBackslashEscapes, Query query) {
this.builder = newBuilder(query);
this.noBackslashEscapes = noBackslashEscapes;
this.query = query;
this.index = 1;
this.mode = 1 < query.getPartSize() ? Mode.AVAILABLE : Mode.FULL;
Expand Down Expand Up @@ -318,15 +321,19 @@ private void write0(char[] s, int off, int len) {
}

private void escape(char c) {
if (c == '\'') {
// MySQL will auto-combine consecutive strings, whatever backslash is used or not, e.g. '1''2' -> '1\'2'
builder.append('\'').append('\'');
return;
} else if (noBackslashEscapes) {
builder.append(c);
return;
}

switch (c) {
case '\\':
builder.append('\\').append('\\');
break;
case '\'':
// MySQL will auto-combine consecutive strings, like '1''2' -> '12'.
// Sure, there can use '1\'2', but this will be better. (For some logging systems)
builder.append('\'').append('\'');
break;
// Maybe useful in the future, keep '"' here.
// case '"': buf.append('\\').append('"'); break;
// SHIFT-JIS, WINDOWS-932, EUC-JP and eucJP-OPEN will encode '\u00a5' (the sign of Japanese Yen
Expand All @@ -335,20 +342,19 @@ private void escape(char c) {
// case '\u00a5': do something; break;
// case '\u20a9': do something; break;
case 0:
// MySQL is based on C/C++, must escape '\0' which is an end flag in C style string.
// Should escape '\0' which is an end flag in C style string.
builder.append('\\').append('0');
break;
case '\032':
// It seems like a problem on Windows 32, maybe check current OS here?
// It gives some problems on Win32.
builder.append('\\').append('Z');
break;
case '\n':
// Should escape it for some logging such as Relational Database Service (RDS) Logging
// System, etc. Sure, it is not necessary, but this will be better.
// Should be escaped for better logging.
builder.append('\\').append('n');
break;
case '\r':
// Should escape it for some logging such as RDS Logging System, etc.
// Should be escaped for better logging.
builder.append('\\').append('r');
break;
default:
Expand All @@ -357,9 +363,9 @@ private void escape(char c) {
}
}

static Mono<String> publish(Query query, Flux<MySqlParameter> values) {
static Mono<String> publish(boolean noBackslashEscapes, Query query, Flux<MySqlParameter> values) {
return Mono.defer(() -> {
ParamWriter writer = new ParamWriter(query);
ParamWriter writer = new ParamWriter(noBackslashEscapes, query);

return OperatorUtils.discardOnCancel(values)
.doOnDiscard(MySqlParameter.class, DISPOSE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public Mono<ByteBuf> encode(ByteBufAllocator allocator, ConnectionContext contex
return Flux.fromArray(values);
});

return ParamWriter.publish(query, parameters).handle((it, sink) -> {
return ParamWriter.publish(context.isNoBackslashEscapes(), query, parameters).handle((it, sink) -> {
ByteBuf buf = allocator.buffer();

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.Collections;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Stream;

import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED;
import static io.r2dbc.spi.IsolationLevel.READ_UNCOMMITTED;
Expand Down Expand Up @@ -80,6 +81,51 @@ void isInTransaction() {
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()));
}

@ParameterizedTest
@ValueSource(strings = {
"test",
"test`data",
"test\ndata",
"I'm feeling good",
})
void sqlModeNoBackslashEscapes(String value) {
String tdl = "CREATE TEMPORARY TABLE `test` (`id` INT NOT NULL PRIMARY KEY, `value` VARCHAR(50) NOT NULL)";

// Add NO_BACKSLASH_ESCAPES instead of replace
// TODO: get context from connection, check if NO_BACKSLASH_ESCAPES is already in server statuses
castedComplete(connection -> Flux.from(connection.createStatement(tdl).execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO test VALUES (1, ?)")
.bind(0, value)
.execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("SELECT COUNT(0) FROM `test` WHERE `value` = ?")
.bind(0, value)
.execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0, Integer.class)))
.collectList()
.doOnNext(counts -> assertThat(counts).isEqualTo(Collections.singletonList(1)))
.thenMany(connection.createStatement("SELECT @@sql_mode").execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0, String.class)))
.map(modes -> Stream.concat(Stream.of(modes.split(",")), Stream.of("NO_BACKSLASH_ESCAPES"))
.toArray(String[]::new))
.last()
.flatMapMany(modes -> connection.createStatement("SET sql_mode = ?")
.bind(0, modes)
.execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO test VALUES (2, ?)")
.bind(0, value)
.execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("SELECT COUNT(0) FROM `test` WHERE `value` = ?")
.bind(0, value)
.execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0, Integer.class)))
.collectList()
.doOnNext(counts -> assertThat(counts).isEqualTo(Collections.singletonList(2))));
}

@DisabledIf("envIsLessThanMySql56")
@Test
void startTransaction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ default void encodeStringify() {
Query query = Query.parse("?");

for (int i = 0; i < origin.length; ++i) {
ParameterWriter writer = ParameterWriterHelper.get(query);
ParameterWriter writer = ParameterWriterHelper.get(false, query);
codec.encode(origin[i], context())
.publishText(writer)
.as(StepVerifier::create)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void stringifySet() {
Query query = Query.parse("?");

for (int i = 0; i < sets.length; ++i) {
ParameterWriter writer = ParameterWriterHelper.get(query);
ParameterWriter writer = ParameterWriterHelper.get(false, query);
codec.encode(sets[i], context())
.publishText(writer)
.as(StepVerifier::create)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import io.asyncer.r2dbc.mysql.Query;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;

Expand Down Expand Up @@ -84,42 +86,42 @@ void badFollowNull() {

@Test
void appendPart() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.append("define", 2, 5);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'fin'");
}

@Test
void writePart() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write("define", 2, 3);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'fin'");
}

@Test
void appendNull() {
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(parameterOnly(1)).append(null)))
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(false, parameterOnly(1)).append(null)))
.isEqualTo("'null'");
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(parameterOnly(1))
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(false, parameterOnly(1))
.append(null, 1, 3)))
.isEqualTo("'ul'");
}

@Test
void writeNull() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((String) null);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'null'");

writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((String) null, 1, 2);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'ul'");

writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((char[]) null);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'null'");

writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((char[]) null, 1, 2);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'ul'");
}
Expand All @@ -132,7 +134,7 @@ void publishSuccess() {
values[i] = new MockMySqlParameter(true);
}

Flux.from(ParamWriter.publish(parameterOnly(SIZE), Flux.fromArray(values)))
Flux.from(ParamWriter.publish(false, parameterOnly(SIZE), Flux.fromArray(values)))
.as(StepVerifier::create)
.expectNext(new String(new char[SIZE]).replace("\0", "''"))
.verifyComplete();
Expand All @@ -154,7 +156,7 @@ void publishPartially() {
values[i] = new MockMySqlParameter(false);
}

Flux.from(ParamWriter.publish(parameterOnly(SIZE), Flux.fromArray(values)))
Flux.from(ParamWriter.publish(false, parameterOnly(SIZE), Flux.fromArray(values)))
.as(StepVerifier::create)
.verifyError(MockException.class);

Expand All @@ -169,13 +171,30 @@ void publishNothing() {
values[i] = new MockMySqlParameter(false);
}

Flux.from(ParamWriter.publish(parameterOnly(SIZE), Flux.fromArray(values)))
Flux.from(ParamWriter.publish(false, parameterOnly(SIZE), Flux.fromArray(values)))
.as(StepVerifier::create)
.verifyError(MockException.class);

assertThat(values).extracting(MockMySqlParameter::refCnt).containsOnly(0);
}

@ParameterizedTest
@ValueSource(strings = {
"abc",
"a'b'c",
"a\nb\rc",
"a\"b\"c",
"a\\b\\c",
"a\0b\0c",
"a\u00a5b\u20a9c",
"a\032b\032c",
})
void noBackslashEscapes(String value) {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(true, parameterOnly(1));
writer.write(value);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'" + value.replaceAll("'", "''") + "'");
}

private static Query parameterOnly(int parameters) {
char[] chars = new char[parameters];
Arrays.fill(chars, '?');
Expand All @@ -184,13 +203,13 @@ private static Query parameterOnly(int parameters) {
}

private static ParamWriter stringWriter() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write('0');
return writer;
}

private static ParamWriter nullWriter() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.writeNull();
return writer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public final class ParameterWriterHelper {
ReflectionUtils.findMethod(ParamWriter.class, "toSql")
.orElseThrow(RuntimeException::new);

public static ParameterWriter get(Query query) {
public static ParameterWriter get(boolean noBackslashEscapes, Query query) {
assertThat(query.getPartSize()).isGreaterThan(1);

return ReflectionUtils.newInstance(CONSTRUCTOR, query);
return ReflectionUtils.newInstance(CONSTRUCTOR, noBackslashEscapes, query);
}

public static String toSql(ParameterWriter writer) {
Expand Down

0 comments on commit 8c24b1c

Please sign in to comment.