Skip to content

Commit

Permalink
ARROW-4213: [Flight] Fix incompatibilities between C++ and Java
Browse files Browse the repository at this point in the history
I'm willing to update the patches here in case you want to resolve incompatibilities differently. The changes are:

- In C++, send/read the schema as the first message in a DoGet stream.
- In Java, encode the GetFlightInfo schema in a Flatbuffer Message payload.
- In C++, skip null columns when encoding IpcPayload.
- In C++, don't write the body tag when encoding IpcPayload if no buffers are present.
- In Java, always align buffers when serializing record batches.

Additionally:
- Add integration tests for Flight. They will fail when trying to test record batches with dictionaries, but otherwise pass.
- Generate an uberjar for Flight, and include Apache Commons CLI and gRPC in the uberjar.
- Explicitly add the generated gRPC/Protobuf sources to the Java build - for some reason, Maven was not picking them up.
- In Java FlightClient, if there's an exception, use it to resolve the VectorSchemaRoot CompletableFuture so that clients do not hang forever.

Author: David Li <[email protected]>
Author: Wes McKinney <[email protected]>

Closes apache#3477 from lihalite/arrow-4213 and squashes the following commits:

d5a3ebf <Wes McKinney> Fix comment re: flight integration testing procedure
ac337a2 <David Li> Update description of integration test client
f867959 <David Li> Fix Flake8 errors
8c7be86 <David Li> clang-format new code
142b97d <David Li> Allow integration test script to control Flight port
3aace90 <David Li> Properly wait for Flight integration test server to start
0d0a476 <David Li> Clean up style issues
4c83485 <David Li> Collect failures at end of integration tests
0f3cce4 <David Li> Always align buffers when writing record batches in Flight
3c057df <David Li> Implement Java server side of Flight integration tests
13c0700 <David Li> Set exception on root in FlightClient
4801257 <David Li> Assume schema in Flight GetInfo is encoded in a Message payload
321f898 <David Li> Include gRPC in Flight uberjar
d23a201 <David Li> Explicitly compile generated Protobuf sources for Flight
6885d34 <David Li> Don't write body tag when serializing IpcPayload if no buffers
92c4c13 <David Li> Skeleton of integration test server for Flight
9bf8525 <David Li> Fix segfault in Flight when sending nullary columns
aca752b <David Li> Read schemas from stream in Flight DoGet
2acc8be <David Li> Implement arrow::ipc::ReadSchema that works with a Message
  • Loading branch information
David Li authored and wesm committed Jan 29, 2019
1 parent 3d435e4 commit de84293
Show file tree
Hide file tree
Showing 18 changed files with 856 additions and 27 deletions.
12 changes: 12 additions & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ if (ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS)
gflags_static
${GTEST_LIBRARY})

add_executable(flight-test-integration-server test-integration-server.cc)
target_link_libraries(flight-test-integration-server
${ARROW_FLIGHT_TEST_STATIC_LINK_LIBS}
gflags_static
gtest_static)

add_executable(flight-test-integration-client test-integration-client.cc)
target_link_libraries(flight-test-integration-client
${ARROW_FLIGHT_TEST_STATIC_LINK_LIBS}
gflags_static
gtest_static)

# This is needed for the unit tests
if (ARROW_BUILD_TESTS)
add_dependencies(arrow-flight-test flight-test-server)
Expand Down
11 changes: 10 additions & 1 deletion cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,16 @@ class FlightStreamReader : public RecordBatchReader {

// Validate IPC message
RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message));
return ipc::ReadRecordBatch(*message, schema_, out);
// The first message is a schema; read it and then try to read a
// record batch.
if (message->type() == ipc::Message::Type::SCHEMA) {
RETURN_NOT_OK(ipc::ReadSchema(*message, &schema_));
return ReadNext(out);
} else if (message->type() == ipc::Message::Type::RECORD_BATCH) {
return ipc::ReadRecordBatch(*message, schema_, out);
} else {
return Status(StatusCode::Invalid, "Unrecognized message in Flight stream");
}
} else {
// Stream is completed
stream_finished_ = true;
Expand Down
49 changes: 37 additions & 12 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class SerializationTraits<IpcPayload> {

int64_t body_size = 0;
for (const auto& buffer : msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
if (!buffer) continue;

body_size += buffer->size();

const int64_t remainder = buffer->size() % 8;
Expand All @@ -111,7 +115,11 @@ class SerializationTraits<IpcPayload> {
}

// 2 bytes for body tag
total_size += 2 + WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
// Only written when there are body buffers
if (msg.body_length > 0) {
total_size +=
2 + WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
}

// TODO(wesm): messages over 2GB unlikely to be yet supported
if (total_size > kInt32Max) {
Expand All @@ -135,20 +143,27 @@ class SerializationTraits<IpcPayload> {
pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
static_cast<int>(msg.metadata->size()));

// Write body
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream);
pb_stream.WriteVarint32(static_cast<uint32_t>(body_size));
// Don't write tag if there are no body buffers
if (msg.body_length > 0) {
// Write body
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream);
pb_stream.WriteVarint32(static_cast<uint32_t>(body_size));

constexpr uint8_t kPaddingBytes[8] = {0};
constexpr uint8_t kPaddingBytes[8] = {0};

for (const auto& buffer : msg.body_buffers) {
pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast<int>(buffer->size()));
for (const auto& buffer : msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
if (!buffer) continue;

// Write padding if not multiple of 8
const int remainder = static_cast<int>(buffer->size() % 8);
if (remainder) {
pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder);
pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast<int>(buffer->size()));

// Write padding if not multiple of 8
const int remainder = static_cast<int>(buffer->size() % 8);
if (remainder) {
pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder);
}
}
}

Expand Down Expand Up @@ -255,6 +270,14 @@ class FlightServiceImpl : public FlightService::Service {
// Requires ServerWriter customization in grpc_customizations.h
auto custom_writer = reinterpret_cast<ServerWriter<IpcPayload>*>(writer);

// Write the schema as the first message in the stream
IpcPayload schema_payload;
MemoryPool* pool = default_memory_pool();
ipc::DictionaryMemo dictionary_memo;
GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload(
*data_stream->schema(), pool, &dictionary_memo, &schema_payload));
custom_writer->Write(schema_payload, grpc::WriteOptions());

while (true) {
IpcPayload payload;
GRPC_RETURN_NOT_OK(data_stream->Next(&payload));
Expand Down Expand Up @@ -368,6 +391,8 @@ Status FlightServerBase::ListActions(std::vector<ActionType>* actions) {
RecordBatchStream::RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader)
: pool_(default_memory_pool()), reader_(reader) {}

std::shared_ptr<Schema> RecordBatchStream::schema() { return reader_->schema(); }

Status RecordBatchStream::Next(IpcPayload* payload) {
std::shared_ptr<RecordBatch> batch;
RETURN_NOT_OK(reader_->ReadNext(&batch));
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/util/visibility.h"

#include "arrow/flight/types.h"
#include "arrow/ipc/dictionary.h"

namespace arrow {

Expand Down Expand Up @@ -57,6 +58,9 @@ class ARROW_EXPORT FlightDataStream {
public:
virtual ~FlightDataStream() = default;

// When the stream starts, send the schema.
virtual std::shared_ptr<Schema> schema() = 0;

// When the stream is completed, the last payload written will have null
// metadata
virtual Status Next(ipc::internal::IpcPayload* payload) = 0;
Expand All @@ -69,6 +73,7 @@ class ARROW_EXPORT RecordBatchStream : public FlightDataStream {
public:
explicit RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader);

std::shared_ptr<Schema> schema() override;
Status Next(ipc::internal::IpcPayload* payload) override;

private:
Expand Down
82 changes: 82 additions & 0 deletions cpp/src/arrow/flight/test-integration-client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Client implementation for Flight integration testing. Requests the given
// path from the Flight server, which reads that file and sends it as a stream
// to the client. The client writes the server stream to the IPC file format at
// the given output file path. The integration test script then uses the
// existing integration test tools to compare the output binary with the
// original JSON

#include <iostream>
#include <memory>
#include <string>

#include <gflags/gflags.h>

#include "arrow/io/test-common.h"
#include "arrow/ipc/json.h"
#include "arrow/record_batch.h"

#include "arrow/flight/server.h"
#include "arrow/flight/test-util.h"

DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");
DEFINE_string(output, "", "Where to write requested resource");

int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing client for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);

std::unique_ptr<arrow::flight::FlightClient> client;
ABORT_NOT_OK(arrow::flight::FlightClient::Connect(FLAGS_host, FLAGS_port, &client));

arrow::flight::FlightDescriptor descr{
arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}};
std::unique_ptr<arrow::flight::FlightInfo> info;
ABORT_NOT_OK(client->GetFlightInfo(descr, &info));

std::shared_ptr<arrow::Schema> schema;
ABORT_NOT_OK(info->GetSchema(&schema));

if (info->endpoints().size() == 0) {
std::cerr << "No endpoints returned from Flight server." << std::endl;
return -1;
}

arrow::flight::Ticket ticket = info->endpoints()[0].ticket;
std::unique_ptr<arrow::RecordBatchReader> stream;
ABORT_NOT_OK(client->DoGet(ticket, schema, &stream));

std::shared_ptr<arrow::io::FileOutputStream> out_file;
ABORT_NOT_OK(arrow::io::FileOutputStream::Open(FLAGS_output, &out_file));
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;
ABORT_NOT_OK(arrow::ipc::RecordBatchFileWriter::Open(out_file.get(), schema, &writer));

std::shared_ptr<arrow::RecordBatch> chunk;
while (true) {
ABORT_NOT_OK(stream->ReadNext(&chunk));
if (chunk == nullptr) break;
ABORT_NOT_OK(writer->WriteRecordBatch(*chunk));
}

ABORT_NOT_OK(writer->Close());

return 0;
}
150 changes: 150 additions & 0 deletions cpp/src/arrow/flight/test-integration-server.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Example server implementation for integration testing purposes

#include <signal.h>
#include <iostream>
#include <memory>
#include <string>

#include <gflags/gflags.h>

#include "arrow/io/test-common.h"
#include "arrow/ipc/json.h"
#include "arrow/record_batch.h"

#include "arrow/flight/server.h"
#include "arrow/flight/test-util.h"

DEFINE_int32(port, 31337, "Server port to listen on");

namespace arrow {
namespace flight {

class JsonReaderRecordBatchStream : public FlightDataStream {
public:
explicit JsonReaderRecordBatchStream(
std::unique_ptr<ipc::internal::json::JsonReader>&& reader)
: index_(0), pool_(default_memory_pool()), reader_(std::move(reader)) {}

std::shared_ptr<Schema> schema() override { return reader_->schema(); }

Status Next(ipc::internal::IpcPayload* payload) override {
if (index_ >= reader_->num_record_batches()) {
// Signal that iteration is over
payload->metadata = nullptr;
return Status::OK();
}

std::shared_ptr<RecordBatch> batch;
RETURN_NOT_OK(reader_->ReadRecordBatch(index_, &batch));
index_++;

if (!batch) {
// Signal that iteration is over
payload->metadata = nullptr;
return Status::OK();
} else {
return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload);
}
}

private:
int index_;
MemoryPool* pool_;
std::unique_ptr<ipc::internal::json::JsonReader> reader_;
};

class FlightIntegrationTestServer : public FlightServerBase {
Status ReadJson(const std::string& json_path,
std::unique_ptr<ipc::internal::json::JsonReader>* out) {
std::shared_ptr<io::ReadableFile> in_file;
std::cout << "Opening JSON file '" << json_path << "'" << std::endl;
RETURN_NOT_OK(io::ReadableFile::Open(json_path, &in_file));

int64_t file_size = 0;
RETURN_NOT_OK(in_file->GetSize(&file_size));

std::shared_ptr<Buffer> json_buffer;
RETURN_NOT_OK(in_file->Read(file_size, &json_buffer));

RETURN_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, out));
return Status::OK();
}

Status GetFlightInfo(const FlightDescriptor& request,
std::unique_ptr<FlightInfo>* info) override {
if (request.type == FlightDescriptor::PATH) {
if (request.path.size() == 0) {
return Status::Invalid("Invalid path");
}

std::unique_ptr<arrow::ipc::internal::json::JsonReader> reader;
RETURN_NOT_OK(ReadJson(request.path.back(), &reader));

FlightEndpoint endpoint1({{request.path.back()}, {}});

FlightInfo::Data flight_data;
RETURN_NOT_OK(internal::SchemaToString(*reader->schema(), &flight_data.schema));
flight_data.descriptor = request;
flight_data.endpoints = {endpoint1};
flight_data.total_records = reader->num_record_batches();
flight_data.total_bytes = -1;
FlightInfo value(flight_data);

*info = std::unique_ptr<FlightInfo>(new FlightInfo(value));
return Status::OK();
} else {
return Status::NotImplemented(request.type);
}
}

Status DoGet(const Ticket& request,
std::unique_ptr<FlightDataStream>* data_stream) override {
std::unique_ptr<arrow::ipc::internal::json::JsonReader> reader;
RETURN_NOT_OK(ReadJson(request.ticket, &reader));

*data_stream = std::unique_ptr<FlightDataStream>(
new JsonReaderRecordBatchStream(std::move(reader)));

return Status::OK();
}
};

} // namespace flight
} // namespace arrow

std::unique_ptr<arrow::flight::FlightIntegrationTestServer> g_server;

void Shutdown(int signal) {
if (g_server != nullptr) {
g_server->Shutdown();
}
}

int main(int argc, char** argv) {
gflags::SetUsageMessage("Integration testing server for Flight.");
gflags::ParseCommandLineFlags(&argc, &argv, true);

// SIGTERM shuts down the server
signal(SIGTERM, Shutdown);

g_server.reset(new arrow::flight::FlightIntegrationTestServer);
g_server->Run(FLAGS_port);
return 0;
}
6 changes: 6 additions & 0 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,12 @@ Status ReadSchema(io::InputStream* stream, std::shared_ptr<Schema>* out) {
return Status::OK();
}

Status ReadSchema(const Message& message, std::shared_ptr<Schema>* out) {
std::shared_ptr<RecordBatchReader> reader;
DictionaryMemo dictionary_memo;
return internal::GetSchema(message.header(), dictionary_memo, &*out);
}

Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, io::InputStream* file,
std::shared_ptr<RecordBatch>* out) {
std::unique_ptr<Message> message;
Expand Down
Loading

0 comments on commit de84293

Please sign in to comment.