Skip to content

Commit

Permalink
changes in test_sql.py
Browse files Browse the repository at this point in the history
  • Loading branch information
skarakuzu committed Dec 16, 2024
1 parent b506988 commit 01a6f18
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 31 deletions.
2 changes: 1 addition & 1 deletion tiled/_tests/adapters/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def adapter() -> ArrowAdapter:
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=3)
assets = ArrowAdapter.init_storage(data_uri, structure=structure)
assets = ArrowAdapter.init_storage(data_uri, structure=structure, path_parts=[])
return ArrowAdapter([asset.data_uri for asset in assets], structure=structure)


Expand Down
101 changes: 74 additions & 27 deletions tiled/_tests/adapters/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest

from tiled.adapters.sql import SQLAdapter
from tiled.structures.core import StructureFamily
from tiled.structures.data_source import DataSource, Management, Storage
from tiled.structures.table import TableStructure

names = ["f0", "f1", "f2"]
Expand All @@ -26,33 +28,33 @@
batch2 = pa.record_batch(data2, names=names)


def test_invalid_uri() -> None:
data_uri = "/some_random_uri/test.db"
@pytest.fixture
def data_source_from_init_storage() -> DataSource:
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
asset = SQLAdapter.init_storage(data_uri, structure=structure)
with pytest.raises(
ValueError,
match="The database uri must start with either `sqlite://` or `postgresql://` ",
):
SQLAdapter(asset.data_uri, structure=structure)


def test_invalid_structure() -> None:
data_uri = "/some_random_uri/test.db"
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=3)
with pytest.raises(ValueError, match="The SQL adapter must have only 1 partition"):
SQLAdapter.init_storage(data_uri, structure=structure)
data_source = DataSource(
management=Management.writable,
mimetype="application/x-tiled-sql-table",
structure_family=StructureFamily.table,
structure=structure,
assets=[],
)
storage = Storage(filesystem=None, sql="sqlite:////tmp/test.sqlite")
return SQLAdapter.init_storage(
data_source=data_source, storage=storage, path_parts=[]
)


@pytest.fixture
def adapter_sql() -> SQLAdapter:
def adapter_sql(data_source_from_init_storage: DataSource) -> SQLAdapter:
data_uri = "sqlite://file://localhost" + tempfile.gettempdir() + "/test.db"
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
asset = SQLAdapter.init_storage(data_uri, structure=structure)
return SQLAdapter(asset.data_uri, structure=structure)
data_source = data_source_from_init_storage
return SQLAdapter(
data_uri,
data_source.structure,
data_source.parameters["table_name"],
data_source.parameters["dataset_id"],
)


def test_attributes(adapter_sql: SQLAdapter) -> None:
Expand All @@ -61,7 +63,7 @@ def test_attributes(adapter_sql: SQLAdapter) -> None:
assert isinstance(adapter_sql.conn, adbc_driver_sqlite.dbapi.AdbcSqliteConnection)


def test_write_read(adapter_sql: SQLAdapter) -> None:
def test_write_read_sql(adapter_sql: SQLAdapter) -> None:
# test writing and reading it
adapter_sql.write(batch0)
result = adapter_sql.read()
Expand Down Expand Up @@ -111,11 +113,16 @@ def postgres_uri() -> str:


@pytest.fixture
def adapter_psql(postgres_uri: str) -> SQLAdapter:
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
asset = SQLAdapter.init_storage(postgres_uri, structure=structure)
return SQLAdapter(asset.data_uri, structure=structure)
def adapter_psql(
data_source_from_init_storage: DataSource, postgres_uri: str
) -> SQLAdapter:
data_source = data_source_from_init_storage
return SQLAdapter(
postgres_uri,
data_source.structure,
data_source.parameters["table_name"],
data_source.parameters["dataset_id"],
)


def test_psql(postgres_uri: str, adapter_psql: SQLAdapter) -> None:
Expand All @@ -124,3 +131,43 @@ def test_psql(postgres_uri: str, adapter_psql: SQLAdapter) -> None:
# assert isinstance(
# adapter_psql.conn, adbc_driver_postgresql.dbapi.AdbcSqliteConnection
# )


def test_write_read_psql(adapter_psql: SQLAdapter) -> None:
# test writing and reading it
adapter_psql.write(batch0)
result = adapter_psql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")

assert pa.Table.from_arrays(data0, names) == pa.Table.from_pandas(result)

adapter_psql.write([batch0, batch1])
result = adapter_psql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")
assert pa.Table.from_batches([batch0, batch1]) == pa.Table.from_pandas(result)

adapter_psql.write([batch0, batch1, batch2])
result = adapter_psql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")
assert pa.Table.from_batches([batch0, batch1, batch2]) == pa.Table.from_pandas(
result
)

# test write , append and read all
adapter_psql.write([batch0, batch1, batch2])
adapter_psql.append([batch2, batch0, batch1])
adapter_psql.append([batch1, batch2, batch0])
result = adapter_psql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")

assert pa.Table.from_batches(
[batch0, batch1, batch2, batch2, batch0, batch1, batch1, batch2, batch0]
) == pa.Table.from_pandas(result)
38 changes: 35 additions & 3 deletions tiled/adapters/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def init_storage(
Parameters
----------
data_uri : the uri of the data
structure : the structure of the data
storage: the storage kind
data_source : data source describing the data
path_parts: ?
Returns
-------
A modified copy of the data source
Expand Down Expand Up @@ -232,6 +232,38 @@ def append_partition(
self.cur.adbc_ingest(self.table_name, reader, mode="append")
self.conn.commit()

def append(
self,
data: Union[List[pyarrow.record_batch], pyarrow.record_batch, pandas.DataFrame],
) -> None:
"""
"Function to write the data as arrow format."
Parameters
----------
data : data to append into the database. Can be a list of record batch, or pandas dataframe.
table_name: string indicating the name of the table to ingest data in the database.
Returns
-------
"""
if isinstance(data, pandas.DataFrame):
table = pyarrow.Table.from_pandas(data)
batches = table.to_batches()
else:
if not isinstance(data, list):
batches = [data]
else:
batches = data

schema = batches[
0
].schema # list of column names can be obtained from schema.names

reader = pyarrow.ipc.RecordBatchReader.from_batches(schema, batches)

self.cur.adbc_ingest(self.table_name, reader, mode="append")
self.conn.commit()

def read(self, fields: Optional[Union[str, List[str]]] = None) -> pandas.DataFrame:
"""
The concatenated data from given set of partitions as pyarrow table.
Expand Down

0 comments on commit 01a6f18

Please sign in to comment.