diff --git a/tiled/_tests/adapters/test_arrow.py b/tiled/_tests/adapters/test_arrow.py index 2f10b1dfb..7515cb85b 100644 --- a/tiled/_tests/adapters/test_arrow.py +++ b/tiled/_tests/adapters/test_arrow.py @@ -29,7 +29,7 @@ def adapter() -> ArrowAdapter: table = pa.Table.from_arrays(data0, names) structure = TableStructure.from_arrow_table(table, npartitions=3) - assets = ArrowAdapter.init_storage(data_uri, structure=structure) + assets = ArrowAdapter.init_storage(data_uri, structure=structure, path_parts=[]) return ArrowAdapter([asset.data_uri for asset in assets], structure=structure) diff --git a/tiled/_tests/adapters/test_sql.py b/tiled/_tests/adapters/test_sql.py index 4a8e6fba1..44eab8f74 100644 --- a/tiled/_tests/adapters/test_sql.py +++ b/tiled/_tests/adapters/test_sql.py @@ -6,6 +6,8 @@ import pytest from tiled.adapters.sql import SQLAdapter +from tiled.structures.core import StructureFamily +from tiled.structures.data_source import DataSource, Management, Storage from tiled.structures.table import TableStructure names = ["f0", "f1", "f2"] @@ -26,33 +28,33 @@ batch2 = pa.record_batch(data2, names=names) -def test_invalid_uri() -> None: - data_uri = "/some_random_uri/test.db" +@pytest.fixture +def data_source_from_init_storage() -> DataSource: table = pa.Table.from_arrays(data0, names) structure = TableStructure.from_arrow_table(table, npartitions=1) - asset = SQLAdapter.init_storage(data_uri, structure=structure) - with pytest.raises( - ValueError, - match="The database uri must start with either `sqlite://` or `postgresql://` ", - ): - SQLAdapter(asset.data_uri, structure=structure) - - -def test_invalid_structure() -> None: - data_uri = "/some_random_uri/test.db" - table = pa.Table.from_arrays(data0, names) - structure = TableStructure.from_arrow_table(table, npartitions=3) - with pytest.raises(ValueError, match="The SQL adapter must have only 1 partition"): - SQLAdapter.init_storage(data_uri, structure=structure) + data_source = DataSource( + management=Management.writable, + mimetype="application/x-tiled-sql-table", + structure_family=StructureFamily.table, + structure=structure, + assets=[], + ) + storage = Storage(filesystem=None, sql="sqlite:////tmp/test.sqlite") + return SQLAdapter.init_storage( + data_source=data_source, storage=storage, path_parts=[] + ) @pytest.fixture -def adapter_sql() -> SQLAdapter: +def adapter_sql(data_source_from_init_storage: DataSource) -> SQLAdapter: data_uri = "sqlite://file://localhost" + tempfile.gettempdir() + "/test.db" - table = pa.Table.from_arrays(data0, names) - structure = TableStructure.from_arrow_table(table, npartitions=1) - asset = SQLAdapter.init_storage(data_uri, structure=structure) - return SQLAdapter(asset.data_uri, structure=structure) + data_source = data_source_from_init_storage + return SQLAdapter( + data_uri, + data_source.structure, + data_source.parameters["table_name"], + data_source.parameters["dataset_id"], + ) def test_attributes(adapter_sql: SQLAdapter) -> None: @@ -61,7 +63,7 @@ def test_attributes(adapter_sql: SQLAdapter) -> None: assert isinstance(adapter_sql.conn, adbc_driver_sqlite.dbapi.AdbcSqliteConnection) -def test_write_read(adapter_sql: SQLAdapter) -> None: +def test_write_read_sql(adapter_sql: SQLAdapter) -> None: # test writing and reading it adapter_sql.write(batch0) result = adapter_sql.read() @@ -111,11 +113,16 @@ def postgres_uri() -> str: @pytest.fixture -def adapter_psql(postgres_uri: str) -> SQLAdapter: - table = pa.Table.from_arrays(data0, names) - structure = TableStructure.from_arrow_table(table, npartitions=1) - asset = SQLAdapter.init_storage(postgres_uri, structure=structure) - return SQLAdapter(asset.data_uri, structure=structure) +def adapter_psql( + data_source_from_init_storage: DataSource, postgres_uri: str +) -> SQLAdapter: + data_source = data_source_from_init_storage + return SQLAdapter( + postgres_uri, + data_source.structure, + data_source.parameters["table_name"], + data_source.parameters["dataset_id"], + ) def test_psql(postgres_uri: str, adapter_psql: SQLAdapter) -> None: @@ -124,3 +131,43 @@ def test_psql(postgres_uri: str, adapter_psql: SQLAdapter) -> None: # assert isinstance( # adapter_psql.conn, adbc_driver_postgresql.dbapi.AdbcSqliteConnection # ) + + +def test_write_read_psql(adapter_psql: SQLAdapter) -> None: + # test writing and reading it + adapter_psql.write(batch0) + result = adapter_psql.read() + # the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean + # so we explicitely convert the last column to boolean for testing purposes + result["f2"] = result["f2"].astype("boolean") + + assert pa.Table.from_arrays(data0, names) == pa.Table.from_pandas(result) + + adapter_psql.write([batch0, batch1]) + result = adapter_psql.read() + # the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean + # so we explicitely convert the last column to boolean for testing purposes + result["f2"] = result["f2"].astype("boolean") + assert pa.Table.from_batches([batch0, batch1]) == pa.Table.from_pandas(result) + + adapter_psql.write([batch0, batch1, batch2]) + result = adapter_psql.read() + # the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean + # so we explicitely convert the last column to boolean for testing purposes + result["f2"] = result["f2"].astype("boolean") + assert pa.Table.from_batches([batch0, batch1, batch2]) == pa.Table.from_pandas( + result + ) + + # test write , append and read all + adapter_psql.write([batch0, batch1, batch2]) + adapter_psql.append([batch2, batch0, batch1]) + adapter_psql.append([batch1, batch2, batch0]) + result = adapter_psql.read() + # the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean + # so we explicitely convert the last column to boolean for testing purposes + result["f2"] = result["f2"].astype("boolean") + + assert pa.Table.from_batches( + [batch0, batch1, batch2, batch2, batch0, batch1, batch1, batch2, batch0] + ) == pa.Table.from_pandas(result) diff --git a/tiled/adapters/sql.py b/tiled/adapters/sql.py index 56fce5952..b57b5659b 100644 --- a/tiled/adapters/sql.py +++ b/tiled/adapters/sql.py @@ -86,9 +86,9 @@ def init_storage( Parameters ---------- - data_uri : the uri of the data - structure : the structure of the data - + storage: the storage kind + data_source : data source describing the data + path_parts: ? Returns ------- A modified copy of the data source @@ -232,6 +232,38 @@ def append_partition( self.cur.adbc_ingest(self.table_name, reader, mode="append") self.conn.commit() + def append( + self, + data: Union[List[pyarrow.record_batch], pyarrow.record_batch, pandas.DataFrame], + ) -> None: + """ + "Function to write the data as arrow format." + + Parameters + ---------- + data : data to append into the database. Can be a list of record batch, or pandas dataframe. + table_name: string indicating the name of the table to ingest data in the database. + Returns + ------- + """ + if isinstance(data, pandas.DataFrame): + table = pyarrow.Table.from_pandas(data) + batches = table.to_batches() + else: + if not isinstance(data, list): + batches = [data] + else: + batches = data + + schema = batches[ + 0 + ].schema # list of column names can be obtained from schema.names + + reader = pyarrow.ipc.RecordBatchReader.from_batches(schema, batches) + + self.cur.adbc_ingest(self.table_name, reader, mode="append") + self.conn.commit() + def read(self, fields: Optional[Union[str, List[str]]] = None) -> pandas.DataFrame: """ The concatenated data from given set of partitions as pyarrow table.