Skip to content

Commit

Permalink
added do_put() (#86)
Browse files Browse the repository at this point in the history
* added do_put()
supports writing parquet dataset to filesystems

Signed-off-by: Doron Chen <[email protected]>

* remove repeat parameter from sample_put.py

Signed-off-by: Doron Chen <[email protected]>

* minor change

Signed-off-by: Doron Chen <[email protected]>

* added documentation
based on Mohammad's comments

Signed-off-by: Doron Chen <[email protected]>
  • Loading branch information
cdoron authored Jul 7, 2021
1 parent 767e245 commit b91d0bd
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 2 deletions.
15 changes: 14 additions & 1 deletion afm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

import json
import logging
import os

Expand Down Expand Up @@ -54,6 +55,14 @@ def _filter_columns(self, schema, columns):
fields = [schema.field(c) for c in columns]
return pa.schema([pa.field(f.name, f.type, f.nullable, f.metadata) for f in fields])

# write arrow dataset to filesystem
def _write_asset(self, asset, reader):
# in this implementation we currently begin by reading the entire dataset
t = reader.read_all().combine_chunks()
# currently, write_dataset supports the parquet format, but not csv
ds.write_dataset(t, base_dir=asset.path, format=asset.format,
filesystem=asset.filesystem)

def _read_asset(self, asset, columns=None):
dataset, data_files = self._get_dataset(asset)
scanner = ds.Scanner.from_dataset(dataset, columns=columns, batch_size=64*2**20)
Expand Down Expand Up @@ -144,7 +153,11 @@ def do_get(self, context, ticket: fl.Ticket):
return fl.GeneratorStream(schema, batches)

def do_put(self, context, descriptor, reader, writer):
raise NotImplementedError("do_put")
logging.critical('do_put: descriptor={}'.format(descriptor))
asset_info = json.loads(descriptor.command)
with Config(self.config_path) as config:
asset = asset_from_config(config, asset_info['asset'])
self._write_asset(asset, reader)

def get_schema(self, context, descriptor):
info = self.get_flight_info(context, descriptor)
Expand Down
4 changes: 4 additions & 0 deletions module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ spec:
source:
protocol: s3
dataformat: csv
- flow: write
source:
protocol: s3
dataformat: parquet
actions:
- id: redact-ID
level: 2 # column
Expand Down
6 changes: 5 additions & 1 deletion sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@ It shows a server configured to serve [The New York City taxi trip record data](
```
1. Run a sample client with
```bash
pipenv run python sample/sample.py
pipenv run python sample/sample.py --username qqq --password moo
```
1. Write a parquet dataset to /tmp/new-dataset. The dataset consists of a single column with the numbers 0 to 10239
```bash
pipenv run python sample/sample_put.py --username qqq --password moo
```
5 changes: 5 additions & 0 deletions sample/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ workers:
address: "localhost"
port: 8080
data:
- name: "new-dataset"
format: parquet
path: "/tmp/new-dataset"
connection:
type: localfs
- name: "nyc-taxi.parquet"
format: parquet
path: "ursa-labs-taxi-data/2019"
Expand Down
61 changes: 61 additions & 0 deletions sample/sample_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Copyright 2020 IBM Corp.
# SPDX-License-Identifier: Apache-2.0
#
import pyarrow.flight as fl
import pyarrow as pa
import json

# taken from https://github.com/apache/arrow/blob/master/python/pyarrow/tests/test_flight.py#L450
class HttpBasicClientAuthHandler(fl.ClientAuthHandler):
"""An example implementation of HTTP basic authentication."""

def __init__(self, username, password):
super().__init__()
self.basic_auth = fl.BasicAuth(username, password)
self.token = None

def authenticate(self, outgoing, incoming):
auth = self.basic_auth.serialize()
outgoing.write(auth)
self.token = incoming.read()

def get_token(self):
return self.token

request = {
"asset": "new-dataset",
}

def main(port, username, password):
client = fl.connect("grpc://localhost:{}".format(port))
if username or password:
client.authenticate(HttpBasicClientAuthHandler(username, password))

# write the new dataset
data = pa.Table.from_arrays([pa.array(range(0, 10 * 1024))], names=['a'])
writer, _ = client.do_put(fl.FlightDescriptor.for_command(json.dumps(request)),
data.schema)
writer.write_table(data, 1024)
writer.close()

# now that the dataset is in place, let's try to read it
info = client.get_flight_info(
fl.FlightDescriptor.for_command(json.dumps(request)))

endpoint = info.endpoints[0]
result: fl.FlightStreamReader = client.do_get(endpoint.ticket)
print(result.read_all().to_pandas())

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='arrow-flight-module sample')
parser.add_argument(
'--port', type=int, default=8080, help='Listening port')
parser.add_argument(
'--username', type=str, default=None, help='Authentication username')
parser.add_argument(
'--password', type=str, default=None, help='Authentication password')
args = parser.parse_args()

main(args.port, args.username, args.password)

0 comments on commit b91d0bd

Please sign in to comment.