From de84293d9c93fe721cd127f1a27acc94fe290f3f Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Jan 2019 10:47:07 -0600 Subject: [PATCH] ARROW-4213: [Flight] Fix incompatibilities between C++ and Java I'm willing to update the patches here in case you want to resolve incompatibilities differently. The changes are: - In C++, send/read the schema as the first message in a DoGet stream. - In Java, encode the GetFlightInfo schema in a Flatbuffer Message payload. - In C++, skip null columns when encoding IpcPayload. - In C++, don't write the body tag when encoding IpcPayload if no buffers are present. - In Java, always align buffers when serializing record batches. Additionally: - Add integration tests for Flight. They will fail when trying to test record batches with dictionaries, but otherwise pass. - Generate an uberjar for Flight, and include Apache Commons CLI and gRPC in the uberjar. - Explicitly add the generated gRPC/Protobuf sources to the Java build - for some reason, Maven was not picking them up. - In Java FlightClient, if there's an exception, use it to resolve the VectorSchemaRoot CompletableFuture so that clients do not hang forever. Author: David Li Author: Wes McKinney Closes #3477 from lihalite/arrow-4213 and squashes the following commits: d5a3ebf6d Fix comment re: flight integration testing procedure ac337a279 Update description of integration test client f8679593f Fix Flake8 errors 8c7be8647 clang-format new code 142b97d0a Allow integration test script to control Flight port 3aace90ea Properly wait for Flight integration test server to start 0d0a476fb Clean up style issues 4c8348574 Collect failures at end of integration tests 0f3cce426 Always align buffers when writing record batches in Flight 3c057dfaa Implement Java server side of Flight integration tests 13c0700f9 Set exception on root in FlightClient 4801257e5 Assume schema in Flight GetInfo is encoded in a Message payload 321f8986d Include gRPC in Flight uberjar d23a2010a Explicitly compile generated Protobuf sources for Flight 6885d34ec Don't write body tag when serializing IpcPayload if no buffers 92c4c139b Skeleton of integration test server for Flight 9bf85255e Fix segfault in Flight when sending nullary columns aca752ba9 Read schemas from stream in Flight DoGet 2acc8be97 Implement arrow::ipc::ReadSchema that works with a Message --- cpp/src/arrow/flight/CMakeLists.txt | 12 ++ cpp/src/arrow/flight/client.cc | 11 +- cpp/src/arrow/flight/server.cc | 49 +++-- cpp/src/arrow/flight/server.h | 5 + .../arrow/flight/test-integration-client.cc | 82 +++++++++ .../arrow/flight/test-integration-server.cc | 150 ++++++++++++++++ cpp/src/arrow/ipc/reader.cc | 6 + cpp/src/arrow/ipc/reader.h | 8 + cpp/src/arrow/ipc/writer.cc | 9 + cpp/src/arrow/ipc/writer.h | 12 ++ integration/integration_test.py | 169 +++++++++++++++++- java/flight/pom.xml | 46 ++++- .../org/apache/arrow/flight/ArrowMessage.java | 27 ++- .../org/apache/arrow/flight/FlightInfo.java | 29 ++- .../apache/arrow/flight/FlightService.java | 3 +- .../org/apache/arrow/flight/FlightStream.java | 1 + .../integration/IntegrationTestClient.java | 108 +++++++++++ .../integration/IntegrationTestServer.java | 156 ++++++++++++++++ 18 files changed, 856 insertions(+), 27 deletions(-) create mode 100644 cpp/src/arrow/flight/test-integration-client.cc create mode 100644 cpp/src/arrow/flight/test-integration-server.cc create mode 100644 java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java create mode 100644 java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index b0006ac7f5cb8..f59ea3c5e6757 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -90,6 +90,18 @@ if (ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) gflags_static ${GTEST_LIBRARY}) + add_executable(flight-test-integration-server test-integration-server.cc) + target_link_libraries(flight-test-integration-server + ${ARROW_FLIGHT_TEST_STATIC_LINK_LIBS} + gflags_static + gtest_static) + + add_executable(flight-test-integration-client test-integration-client.cc) + target_link_libraries(flight-test-integration-client + ${ARROW_FLIGHT_TEST_STATIC_LINK_LIBS} + gflags_static + gtest_static) + # This is needed for the unit tests if (ARROW_BUILD_TESTS) add_dependencies(arrow-flight-test flight-test-server) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 94c4928d0220d..e25c1875d669f 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -232,7 +232,16 @@ class FlightStreamReader : public RecordBatchReader { // Validate IPC message RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); - return ipc::ReadRecordBatch(*message, schema_, out); + // The first message is a schema; read it and then try to read a + // record batch. + if (message->type() == ipc::Message::Type::SCHEMA) { + RETURN_NOT_OK(ipc::ReadSchema(*message, &schema_)); + return ReadNext(out); + } else if (message->type() == ipc::Message::Type::RECORD_BATCH) { + return ipc::ReadRecordBatch(*message, schema_, out); + } else { + return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); + } } else { // Stream is completed stream_finished_ = true; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 46815b5476c67..018c079501f2f 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -102,6 +102,10 @@ class SerializationTraits { int64_t body_size = 0; for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + body_size += buffer->size(); const int64_t remainder = buffer->size() % 8; @@ -111,7 +115,11 @@ class SerializationTraits { } // 2 bytes for body tag - total_size += 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + // Only written when there are body buffers + if (msg.body_length > 0) { + total_size += + 2 + WireFormatLite::LengthDelimitedSize(static_cast(body_size)); + } // TODO(wesm): messages over 2GB unlikely to be yet supported if (total_size > kInt32Max) { @@ -135,20 +143,27 @@ class SerializationTraits { pb_stream.WriteRawMaybeAliased(msg.metadata->data(), static_cast(msg.metadata->size())); - // Write body - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(static_cast(body_size)); + // Don't write tag if there are no body buffers + if (msg.body_length > 0) { + // Write body + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); + pb_stream.WriteVarint32(static_cast(body_size)); - constexpr uint8_t kPaddingBytes[8] = {0}; + constexpr uint8_t kPaddingBytes[8] = {0}; - for (const auto& buffer : msg.body_buffers) { - pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); + for (const auto& buffer : msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; - // Write padding if not multiple of 8 - const int remainder = static_cast(buffer->size() % 8); - if (remainder) { - pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast(buffer->size())); + + // Write padding if not multiple of 8 + const int remainder = static_cast(buffer->size() % 8); + if (remainder) { + pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + } } } @@ -255,6 +270,14 @@ class FlightServiceImpl : public FlightService::Service { // Requires ServerWriter customization in grpc_customizations.h auto custom_writer = reinterpret_cast*>(writer); + // Write the schema as the first message in the stream + IpcPayload schema_payload; + MemoryPool* pool = default_memory_pool(); + ipc::DictionaryMemo dictionary_memo; + GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload( + *data_stream->schema(), pool, &dictionary_memo, &schema_payload)); + custom_writer->Write(schema_payload, grpc::WriteOptions()); + while (true) { IpcPayload payload; GRPC_RETURN_NOT_OK(data_stream->Next(&payload)); @@ -368,6 +391,8 @@ Status FlightServerBase::ListActions(std::vector* actions) { RecordBatchStream::RecordBatchStream(const std::shared_ptr& reader) : pool_(default_memory_pool()), reader_(reader) {} +std::shared_ptr RecordBatchStream::schema() { return reader_->schema(); } + Status RecordBatchStream::Next(IpcPayload* payload) { std::shared_ptr batch; RETURN_NOT_OK(reader_->ReadNext(&batch)); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 9c632d699427c..b3b8239132b7a 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -28,6 +28,7 @@ #include "arrow/util/visibility.h" #include "arrow/flight/types.h" +#include "arrow/ipc/dictionary.h" namespace arrow { @@ -57,6 +58,9 @@ class ARROW_EXPORT FlightDataStream { public: virtual ~FlightDataStream() = default; + // When the stream starts, send the schema. + virtual std::shared_ptr schema() = 0; + // When the stream is completed, the last payload written will have null // metadata virtual Status Next(ipc::internal::IpcPayload* payload) = 0; @@ -69,6 +73,7 @@ class ARROW_EXPORT RecordBatchStream : public FlightDataStream { public: explicit RecordBatchStream(const std::shared_ptr& reader); + std::shared_ptr schema() override; Status Next(ipc::internal::IpcPayload* payload) override; private: diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc new file mode 100644 index 0000000000000..267025a451cc7 --- /dev/null +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Client implementation for Flight integration testing. Requests the given +// path from the Flight server, which reads that file and sends it as a stream +// to the client. The client writes the server stream to the IPC file format at +// the given output file path. The integration test script then uses the +// existing integration test tools to compare the output binary with the +// original JSON + +#include +#include +#include + +#include + +#include "arrow/io/test-common.h" +#include "arrow/ipc/json.h" +#include "arrow/record_batch.h" + +#include "arrow/flight/server.h" +#include "arrow/flight/test-util.h" + +DEFINE_string(host, "localhost", "Server port to connect to"); +DEFINE_int32(port, 31337, "Server port to connect to"); +DEFINE_string(path, "", "Resource path to request"); +DEFINE_string(output, "", "Where to write requested resource"); + +int main(int argc, char** argv) { + gflags::SetUsageMessage("Integration testing client for Flight."); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + std::unique_ptr client; + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(FLAGS_host, FLAGS_port, &client)); + + arrow::flight::FlightDescriptor descr{ + arrow::flight::FlightDescriptor::PATH, "", {FLAGS_path}}; + std::unique_ptr info; + ABORT_NOT_OK(client->GetFlightInfo(descr, &info)); + + std::shared_ptr schema; + ABORT_NOT_OK(info->GetSchema(&schema)); + + if (info->endpoints().size() == 0) { + std::cerr << "No endpoints returned from Flight server." << std::endl; + return -1; + } + + arrow::flight::Ticket ticket = info->endpoints()[0].ticket; + std::unique_ptr stream; + ABORT_NOT_OK(client->DoGet(ticket, schema, &stream)); + + std::shared_ptr out_file; + ABORT_NOT_OK(arrow::io::FileOutputStream::Open(FLAGS_output, &out_file)); + std::shared_ptr writer; + ABORT_NOT_OK(arrow::ipc::RecordBatchFileWriter::Open(out_file.get(), schema, &writer)); + + std::shared_ptr chunk; + while (true) { + ABORT_NOT_OK(stream->ReadNext(&chunk)); + if (chunk == nullptr) break; + ABORT_NOT_OK(writer->WriteRecordBatch(*chunk)); + } + + ABORT_NOT_OK(writer->Close()); + + return 0; +} diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc new file mode 100644 index 0000000000000..80813e7f19a4c --- /dev/null +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Example server implementation for integration testing purposes + +#include +#include +#include +#include + +#include + +#include "arrow/io/test-common.h" +#include "arrow/ipc/json.h" +#include "arrow/record_batch.h" + +#include "arrow/flight/server.h" +#include "arrow/flight/test-util.h" + +DEFINE_int32(port, 31337, "Server port to listen on"); + +namespace arrow { +namespace flight { + +class JsonReaderRecordBatchStream : public FlightDataStream { + public: + explicit JsonReaderRecordBatchStream( + std::unique_ptr&& reader) + : index_(0), pool_(default_memory_pool()), reader_(std::move(reader)) {} + + std::shared_ptr schema() override { return reader_->schema(); } + + Status Next(ipc::internal::IpcPayload* payload) override { + if (index_ >= reader_->num_record_batches()) { + // Signal that iteration is over + payload->metadata = nullptr; + return Status::OK(); + } + + std::shared_ptr batch; + RETURN_NOT_OK(reader_->ReadRecordBatch(index_, &batch)); + index_++; + + if (!batch) { + // Signal that iteration is over + payload->metadata = nullptr; + return Status::OK(); + } else { + return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload); + } + } + + private: + int index_; + MemoryPool* pool_; + std::unique_ptr reader_; +}; + +class FlightIntegrationTestServer : public FlightServerBase { + Status ReadJson(const std::string& json_path, + std::unique_ptr* out) { + std::shared_ptr in_file; + std::cout << "Opening JSON file '" << json_path << "'" << std::endl; + RETURN_NOT_OK(io::ReadableFile::Open(json_path, &in_file)); + + int64_t file_size = 0; + RETURN_NOT_OK(in_file->GetSize(&file_size)); + + std::shared_ptr json_buffer; + RETURN_NOT_OK(in_file->Read(file_size, &json_buffer)); + + RETURN_NOT_OK(arrow::ipc::internal::json::JsonReader::Open(json_buffer, out)); + return Status::OK(); + } + + Status GetFlightInfo(const FlightDescriptor& request, + std::unique_ptr* info) override { + if (request.type == FlightDescriptor::PATH) { + if (request.path.size() == 0) { + return Status::Invalid("Invalid path"); + } + + std::unique_ptr reader; + RETURN_NOT_OK(ReadJson(request.path.back(), &reader)); + + FlightEndpoint endpoint1({{request.path.back()}, {}}); + + FlightInfo::Data flight_data; + RETURN_NOT_OK(internal::SchemaToString(*reader->schema(), &flight_data.schema)); + flight_data.descriptor = request; + flight_data.endpoints = {endpoint1}; + flight_data.total_records = reader->num_record_batches(); + flight_data.total_bytes = -1; + FlightInfo value(flight_data); + + *info = std::unique_ptr(new FlightInfo(value)); + return Status::OK(); + } else { + return Status::NotImplemented(request.type); + } + } + + Status DoGet(const Ticket& request, + std::unique_ptr* data_stream) override { + std::unique_ptr reader; + RETURN_NOT_OK(ReadJson(request.ticket, &reader)); + + *data_stream = std::unique_ptr( + new JsonReaderRecordBatchStream(std::move(reader))); + + return Status::OK(); + } +}; + +} // namespace flight +} // namespace arrow + +std::unique_ptr g_server; + +void Shutdown(int signal) { + if (g_server != nullptr) { + g_server->Shutdown(); + } +} + +int main(int argc, char** argv) { + gflags::SetUsageMessage("Integration testing server for Flight."); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // SIGTERM shuts down the server + signal(SIGTERM, Shutdown); + + g_server.reset(new arrow::flight::FlightIntegrationTestServer); + g_server->Run(FLAGS_port); + return 0; +} diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index e856acafd7138..1f04fad81743c 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -701,6 +701,12 @@ Status ReadSchema(io::InputStream* stream, std::shared_ptr* out) { return Status::OK(); } +Status ReadSchema(const Message& message, std::shared_ptr* out) { + std::shared_ptr reader; + DictionaryMemo dictionary_memo; + return internal::GetSchema(message.header(), dictionary_memo, &*out); +} + Status ReadRecordBatch(const std::shared_ptr& schema, io::InputStream* file, std::shared_ptr* out) { std::unique_ptr message; diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index ebecea13ffb8b..641de3eaf7b41 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -175,6 +175,14 @@ class ARROW_EXPORT RecordBatchFileReader { ARROW_EXPORT Status ReadSchema(io::InputStream* stream, std::shared_ptr* out); +/// \brief Read Schema from encapsulated Message +/// +/// \param[in] message a message instance containing metadata +/// \param[out] out the resulting Schema +/// \return Status +ARROW_EXPORT +Status ReadSchema(const Message& message, std::shared_ptr* out); + /// Read record batch as encapsulated IPC message with metadata size prefix and /// header /// diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 0bf68142c7776..1eb91998b5a93 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -524,6 +524,15 @@ Status WriteIpcPayload(const IpcPayload& payload, io::OutputStream* dst, return Status::OK(); } +Status GetSchemaPayload(const Schema& schema, MemoryPool* pool, + DictionaryMemo* dictionary_memo, IpcPayload* out) { + out->type = Message::Type::SCHEMA; + out->body_buffers.clear(); + out->body_length = 0; + RETURN_NOT_OK(SerializeSchema(schema, pool, &out->metadata)); + return WriteSchemaMessage(schema, dictionary_memo, &out->metadata); +} + Status GetRecordBatchPayload(const RecordBatch& batch, MemoryPool* pool, IpcPayload* out) { RecordBatchSerializer writer(pool, 0, kMaxNestingDepth, true, out); diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 0a2b0a09c2fd2..5b099d59c0ef0 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -30,6 +30,7 @@ namespace arrow { class Buffer; +class DictionaryMemo; class MemoryPool; class RecordBatch; class Schema; @@ -313,6 +314,17 @@ ARROW_EXPORT Status GetDictionaryPayloads(const Schema& schema, std::vector>* out); +/// \brief Compute IpcPayload for the given schema +/// \param[in] schema the Schema that is being serialized +/// \param[in,out] pool for any required temporary memory allocations +/// \param[in,out] dictionary_memo class for tracking dictionaries and assigning +/// dictionary ids +/// \param[out] out the returned IpcPayload +/// \return Status +ARROW_EXPORT +Status GetSchemaPayload(const Schema& schema, MemoryPool* pool, + DictionaryMemo* dictionary_memo, IpcPayload* out); + /// \brief Compute IpcPayload for the given record batch /// \param[in] batch the RecordBatch that is being serialized /// \param[in,out] pool for any required temporary memory allocations diff --git a/integration/integration_test.py b/integration/integration_test.py index c0191c372915c..0bced26f15acd 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -18,6 +18,7 @@ from collections import OrderedDict import argparse import binascii +import contextlib import glob import itertools import json @@ -26,7 +27,9 @@ import six import string import subprocess +import sys import tempfile +import traceback import uuid import errno @@ -931,10 +934,29 @@ def __init__(self, json_files, testers, tempdir=None, debug=False): self.debug = debug def run(self): + failures = [] for producer, consumer in itertools.product( filter(lambda t: t.PRODUCER, self.testers), filter(lambda t: t.CONSUMER, self.testers)): - self._compare_implementations(producer, consumer) + try: + self._compare_implementations(producer, consumer) + except Exception: + traceback.print_exc() + failures.append((producer, consumer, sys.exc_info())) + return failures + + def run_flight(self): + failures = [] + servers = filter(lambda t: t.FLIGHT_SERVER, self.testers) + clients = filter(lambda t: (t.FLIGHT_CLIENT and t.CONSUMER), + self.testers) + for server, client in itertools.product(servers, clients): + try: + self._compare_flight_implementations(server, client) + except Exception: + traceback.print_exc() + failures.append((server, client, sys.exc_info())) + return failures def _compare_implementations(self, producer, consumer): print('##########################################################') @@ -975,10 +997,43 @@ def _compare_implementations(self, producer, consumer): consumer_file_path) consumer.validate(json_path, consumer_file_path) + def _compare_flight_implementations(self, producer, consumer): + print('##########################################################') + print( + '{0} serving, {1} requesting'.format(producer.name, consumer.name) + ) + print('##########################################################') + + for json_path in self.json_files: + print('==========================================================') + print('Testing file {0}'.format(json_path)) + print('==========================================================') + + name = os.path.splitext(os.path.basename(json_path))[0] + + file_id = guid()[:8] + + with producer.flight_server(): + # Have the client request the file + consumer_file_path = os.path.join( + self.temp_dir, + file_id + '_' + name + '.consumer_requested_file') + consumer.flight_request(producer.FLIGHT_PORT, + json_path, consumer_file_path) + + # Validate the file + print('-- Validating file') + consumer.validate(json_path, consumer_file_path) + + # TODO: also have the client upload the file + class Tester(object): PRODUCER = False CONSUMER = False + FLIGHT_SERVER = False + FLIGHT_CLIENT = False + FLIGHT_PORT = 31337 def __init__(self, debug=False): self.debug = debug @@ -995,10 +1050,20 @@ def file_to_stream(self, file_path, stream_path): def validate(self, json_path, arrow_path): raise NotImplementedError + def flight_server(self): + raise NotImplementedError + + def flight_request(self, port, json_path, arrow_path): + raise NotImplementedError + class JavaTester(Tester): PRODUCER = True CONSUMER = True + FLIGHT_SERVER = True + FLIGHT_CLIENT = True + + FLIGHT_PORT = 31338 _arrow_version = load_version_from_pom() ARROW_TOOLS_JAR = os.environ.get( @@ -1006,6 +1071,15 @@ class JavaTester(Tester): os.path.join(ARROW_HOME, 'java/tools/target/arrow-tools-{}-' 'jar-with-dependencies.jar'.format(_arrow_version))) + ARROW_FLIGHT_JAR = os.environ.get( + 'ARROW_FLIGHT_JAVA_INTEGRATION_JAR', + os.path.join(ARROW_HOME, + 'java/flight/target/arrow-flight-{}-' + 'jar-with-dependencies.jar'.format(_arrow_version))) + ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.example.integration.' + 'IntegrationTestServer') + ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.example.integration.' + 'IntegrationTestClient') name = 'Java' @@ -1048,10 +1122,41 @@ def file_to_stream(self, file_path, stream_path): print(' '.join(cmd)) run_cmd(cmd) + def flight_request(self, port, json_path, arrow_path): + cmd = ['java', '-cp', self.ARROW_FLIGHT_JAR, + self.ARROW_FLIGHT_CLIENT, + '-port', str(port), + '-j', json_path, + '-a', arrow_path] + if self.debug: + print(' '.join(cmd)) + run_cmd(cmd) + + @contextlib.contextmanager + def flight_server(self): + cmd = ['java', '-cp', self.ARROW_FLIGHT_JAR, + self.ARROW_FLIGHT_SERVER, + '-port', str(self.FLIGHT_PORT)] + if self.debug: + print(' '.join(cmd)) + server = subprocess.Popen(cmd, stdout=subprocess.PIPE) + try: + output = server.stdout.readline().decode() + if not output.startswith("Server listening on localhost"): + raise RuntimeError( + "Flight-Java server did not start properly, output: " + + output) + yield + finally: + server.terminate() + server.wait(5) + class CPPTester(Tester): PRODUCER = True CONSUMER = True + FLIGHT_SERVER = True + FLIGHT_CLIENT = True EXE_PATH = os.environ.get( 'ARROW_CPP_EXE_PATH', @@ -1061,6 +1166,15 @@ class CPPTester(Tester): STREAM_TO_FILE = os.path.join(EXE_PATH, 'arrow-stream-to-file') FILE_TO_STREAM = os.path.join(EXE_PATH, 'arrow-file-to-stream') + FLIGHT_PORT = 31337 + + FLIGHT_SERVER_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-server'), + "-port", str(FLIGHT_PORT)] + FLIGHT_CLIENT_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-client'), + "-host", "localhost"] + name = 'C++' def _run(self, arrow_path=None, json_path=None, command='VALIDATE'): @@ -1099,6 +1213,33 @@ def file_to_stream(self, file_path, stream_path): print(cmd) os.system(cmd) + @contextlib.contextmanager + def flight_server(self): + if self.debug: + print(' '.join(self.FLIGHT_SERVER_CMD)) + server = subprocess.Popen(self.FLIGHT_SERVER_CMD, + stdout=subprocess.PIPE) + try: + output = server.stdout.readline().decode() + if not output.startswith("Server listening on localhost"): + raise RuntimeError( + "Flight-C++ server did not start properly, output: " + + output) + yield + finally: + server.terminate() + server.wait(5) + + def flight_request(self, port, json_path, arrow_path): + cmd = self.FLIGHT_CLIENT_CMD + [ + '-port=' + str(port), + '-path=' + json_path, + '-output=' + arrow_path + ] + if self.debug: + print(' '.join(cmd)) + subprocess.run(cmd) + class JSTester(Tester): PRODUCER = True @@ -1166,7 +1307,7 @@ def get_static_json_files(): return glob.glob(glob_pattern) -def run_all_tests(debug=False, tempdir=None): +def run_all_tests(run_flight=False, debug=False, tempdir=None): testers = [CPPTester(debug=debug), JavaTester(debug=debug), JSTester(debug=debug)] @@ -1176,8 +1317,22 @@ def run_all_tests(debug=False, tempdir=None): runner = IntegrationRunner(json_files, testers, tempdir=tempdir, debug=debug) - runner.run() - print('-- All tests passed!') + failures = [] + failures.extend(runner.run()) + if run_flight: + failures.extend(runner.run_flight()) + + print() + print('##########################################################') + if not failures: + print('-- All tests passed!') + else: + print('-- Tests completed, failures:') + for producer, consumer, exc_info in failures: + print("FAILED TEST:", producer.name, "producing, ", + consumer.name, "consuming") + traceback.print_exception(*exc_info) + print() def write_js_test_json(directory): @@ -1197,6 +1352,9 @@ def write_js_test_json(directory): parser.add_argument('--write_generated_json', dest='generated_json_path', action='store', default=False, help='Generate test JSON') + parser.add_argument('--run_flight', dest='run_flight', + action='store_true', default=False, + help='Run Flight integration tests') parser.add_argument('--debug', dest='debug', action='store_true', default=False, help='Run executables in debug mode as relevant') @@ -1213,4 +1371,5 @@ def write_js_test_json(directory): raise write_js_test_json(args.generated_json_path) else: - run_all_tests(debug=args.debug, tempdir=args.tempdir) + run_all_tests(run_flight=args.run_flight, + debug=args.debug, tempdir=args.tempdir) diff --git a/java/flight/pom.xml b/java/flight/pom.xml index 67733f382946e..48939df886fd4 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -48,19 +48,16 @@ io.grpc grpc-netty ${dep.grpc.version} - provided io.grpc grpc-core ${dep.grpc.version} - provided io.grpc grpc-protobuf ${dep.grpc.version} - provided io.netty @@ -75,11 +72,15 @@ com.google.guava guava + + commons-cli + commons-cli + 1.4 + io.grpc grpc-stub ${dep.grpc.version} - provided com.google.protobuf @@ -225,6 +226,43 @@ + + org.codehaus.mojo + build-helper-maven-plugin + 1.9.1 + + + add-generated-sources-to-classpath + generate-sources + + add-source + + + + ${project.build.directory}/generated-sources/protobuf + + + + + + + maven-assembly-plugin + 3.0.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + diff --git a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 9764ff39a4a19..d2f7bb6c713b5 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -22,6 +22,8 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.arrow.flatbuf.Message; @@ -52,10 +54,12 @@ import io.grpc.MethodDescriptor.Marshaller; import io.grpc.internal.ReadableBuffer; import io.grpc.protobuf.ProtoUtils; + import io.netty.buffer.ArrowBuf; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; /** * The in-memory representation of FlightData used to manage a stream of Arrow messages. @@ -95,6 +99,18 @@ public static HeaderType getHeader(byte b) { } + // Pre-allocated buffers for padding serialized ArrowMessages. + private static List PADDING_BUFFERS = Arrays.asList( + null, + Unpooled.copiedBuffer(new byte[] { 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0, 0 }) + ); + private final FlightDescriptor descriptor; private final Message message; private final List bufs; @@ -253,8 +269,17 @@ private InputStream asInputStream(BufferAllocator allocator) { cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); int size = 0; + List allBufs = new ArrayList<>(); for (ArrowBuf b : bufs) { + allBufs.add(b); size += b.readableBytes(); + // [ARROW-4213] These buffers must be aligned to an 8-byte boundary in order to be readable from C++. + if (b.readableBytes() % 8 != 0) { + int paddingBytes = 8 - (b.readableBytes() % 8); + assert paddingBytes > 0 && paddingBytes < 8; + size += paddingBytes; + allBufs.add(PADDING_BUFFERS.get(paddingBytes).retain()); + } } // rawvarint is used for length definition. cos.writeUInt32NoTag(size); @@ -263,7 +288,7 @@ private InputStream asInputStream(BufferAllocator allocator) { ArrowBuf initialBuf = allocator.buffer(baos.size()); initialBuf.writeBytes(baos.toByteArray()); final CompositeByteBuf bb = new CompositeByteBuf(allocator.getAsByteBufAllocator(), true, bufs.size() + 1, - ImmutableList.builder().add(initialBuf).addAll(bufs).build()); + ImmutableList.builder().add(initialBuf).addAll(allBufs).build()); final ByteBufInputStream is = new DrainableByteBufInputStream(bb); return is; } catch (Exception ex) { diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java index 5e7aad178e70d..9accbbe434a10 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -17,13 +17,22 @@ package org.apache.arrow.flight; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; import java.util.List; import java.util.stream.Collectors; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.Flight.FlightGetInfo; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; +import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; + import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; @@ -45,8 +54,15 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List 0 ? - Schema.deserialize(flightGetInfo.getSchema().asReadOnlyByteBuffer()) : new Schema(ImmutableList.of()); + try { + final ByteBuffer schemaBuf = flightGetInfo.getSchema().asReadOnlyByteBuffer(); + schema = flightGetInfo.getSchema().size() > 0 ? + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf)))) + : new Schema(ImmutableList.of()); + } catch (IOException e) { + throw new RuntimeException(e); + } descriptor = new FlightDescriptor(flightGetInfo.getFlightDescriptor()); endpoints = flightGetInfo.getEndpointList().stream().map(t -> new FlightEndpoint(t)).collect(Collectors.toList()); bytes = flightGetInfo.getTotalBytes(); @@ -74,9 +90,16 @@ public List getEndpoints() { } FlightGetInfo toProtocol() { + // Encode schema in a Message payload + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema); + } catch (IOException e) { + throw new RuntimeException(e); + } return Flight.FlightGetInfo.newBuilder() .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList())) - .setSchema(ByteString.copyFrom(schema.toByteArray())) + .setSchema(ByteString.copyFrom(baos.toByteArray())) .setFlightDescriptor(descriptor.toProtocol()) .setTotalBytes(FlightInfo.this.bytes) .setTotalRecords(records) diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java index 91499123134c3..389497e884d09 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java @@ -128,7 +128,8 @@ public boolean isCancelled() { @Override public void start(VectorSchemaRoot root) { responseObserver.onNext(new ArrowMessage(null, root.getSchema())); - unloader = new VectorUnloader(root, true, false); + // [ARROW-4213] We must align buffers to be compatible with other languages. + unloader = new VectorUnloader(root, true, true); } @Override diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java index 5cba7ab47aa30..616b9cdc267a5 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -178,6 +178,7 @@ public void onNext(ArrowMessage msg) { public void onError(Throwable t) { ex = t; queue.add(DONE_EX); + root.setException(t); } @Override diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java new file mode 100644 index 0000000000000..803a56c6c1afe --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.List; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; + +/** + * An Example Flight Server that provides access to the InMemoryStore. + */ +class IntegrationTestClient { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class); + private final Options options; + + private IntegrationTestClient() { + options = new Options(); + options.addOption("a", "arrow", true, "arrow file"); + options.addOption("j", "json", true, "json file"); + options.addOption("host", true, "The host to connect to."); + options.addOption("port", true, "The port to connect to." ); + } + + public static void main(String[] args) { + try { + new IntegrationTestClient().run(args); + } catch (ParseException e) { + fatalError("Invalid parameters", e); + } catch (IOException e) { + fatalError("Error accessing files", e); + } + } + + static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); + } + + private void run(String[] args) throws ParseException, IOException { + CommandLineParser parser = new DefaultParser(); + CommandLine cmd = parser.parse(options, args, false); + + String fileName = cmd.getOptionValue("arrow"); + if (fileName == null) { + throw new IllegalArgumentException("missing arrow file parameter"); + } + File arrowFile = new File(fileName); + if (arrowFile.exists()) { + throw new IllegalArgumentException("arrow file already exists: " + arrowFile.getAbsolutePath()); + } + + final String host = cmd.getOptionValue("host", "localhost"); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); + + final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + FlightClient client = new FlightClient(allocator, new Location(host, port)); + FlightInfo info = client.getInfo(FlightDescriptor.path(cmd.getOptionValue("json"))); + List endpoints = info.getEndpoints(); + if (endpoints.isEmpty()) { + throw new RuntimeException("No endpoints returned from Flight server."); + } + + FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); + try (VectorSchemaRoot root = stream.getRoot(); + FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, new DictionaryProvider.MapDictionaryProvider(), + fileOutputStream.getChannel())) { + while (stream.next()) { + arrowWriter.writeBatch(); + } + } + } +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java new file mode 100644 index 0000000000000..7b45e53a149be --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.concurrent.Callable; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.JsonFileReader; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; + +class IntegrationTestServer { + private final Options options; + + private IntegrationTestServer() { + options = new Options(); + options.addOption("port", true, "The port to serve on."); + } + + private void run(String[] args) throws Exception { + CommandLineParser parser = new DefaultParser(); + CommandLine cmd = parser.parse(options, args, false); + + final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); + try (final IntegrationFlightProducer producer = new IntegrationFlightProducer(allocator); + final FlightServer server = new FlightServer(allocator, port, producer, ServerAuthHandler.NO_OP)) { + server.start(); + // Print out message for integration test script + System.out.println("Server listening on localhost:" + server.getPort()); + while (true) { + Thread.sleep(30000); + } + } + } + + public static void main(String[] args) { + try { + new IntegrationTestServer().run(args); + } catch (ParseException e) { + IntegrationTestClient.fatalError("Error parsing arguments", e); + } catch (Exception e) { + IntegrationTestClient.fatalError("Runtime error", e); + } + } + + static class IntegrationFlightProducer implements FlightProducer, AutoCloseable { + private final BufferAllocator allocator; + + IntegrationFlightProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void close() { + allocator.close(); + } + + @Override + public void getStream(Ticket ticket, ServerStreamListener listener) { + String path = new String(ticket.getBytes(), StandardCharsets.UTF_8); + File inputFile = new File(path); + try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { + Schema schema = reader.start(); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + listener.start(root); + while (reader.read(root)) { + listener.putNext(); + } + listener.completed(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void listFlights(Criteria criteria, StreamListener listener) { + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfo(FlightDescriptor descriptor) { + if (descriptor.isCommand()) { + throw new UnsupportedOperationException("Commands not supported."); + } + if (descriptor.getPath().size() < 1) { + throw new IllegalArgumentException("Must provide a path."); + } + String path = descriptor.getPath().get(0); + File inputFile = new File(path); + try (JsonFileReader reader = new JsonFileReader(inputFile, allocator)) { + Schema schema = reader.start(); + return new FlightInfo(schema, descriptor, + Collections.singletonList(new FlightEndpoint(new Ticket(path.getBytes()), + new Location("localhost", 31338))), + 0, 0); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Callable acceptPut(FlightStream flightStream) { + return null; + } + + @Override + public Result doAction(Action action) { + return null; + } + + @Override + public void listActions(StreamListener listener) { + listener.onCompleted(); + } + } +}