Skip to content

Commit

Permalink
Refactor GetSqlInfo and GetObjects workflows
Browse files Browse the repository at this point in the history
* Change GetSqlInfo and GetObjects to use the same reader as
FlightInfoReader. This fixes bugs where locations aren't used
for these metadata calls and also makes them use resources
lazily.
* Fix a bug in the FlightInfoReader path where the first stream
is assumed to be from the current client rather than using
a location in the first endpoint.
* Fix handling of empty schemas or empty collections when
using GetObjects.
  • Loading branch information
jduo committed Feb 9, 2024
1 parent 25acfe3 commit dfff096
Show file tree
Hide file tree
Showing 9 changed files with 960 additions and 615 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.adbc.driver.flightsql;

import static org.apache.arrow.adbc.driver.flightsql.FlightSqlDriverUtil.tryLoadNextStream;

import com.github.benmanes.caffeine.cache.LoadingCache;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;

/** Base class for ArrowReaders based on consuming data from FlightEndpoints. */
public abstract class BaseFlightReader extends ArrowReader {

private final List<FlightEndpoint> flightEndpoints;
private final Supplier<List<FlightEndpoint>> rpcCall;
private int nextEndpointIndex = 0;
private FlightStream currentStream;
private Schema schema;
private long bytesRead = 0;
protected final FlightSqlClientWithCallOptions client;
protected final LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache;

protected BaseFlightReader(
BufferAllocator allocator,
FlightSqlClientWithCallOptions client,
LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache,
Supplier<List<FlightEndpoint>> rpcCall) {
super(allocator);
this.client = client;
this.clientCache = clientCache;
this.flightEndpoints = new ArrayList<>();
this.rpcCall = rpcCall;
}

@Override
public boolean loadNextBatch() throws IOException {
if (!currentStream.next()) {
if (nextEndpointIndex >= flightEndpoints.size()) {
return false;
} else {
try {
currentStream.close();
FlightEndpoint endpoint = flightEndpoints.get(nextEndpointIndex++);
currentStream = tryLoadNextStream(endpoint, client, clientCache);
if (!schema.equals(currentStream.getSchema())) {
throw new IOException(
"Stream has inconsistent schema. Expected: "
+ schema
+ "\nFound: "
+ currentStream.getSchema());
}
} catch (IOException e) {
throw e;
} catch (Exception e) {
throw new IOException(e);
}
}
}
processRootFromStream(currentStream.getRoot());
return true;
}

@Override
protected Schema readSchema() throws IOException {
return schema;
}

@Override
public long bytesRead() {
return bytesRead;
}

@Override
protected void closeReadSource() throws IOException {
try {
currentStream.close();
} catch (Exception e) {
throw new IOException(e);
}
}

protected abstract void processRootFromStream(VectorSchemaRoot root);

protected void addBytesRead(long bytes) {
this.bytesRead += bytes;
}

protected void populateEndpointData() throws AdbcException {
try {
this.flightEndpoints.addAll(rpcCall.get());
this.currentStream =
tryLoadNextStream(flightEndpoints.get(this.nextEndpointIndex++), client, clientCache);
this.schema = this.currentStream.getSchema();
} catch (FlightRuntimeException e) {
throw FlightSqlDriverUtil.fromFlightException(e);
} catch (IOException e) {
throw new AdbcException(e.getMessage(), e, AdbcStatusCode.IO, null, 0);
}
}

protected void loadRoot(VectorSchemaRoot root) {
final VectorUnloader unloader = new VectorUnloader(root);
final ArrowRecordBatch recordBatch = unloader.getRecordBatch();
addBytesRead(recordBatch.computeBodyLength());
loadRecordBatch(recordBatch);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,16 @@

import com.github.benmanes.caffeine.cache.LoadingCache;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
import org.checkerframework.checker.nullness.qual.Nullable;

/** An ArrowReader that wraps a FlightInfo. */
public class FlightInfoReader extends ArrowReader {
private final Schema schema;
private final FlightSqlClientWithCallOptions client;
private final LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache;
private final List<FlightEndpoint> flightEndpoints;
private int nextEndpointIndex;
private FlightStream currentStream;
private long bytesRead;

public class FlightInfoReader extends BaseFlightReader {
@SuppressWarnings(
"method.invocation") // Checker Framework does not like the ensureInitialized call
FlightInfoReader(
Expand All @@ -54,21 +36,9 @@ public class FlightInfoReader extends ArrowReader {
LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache,
List<FlightEndpoint> flightEndpoints)
throws AdbcException {
super(allocator);
this.client = client;
this.clientCache = clientCache;
this.flightEndpoints = flightEndpoints;
this.nextEndpointIndex = 0;
this.bytesRead = 0;

try {
this.currentStream =
client.getStream(flightEndpoints.get(this.nextEndpointIndex++).getTicket());
this.schema = this.currentStream.getSchema();
} catch (FlightRuntimeException e) {
throw FlightSqlDriverUtil.fromFlightException(e);
}
super(allocator, client, clientCache, () -> flightEndpoints);

populateEndpointData();
try {
this.ensureInitialized();
} catch (IOException e) {
Expand All @@ -82,85 +52,7 @@ public class FlightInfoReader extends ArrowReader {
}

@Override
public boolean loadNextBatch() throws IOException {
if (!currentStream.next()) {
if (nextEndpointIndex >= flightEndpoints.size()) {
return false;
} else {
try {
currentStream.close();
FlightEndpoint endpoint = flightEndpoints.get(nextEndpointIndex++);
currentStream = tryLoadNextStream(endpoint);
if (!schema.equals(currentStream.getSchema())) {
throw new IOException(
"Stream has inconsistent schema. Expected: "
+ schema
+ "\nFound: "
+ currentStream.getSchema());
}
} catch (IOException e) {
throw e;
} catch (Exception e) {
throw new IOException(e);
}
}
}
final VectorSchemaRoot root = currentStream.getRoot();
final VectorUnloader unloader = new VectorUnloader(root);
final ArrowRecordBatch recordBatch = unloader.getRecordBatch();
bytesRead += recordBatch.computeBodyLength();
loadRecordBatch(recordBatch);
return true;
}

private FlightStream tryLoadNextStream(FlightEndpoint endpoint) throws IOException {
if (endpoint.getLocations().isEmpty()) {
return client.getStream(endpoint.getTicket());
} else {
List<Location> locations = new ArrayList<>(endpoint.getLocations());
Collections.shuffle(locations);
IOException failure = null;
for (final Location location : locations) {
final @Nullable FlightSqlClientWithCallOptions client = clientCache.get(location);
if (client == null) {
throw new IllegalStateException("Could not connect to " + location);
}
try {
return client.getStream(endpoint.getTicket());
} catch (RuntimeException e) {
// Also handles CompletionException (from clientCache#get), FlightRuntimeException
if (failure == null) {
failure =
new IOException("Failed to get stream from location " + location + ": " + e, e);
} else {
failure.addSuppressed(
new IOException("Failed to get stream from location " + location + ": " + e, e));
}
}
}
if (failure == null) {
throw new IllegalStateException("FlightEndpoint had no locations");
}
throw Objects.requireNonNull(failure);
}
}

@Override
public long bytesRead() {
return bytesRead;
}

@Override
protected void closeReadSource() throws IOException {
try {
currentStream.close();
} catch (Exception e) {
throw new IOException(e);
}
}

@Override
protected Schema readSchema() {
return schema;
protected void processRootFromStream(VectorSchemaRoot root) {
loadRoot(root);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.checkerframework.checker.initialization.qual.UnknownInitialization;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -165,29 +164,25 @@ public ArrowReader getObjects(
String[] tableTypes,
String columnNamePattern)
throws AdbcException {
try (final VectorSchemaRoot root =
new ObjectMetadataBuilder(
allocator,
client,
depth,
catalogPattern,
dbSchemaPattern,
tableNamePattern,
tableTypes,
columnNamePattern)
.build()) {
return RootArrowReader.fromRoot(allocator, root);
}
return GetObjectsMetadataReaders.CreateGetObjectsReader(
allocator,
client,
clientCache,
depth,
catalogPattern,
dbSchemaPattern,
tableNamePattern,
tableTypes,
columnNamePattern);
}

@Override
public ArrowReader getInfo(int @Nullable [] infoCodes) throws AdbcException {
try (InfoMetadataBuilder builder = new InfoMetadataBuilder(allocator, client, infoCodes)) {
try (final VectorSchemaRoot root = builder.build()) {
return RootArrowReader.fromRoot(allocator, root);
}
try {
return GetInfoMetadataReader.CreateGetInfoMetadataReader(
allocator, client, clientCache, infoCodes);
} catch (Exception e) {
throw AdbcException.invalidState("[Flight SQL] Failed to get info");
throw AdbcException.invalidState("[Flight SQL] Failed to get info").withCause(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@
*/
package org.apache.arrow.adbc.driver.flightsql;

import com.github.benmanes.caffeine.cache.LoadingCache;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.core.ErrorDetail;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.checkerframework.checker.nullness.qual.Nullable;

final class FlightSqlDriverUtil {
Expand Down Expand Up @@ -104,4 +111,40 @@ static AdbcException fromFlightException(FlightRuntimeException e) {
0,
errorDetails);
}

static FlightStream tryLoadNextStream(
FlightEndpoint endpoint,
FlightSqlClientWithCallOptions rootClient,
LoadingCache<Location, FlightSqlClientWithCallOptions> clientCache)
throws IOException {
if (endpoint.getLocations().isEmpty()) {
return rootClient.getStream(endpoint.getTicket());
} else {
List<Location> locations = new ArrayList<>(endpoint.getLocations());
Collections.shuffle(locations);
IOException failure = null;
for (final Location location : locations) {
final @Nullable FlightSqlClientWithCallOptions client = clientCache.get(location);
if (client == null) {
throw new IllegalStateException("Could not connect to " + location);
}
try {
return client.getStream(endpoint.getTicket());
} catch (RuntimeException e) {
// Also handles CompletionException (from clientCache#get), FlightRuntimeException
if (failure == null) {
failure =
new IOException("Failed to get stream from location " + location + ": " + e, e);
} else {
failure.addSuppressed(
new IOException("Failed to get stream from location " + location + ": " + e, e));
}
}
}
if (failure == null) {
throw new IllegalStateException("FlightEndpoint had no locations");
}
throw Objects.requireNonNull(failure);
}
}
}
Loading

0 comments on commit dfff096

Please sign in to comment.