Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for query cancellation and session reuse #6

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions lib/src/main/java/com/wherobots/db/Runtime.java
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
package com.wherobots.db;

public enum Runtime {
SEDONA("tiny"),
SAN_FRANCISCO("small"),
NEW_YORK("medium"),
CAIRO("large"),
DELHI("x-large"),
TOKYO("2x-large"),

NEW_YORK_HIMEM("medium-himem"),
CAIRO_HIMEM("large-himem"),
DEHLI_HIMEM("x-large-himem"),
TOKYO_HIMEM("2x-large-himem"),
ATLANTIS_HIMEM("4x-large-himem"),

SEDONA_GPU("tiny-a10-gpu"),
SAN_FRANCISCO_GPU("small-a10-gpu"),
NEW_YORK_GPU("medium-a10-gpu");
TINY("tiny"),
SMALL("small"),
MEDIUM("medium"),
LARGE("large"),
X_LARGE("x-large"),
XX_LARGE("2x-large"),

MEDIUM_HIMEM("medium-himem"),
LARGE_HIMEM("large-himem"),
X_LARGE_HIMEM("x-large-himem"),
XX_LARGE_HIMEM("2x-large-himem"),
XXXX_LARGE_HIMEM("4x-large-himem"),

TINY_A10_GPU("tiny-a10-gpu"),
SMALL_GPU("small-a10-gpu"),
MEDIUM_GPU("medium-a10-gpu"),

@Deprecated SEDONA("tiny"),
@Deprecated SAN_FRANCISCO("small"),
@Deprecated NEW_YORK("medium"),
@Deprecated CAIRO("large"),
@Deprecated DELHI("x-large"),
@Deprecated TOKYO("2x-large"),

@Deprecated NEW_YORK_HIMEM("medium-himem"),
@Deprecated CAIRO_HIMEM("large-himem"),
@Deprecated DEHLI_HIMEM("x-large-himem"),
@Deprecated TOKYO_HIMEM("2x-large-himem"),
@Deprecated ATLANTIS_HIMEM("4x-large-himem"),

@Deprecated SEDONA_GPU("tiny-a10-gpu"),
@Deprecated SAN_FRANCISCO_GPU("small-a10-gpu"),
@Deprecated NEW_YORK_GPU("medium-a10-gpu");

public final String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.wherobots.db.jdbc.internal.ExecutionResult;
import com.wherobots.db.jdbc.internal.Frame;
import com.wherobots.db.jdbc.internal.Query;
import com.wherobots.db.jdbc.models.CancelRequest;
import com.wherobots.db.jdbc.models.Event;
import com.wherobots.db.jdbc.models.ExecuteSqlRequest;
import com.wherobots.db.jdbc.models.QueryState;
Expand Down Expand Up @@ -96,6 +97,7 @@ private void handle(Event event) throws Exception {

switch (event.state) {
case succeeded -> this.retrieveResults(event.executionId);
case cancelled -> query.statement().onExecutionResult(new ExecutionResult(null, null));
case failed -> {
// No-op, error event will follow.
}
Expand All @@ -106,13 +108,13 @@ private void handle(Event event) throws Exception {

if (event instanceof Event.ExecutionResultEvent ere) {
Event.Results results = ere.results;

logger.info(
"Received {} bytes of {}-compressed {} results from {}.",
results.resultBytes.length, results.compression, results.format, event.executionId);
ArrowStreamReader reader = ArrowUtil.readFrom(results.resultBytes, results.compression);
query.statement().onExecutionResult(new ExecutionResult(reader, null));

if (results != null) {
logger.info(
"Received {} bytes of {}-compressed {} results from {}.",
results.resultBytes.length, results.compression, results.format, event.executionId);
ArrowStreamReader reader = ArrowUtil.readFrom(results.resultBytes, results.compression);
query.statement().onExecutionResult(new ExecutionResult(reader, null));
}
return;
}

Expand Down Expand Up @@ -159,11 +161,14 @@ void retrieveResults(String executionId) {
}

void cancel(String executionId) throws SQLException {
Query query = this.queries.remove(executionId);
if (query != null) {
query.statement().close();
logger.info("Cancelled query {}.", executionId);
Query query = this.queries.get(executionId);
if (query == null) {
return;
}

String request = JsonUtil.serialize(new CancelRequest(executionId));
this.session.send(request);
logger.info("Cancelled query {}.", executionId);
}

@Override
Expand Down
13 changes: 11 additions & 2 deletions lib/src/main/java/com/wherobots/db/jdbc/WherobotsJdbcDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class WherobotsJdbcDriver implements Driver {
public static final String TOKEN_PROP = "token";
public static final String RUNTIME_PROP = "runtime";
public static final String REGION_PROP = "region";
public static final String REUSE_SESSION_PROP = "reuseSession";
public static final String WS_URI_PROP = "wsUri";

// Results format; one of {@link DataFormat}
Expand All @@ -45,8 +46,9 @@ public class WherobotsJdbcDriver implements Driver {
public static final String DEFAULT_ENDPOINT = "api.cloud.wherobots.com";
public static final String STAGING_ENDPOINT = "api.staging.wherobots.com";

public static final Runtime DEFAULT_RUNTIME = Runtime.SEDONA;
public static final Runtime DEFAULT_RUNTIME = Runtime.TINY;
public static final Region DEFAULT_REGION = Region.AWS_US_WEST_2;
public static final boolean DEFAULT_REUSE_SESSION = true;

public Map<String, String> getUserAgentHeader() {
String javaVersion = System.getProperty("java.version");
Expand Down Expand Up @@ -83,6 +85,13 @@ public Connection connect(String url, Properties info) throws SQLException {
if (StringUtils.isNotBlank(regionName)) {
region = Region.valueOf(regionName);
}

boolean reuse = DEFAULT_REUSE_SESSION;
String reuseSession = info.getProperty(REUSE_SESSION_PROP);
if (StringUtils.isNotBlank(reuseSession)) {
reuse = Boolean.parseBoolean(reuseSession);
}

Map<String, String> headers = new HashMap<>(getAuthHeaders(info));
headers.putAll(getUserAgentHeader());
WherobotsSession session;
Expand All @@ -96,7 +105,7 @@ public Connection connect(String url, Properties info) throws SQLException {
throw new SQLException("Invalid WebSocket URI: " + wsUriString, e);
}
} else {
session = WherobotsSessionSupplier.create(host, runtime, region, headers);
session = WherobotsSessionSupplier.create(host, runtime, region, reuse, headers);
}

return new WherobotsJdbcConnection(session, info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class WherobotsStatement implements Statement {
private int timeoutSeconds = DEFAULT_QUERY_TIMEOUT_SECONDS;
private int maxRows = 0;

private String executionId;
private volatile String executionId;
private ResultSet results;
private int updateCount = -1;

Expand Down Expand Up @@ -151,6 +151,10 @@ public boolean execute(String sql) throws SQLException {
throw new SQLException(result.error());
}

if (result.result() == null) {
return false;
}

// TODO: differentiate between queries and insert/update/delete results
this.results = new WherobotsResultSet(this, result.result());
return true;
Expand Down
17 changes: 17 additions & 0 deletions lib/src/main/java/com/wherobots/db/jdbc/models/CancelRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.wherobots.db.jdbc.models;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;

@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public class CancelRequest {

public final String kind = "cancel";
public String executionId;

public CancelRequest(String executionId) {
this.executionId = executionId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public abstract class WherobotsSessionSupplier {

private static final Logger logger = LoggerFactory.getLogger(WherobotsSessionSupplier.class);

private static final String SQL_SESSION_ENDPOINT = "https://%s/sql/session?region=%s";
private static final String SQL_SESSION_ENDPOINT = "https://%s/sql/session?region=%s&reuse_session=%s";
private static final String PROTOCOL_VERSION = "1.0.0";

@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -52,7 +52,7 @@ private record SqlSessionResponsePayload(AppStatus status, SqlSessionAppMeta app
* @return
* @throws SQLException
*/
public static WherobotsSession create(String host, Runtime runtime, Region region, Map<String, String> headers)
public static WherobotsSession create(String host, Runtime runtime, Region region, boolean reuse, Map<String, String> headers)
throws SQLException {
HttpClient client = HttpClient.newBuilder()
.followRedirects(HttpClient.Redirect.NORMAL)
Expand All @@ -67,7 +67,7 @@ public static WherobotsSession create(String host, Runtime runtime, Region regio
Retry retry = RetryRegistry.of(config).retry("session");

try {
URI sessionIdUri = new SqlSessionSupplier(client, headers, host, runtime, region).get();
URI sessionIdUri = new SqlSessionSupplier(client, headers, host, runtime, region, reuse).get();
URI wsUri = Retry.decorateCheckedSupplier(retry, new SessionWsUriSupplier(client, headers, sessionIdUri)).get();
return create(wsUri, headers);
} catch (SQLException e) {
Expand Down Expand Up @@ -102,25 +102,27 @@ private record SqlSessionSupplier(HttpClient client,
Map<String, String> headers,
String host,
Runtime runtime,
Region region)
Region region,
boolean reuse)
implements CheckedSupplier<URI> {

@Override
public URI get() throws IOException, InterruptedException {
logger.info("Requesting {}/{} runtime in {} from {}...", runtime, runtime.name, region, host);
logger.info("{} {} runtime in {} from {}...",
reuse ? "Recycling" : "Requesting", runtime.name, region.name, host);

HttpRequest.BodyPublisher body = HttpRequest.BodyPublishers.ofString(
JsonUtil.serialize(new SqlSessionRequestPayload(runtime.name, null)));

HttpRequest.Builder request = HttpRequest.newBuilder()
.POST(body)
.uri(URI.create(String.format(SQL_SESSION_ENDPOINT, host, region.name)))
.uri(URI.create(String.format(SQL_SESSION_ENDPOINT, host, region.name, reuse)))
.header("Content-Type", "application/json");
headers.forEach(request::header);

HttpResponse<String> response = client.send(request.build(), HttpResponse.BodyHandlers.ofString());
if (response.statusCode() != HttpStatus.SC_OK) {
throw new IllegalStateException();
throw new IllegalStateException(String.format("Got %d from SQL session at %s.", response.statusCode(), host));
}
return response.uri();
}
Expand Down
34 changes: 26 additions & 8 deletions lib/src/test/java/com/wherobots/db/jdbc/SmokeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,34 @@ public static void main(String[] args) throws Exception {

Properties props = new Properties();
props.put(WherobotsJdbcDriver.API_KEY_PROP, args[0]);
props.put(WherobotsJdbcDriver.REUSE_SESSION_PROP, "true");

try (Connection conn = DriverManager.getConnection("jdbc:wherobots://api.staging.wherobots.com", props)) {
try (Statement stmt = conn.createStatement(); ResultSet result = stmt.executeQuery(sql)) {
while (result.next()) {
System.out.printf("%s: %s\t%s\t%12d%n",
result.getString("id"),
result.getString("name"),
result.getString("geometry"),
result.getInt("population")
);
try (Statement stmt = conn.createStatement()) {
new Thread(() -> {
try {
System.out.println("Cancelling query in 2s!");
Thread.sleep(2000L);
stmt.cancel();
} catch (Exception e) {
e.printStackTrace();
}
}).start();

boolean hasResult = stmt.execute(sql);
if (!hasResult) {
return;
}

try (ResultSet result = stmt.getResultSet()) {
while (result.next()) {
System.out.printf("%s: %s\t%s\t%12d%n",
result.getString("id"),
result.getString("name"),
result.getString("geometry"),
result.getInt("population")
);
}
}
}
}
Expand Down
Loading