Skip to content

Commit

Permalink
feat(java/driver/fligh-sql): precompile columnNamePattern
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko authored and jduo committed Feb 5, 2024
1 parent 9c7ccb8 commit a3771f5
Showing 1 changed file with 179 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.regex.Pattern;

import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.StandardSchemas;
Expand All @@ -44,196 +46,190 @@

final class ObjectMetadataBuilder {

private final FlightSqlClient client;
private final VectorSchemaRoot root;
private final VarCharVector adbcCatalogNames;
private final UnionListWriter adbcCatalogDbSchemasWriter;
private final BaseWriter.StructWriter adbcCatalogDbSchemasStructWriter;
private final BaseWriter.ListWriter adbcCatalogDbSchemaTablesWriter;
private final VarCharWriter adbcCatalogDbSchemaNameWriter;
private final BaseWriter.StructWriter adbcTablesStructWriter;
private final VarCharWriter adbcTableNameWriter;
private final VarCharWriter adbcTableTypeWriter;
private final BaseWriter.ListWriter adbcTableColumnsWriter;
private final BufferAllocator allocator;
private final AdbcConnection.GetObjectsDepth depth;
private final String catalogPattern;
private final String dbSchemaPattern;
private final String tableNamePattern;
private final String[] tableTypes;
private final String columnNamePattern;

ObjectMetadataBuilder(
BufferAllocator allocator,
FlightSqlClient client,
final AdbcConnection.GetObjectsDepth depth,
final String catalogPattern,
final String dbSchemaPattern,
final String tableNamePattern,
final String[] tableTypes,
final String columnNamePattern) {
this.allocator = allocator;
this.client = client;
this.depth = depth;
this.catalogPattern = catalogPattern;
this.dbSchemaPattern = dbSchemaPattern;
this.tableNamePattern = tableNamePattern;
this.tableTypes = tableTypes;
this.columnNamePattern = columnNamePattern;
this.root = VectorSchemaRoot.create(StandardSchemas.GET_OBJECTS_SCHEMA, allocator);
this.adbcCatalogNames = (VarCharVector) root.getVector(0);
this.adbcCatalogDbSchemasWriter = ((ListVector) root.getVector(1)).getWriter();
this.adbcCatalogDbSchemasStructWriter = adbcCatalogDbSchemasWriter.struct();
this.adbcCatalogDbSchemaTablesWriter =
adbcCatalogDbSchemasStructWriter.list("db_schema_tables");
this.adbcCatalogDbSchemaNameWriter = adbcCatalogDbSchemasStructWriter.varChar("db_schema_name");
this.adbcTablesStructWriter = adbcCatalogDbSchemaTablesWriter.struct();
this.adbcTableNameWriter = adbcTablesStructWriter.varChar("table_name");
this.adbcTableTypeWriter = adbcTablesStructWriter.varChar("table_type");
this.adbcTableColumnsWriter = adbcTablesStructWriter.list("table_columns");
}

private void writeVarChar(VarCharWriter writer, String value) {
byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
try (ArrowBuf tempBuf = allocator.buffer(bytes.length)) {
tempBuf.setBytes(0, bytes, 0, bytes.length);
writer.writeVarChar(0, bytes.length, tempBuf);
}
}

private boolean patternMatched(String name, String pattern) {
if (pattern == null) {
return true;
private final FlightSqlClient client;
private final VectorSchemaRoot root;
private final VarCharVector adbcCatalogNames;
private final UnionListWriter adbcCatalogDbSchemasWriter;
private final BaseWriter.StructWriter adbcCatalogDbSchemasStructWriter;
private final BaseWriter.ListWriter adbcCatalogDbSchemaTablesWriter;
private final VarCharWriter adbcCatalogDbSchemaNameWriter;
private final BaseWriter.StructWriter adbcTablesStructWriter;
private final VarCharWriter adbcTableNameWriter;
private final VarCharWriter adbcTableTypeWriter;
private final BaseWriter.ListWriter adbcTableColumnsWriter;
private final BufferAllocator allocator;
private final AdbcConnection.GetObjectsDepth depth;
private final String catalogPattern;
private final String dbSchemaPattern;
private final String tableNamePattern;
private final String[] tableTypes;
private final Pattern precompiledColumnNamePattern;

ObjectMetadataBuilder(
BufferAllocator allocator,
FlightSqlClient client,
final AdbcConnection.GetObjectsDepth depth,
final String catalogPattern,
final String dbSchemaPattern,
final String tableNamePattern,
final String[] tableTypes,
final String columnNamePattern) {
this.allocator = allocator;
this.client = client;
this.depth = depth;
this.catalogPattern = catalogPattern;
this.dbSchemaPattern = dbSchemaPattern;
this.tableNamePattern = tableNamePattern;
this.precompiledColumnNamePattern = columnNamePattern != null ? Pattern.compile(
Pattern.quote(columnNamePattern).replace("_", ".").replace("%", ".*")
) : null;
this.tableTypes = tableTypes;
this.root = VectorSchemaRoot.create(StandardSchemas.GET_OBJECTS_SCHEMA, allocator);
this.adbcCatalogNames = (VarCharVector) root.getVector(0);
this.adbcCatalogDbSchemasWriter = ((ListVector) root.getVector(1)).getWriter();
this.adbcCatalogDbSchemasStructWriter = adbcCatalogDbSchemasWriter.struct();
this.adbcCatalogDbSchemaTablesWriter =
adbcCatalogDbSchemasStructWriter.list("db_schema_tables");
this.adbcCatalogDbSchemaNameWriter = adbcCatalogDbSchemasStructWriter.varChar("db_schema_name");
this.adbcTablesStructWriter = adbcCatalogDbSchemaTablesWriter.struct();
this.adbcTableNameWriter = adbcTablesStructWriter.varChar("table_name");
this.adbcTableTypeWriter = adbcTablesStructWriter.varChar("table_type");
this.adbcTableColumnsWriter = adbcTablesStructWriter.list("table_columns");
}

return name.matches(pattern.replace("_", ".").replace("%", ".*"));
}

VectorSchemaRoot build() throws AdbcException {
// TODO Catalogs and schemas that don't contain tables are being left out
FlightInfo info;
if (depth == AdbcConnection.GetObjectsDepth.CATALOGS) {
info = client.getCatalogs();
} else if (depth == AdbcConnection.GetObjectsDepth.DB_SCHEMAS) {
info = client.getSchemas(null, dbSchemaPattern);
} else {
info =
client.getTables(
null, // TODO pattern match later during processing
dbSchemaPattern,
tableNamePattern,
tableTypes == null ? null : Arrays.asList(tableTypes),
depth == AdbcConnection.GetObjectsDepth.ALL);
private void writeVarChar(VarCharWriter writer, String value) {
byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
try (ArrowBuf tempBuf = allocator.buffer(bytes.length)) {
tempBuf.setBytes(0, bytes, 0, bytes.length);
writer.writeVarChar(0, bytes.length, tempBuf);
}
}

byte[] lastCatalogAdded = null;
byte[] lastDbSchemaAdded = null;
int catalogIndex = 0;

for (FlightEndpoint endpoint : info.getEndpoints()) {
FlightStream stream = client.getStream(endpoint.getTicket());
while (stream.next()) {
try (VectorSchemaRoot res = stream.getRoot()) {
VarCharVector catalogVector = (VarCharVector) res.getVector(0);

for (int i = 0; i < res.getRowCount(); i++) {
byte[] catalog = catalogVector.get(i);

if (i == 0 || lastCatalogAdded != catalog) {
if (catalog == null) {
adbcCatalogNames.setNull(catalogIndex);
} else {
adbcCatalogNames.setSafe(catalogIndex, catalog);
}
if (depth == AdbcConnection.GetObjectsDepth.CATALOGS) {
adbcCatalogDbSchemasWriter.writeNull();
} else {
if (catalogIndex != 0) {
adbcCatalogDbSchemasWriter.endList();
}
adbcCatalogDbSchemasWriter.startList();
lastDbSchemaAdded = null;
}
catalogIndex++;
lastCatalogAdded = catalog;
}

if (depth != AdbcConnection.GetObjectsDepth.CATALOGS) {
VarCharVector dbSchemaVector = (VarCharVector) res.getVector(1);
byte[] dbSchema = dbSchemaVector.get(i);

if (!Arrays.equals(lastDbSchemaAdded, dbSchema)) {
if (i != 0) {
adbcCatalogDbSchemaTablesWriter.endList();
adbcCatalogDbSchemasStructWriter.end();
}
adbcCatalogDbSchemasStructWriter.start();
writeVarChar(
adbcCatalogDbSchemaNameWriter, new String(dbSchema, StandardCharsets.UTF_8));
if (depth == AdbcConnection.GetObjectsDepth.DB_SCHEMAS) {
adbcCatalogDbSchemaTablesWriter.writeNull();
} else {
adbcCatalogDbSchemaTablesWriter.startList();
}

lastDbSchemaAdded = dbSchema;
}
}

if (depth != AdbcConnection.GetObjectsDepth.CATALOGS
&& depth != AdbcConnection.GetObjectsDepth.DB_SCHEMAS) {
VarCharVector tableNameVector = (VarCharVector) res.getVector(2);
VarCharVector tableTypeVector = (VarCharVector) res.getVector(3);

adbcTablesStructWriter.start();
writeVarChar(
adbcTableNameWriter, new String(tableNameVector.get(i), StandardCharsets.UTF_8));
writeVarChar(
adbcTableTypeWriter, new String(tableTypeVector.get(i), StandardCharsets.UTF_8));

if (depth == AdbcConnection.GetObjectsDepth.ALL) {
VarBinaryVector tableSchemaVector = (VarBinaryVector) res.getVector(4);
Schema schema;

try {
schema =
MessageSerializer.deserializeSchema(
new ReadChannel(
Channels.newChannel(
new ByteArrayInputStream(tableSchemaVector.get(i)))));
} catch (IOException e) {
throw new RuntimeException(e);
}
VectorSchemaRoot build() throws AdbcException {
// TODO Catalogs and schemas that don't contain tables are being left out
FlightInfo info;
if (depth == AdbcConnection.GetObjectsDepth.CATALOGS) {
info = client.getCatalogs();
} else if (depth == AdbcConnection.GetObjectsDepth.DB_SCHEMAS) {
info = client.getSchemas(null, dbSchemaPattern);
} else {
info =
client.getTables(
null, // TODO pattern match later during processing
dbSchemaPattern,
tableNamePattern,
tableTypes == null ? null : Arrays.asList(tableTypes),
depth == AdbcConnection.GetObjectsDepth.ALL);
}

adbcTableColumnsWriter.startList();

for (int y = 0; y < schema.getFields().size(); y++) {
Field field = schema.getFields().get(y);
if (patternMatched(field.getName(), columnNamePattern)) {
adbcTableColumnsWriter.struct().start();
writeVarChar(
adbcTableColumnsWriter.struct().varChar("column_name"), field.getName());
adbcTableColumnsWriter.struct().integer("ordinal_position").writeInt(y + 1);
adbcTableColumnsWriter.struct().end();
}
byte[] lastCatalogAdded = null;
byte[] lastDbSchemaAdded = null;
int catalogIndex = 0;

for (FlightEndpoint endpoint : info.getEndpoints()) {
FlightStream stream = client.getStream(endpoint.getTicket());
while (stream.next()) {
try (VectorSchemaRoot res = stream.getRoot()) {
VarCharVector catalogVector = (VarCharVector) res.getVector(0);

for (int i = 0; i < res.getRowCount(); i++) {
byte[] catalog = catalogVector.get(i);

if (i == 0 || lastCatalogAdded != catalog) {
if (catalog == null) {
adbcCatalogNames.setNull(catalogIndex);
} else {
adbcCatalogNames.setSafe(catalogIndex, catalog);
}
if (depth == AdbcConnection.GetObjectsDepth.CATALOGS) {
adbcCatalogDbSchemasWriter.writeNull();
} else {
if (catalogIndex != 0) {
adbcCatalogDbSchemasWriter.endList();
}
adbcCatalogDbSchemasWriter.startList();
lastDbSchemaAdded = null;
}
catalogIndex++;
lastCatalogAdded = catalog;
}

if (depth != AdbcConnection.GetObjectsDepth.CATALOGS) {
VarCharVector dbSchemaVector = (VarCharVector) res.getVector(1);
byte[] dbSchema = dbSchemaVector.get(i);

if (!Arrays.equals(lastDbSchemaAdded, dbSchema)) {
if (i != 0) {
adbcCatalogDbSchemaTablesWriter.endList();
adbcCatalogDbSchemasStructWriter.end();
}
adbcCatalogDbSchemasStructWriter.start();
writeVarChar(
adbcCatalogDbSchemaNameWriter, new String(dbSchema, StandardCharsets.UTF_8));
if (depth == AdbcConnection.GetObjectsDepth.DB_SCHEMAS) {
adbcCatalogDbSchemaTablesWriter.writeNull();
} else {
adbcCatalogDbSchemaTablesWriter.startList();
}

lastDbSchemaAdded = dbSchema;
}
}

if (depth != AdbcConnection.GetObjectsDepth.CATALOGS
&& depth != AdbcConnection.GetObjectsDepth.DB_SCHEMAS) {
VarCharVector tableNameVector = (VarCharVector) res.getVector(2);
VarCharVector tableTypeVector = (VarCharVector) res.getVector(3);

adbcTablesStructWriter.start();
writeVarChar(
adbcTableNameWriter, new String(tableNameVector.get(i), StandardCharsets.UTF_8));
writeVarChar(
adbcTableTypeWriter, new String(tableTypeVector.get(i), StandardCharsets.UTF_8));

if (depth == AdbcConnection.GetObjectsDepth.ALL) {
VarBinaryVector tableSchemaVector = (VarBinaryVector) res.getVector(4);
Schema schema;

try {
schema =
MessageSerializer.deserializeSchema(
new ReadChannel(
Channels.newChannel(
new ByteArrayInputStream(tableSchemaVector.get(i)))));
} catch (IOException e) {
throw new RuntimeException(e);
}

adbcTableColumnsWriter.startList();

for (int y = 0; y < schema.getFields().size(); y++) {
Field field = schema.getFields().get(y);
if (precompiledColumnNamePattern == null || precompiledColumnNamePattern.matcher(field.getName()).matches()) {
adbcTableColumnsWriter.struct().start();
writeVarChar(
adbcTableColumnsWriter.struct().varChar("column_name"), field.getName());
adbcTableColumnsWriter.struct().integer("ordinal_position").writeInt(y + 1);
adbcTableColumnsWriter.struct().end();
}
}
adbcTableColumnsWriter.endList();
}

adbcTablesStructWriter.end();
}
}

if (depth != AdbcConnection.GetObjectsDepth.CATALOGS) {
adbcCatalogDbSchemaTablesWriter.endList();
adbcCatalogDbSchemasStructWriter.end();
adbcCatalogDbSchemasWriter.endList();
}
}
adbcTableColumnsWriter.endList();
}

adbcTablesStructWriter.end();
}
}

if (depth != AdbcConnection.GetObjectsDepth.CATALOGS) {
adbcCatalogDbSchemaTablesWriter.endList();
adbcCatalogDbSchemasStructWriter.end();
adbcCatalogDbSchemasWriter.endList();
}
}
}
}

this.root.setRowCount(catalogIndex);
return root;
}
root.setRowCount(catalogIndex);
return root;
}
}

0 comments on commit a3771f5

Please sign in to comment.