Skip to content

Commit

Permalink
tried to fix the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
skarakuzu committed Jan 7, 2025
1 parent 7c0ca34 commit 8d26ce5
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 22 deletions.
7 changes: 3 additions & 4 deletions tiled/_tests/adapters/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from tiled.adapters.arrow import ArrowAdapter
from tiled.structures.core import StructureFamily
from tiled.structures.table import TableStructure
from tiled.structures.data_source import DataSource, Management, Storage
from tiled.structures.table import TableStructure

names = ["f0", "f1", "f2"]
data0 = [
Expand All @@ -28,7 +28,7 @@


@pytest.fixture
def data_source_from_init_storage() -> DataSource:
def data_source_from_init_storage() -> DataSource[TableStructure]:
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=3)
data_source = DataSource(
Expand All @@ -45,15 +45,14 @@ def data_source_from_init_storage() -> DataSource:


@pytest.fixture
def adapter(data_source_from_init_storage: DataSource) -> ArrowAdapter:
def adapter(data_source_from_init_storage: DataSource[TableStructure]) -> ArrowAdapter:
data_source = data_source_from_init_storage
return ArrowAdapter(
[asset.data_uri for asset in data_source.assets],
data_source.structure,
)



def test_attributes(adapter: ArrowAdapter) -> None:
assert adapter.structure().columns == names
assert adapter.structure().npartitions == 3
Expand Down
8 changes: 5 additions & 3 deletions tiled/_tests/adapters/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@pytest.fixture
def data_source_from_init_storage() -> DataSource:
def data_source_from_init_storage() -> DataSource[TableStructure]:
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
data_source = DataSource(
Expand All @@ -46,7 +46,9 @@ def data_source_from_init_storage() -> DataSource:


@pytest.fixture
def adapter_sql(data_source_from_init_storage: DataSource) -> SQLAdapter:
def adapter_sql(
data_source_from_init_storage: DataSource[TableStructure],
) -> SQLAdapter:
data_uri = "sqlite://file://localhost" + tempfile.gettempdir() + "/test.db"
data_source = data_source_from_init_storage
return SQLAdapter(
Expand Down Expand Up @@ -114,7 +116,7 @@ def postgres_uri() -> str:

@pytest.fixture
def adapter_psql(
data_source_from_init_storage: DataSource, postgres_uri: str
data_source_from_init_storage: DataSource[TableStructure], postgres_uri: str
) -> SQLAdapter:
data_source = data_source_from_init_storage
return SQLAdapter(
Expand Down
2 changes: 1 addition & 1 deletion tiled/adapters/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def init_storage(
data_uri = str(storage.get("filesystem")) + "".join(
f"/{quote_plus(segment)}" for segment in path_parts
)
print('uri', data_uri)
print("uri", data_uri)
directory = path_from_uri(data_uri)
directory.mkdir(parents=True, exist_ok=True)
assets = [
Expand Down
3 changes: 1 addition & 2 deletions tiled/adapters/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@

from ..adapters.utils import IndexersMixin
from ..iterviews import ItemsView, KeysView, ValuesView
from ..server.schemas import DataSource
from ..structures.array import ArrayStructure
from ..structures.core import Spec, StructureFamily
from ..structures.data_source import Asset, Storage
from ..structures.data_source import Asset, DataSource, Storage
from ..type_aliases import JSON, NDSlice
from ..utils import Conflicts, node_repr, path_from_uri
from .array import ArrayAdapter, slice_and_shape_from_block_and_chunks
Expand Down
5 changes: 5 additions & 0 deletions tiled/client/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _build_arrays(self, variables, optimize_wide_table):
array_structure = array_structures[name]
shape = array_structure.shape
spec_names = set(spec.name for spec in array_client.specs)
print("spec_names", spec_names)
if optimize_wide_table and (
(not shape) # empty
or (
Expand All @@ -73,6 +74,7 @@ def _build_arrays(self, variables, optimize_wide_table):
)
):
if "xarray_coord" in spec_names:
print("hello 1")
coords[name] = (
array_client.dims,
coords_fetcher.register(name, array_client, array_structure),
Expand All @@ -91,6 +93,7 @@ def _build_arrays(self, variables, optimize_wide_table):
)
else:
if "xarray_coord" in spec_names:
print("hello 2")
coords[name] = (
array_client.dims,
array_client.read(),
Expand All @@ -111,6 +114,8 @@ def _build_arrays(self, variables, optimize_wide_table):

def read(self, variables=None, *, optimize_wide_table=True):
data_vars, coords = self._build_arrays(variables, optimize_wide_table)
print("corrds", coords)
print("datavars", data_vars)
return xarray.Dataset(
data_vars=data_vars, coords=coords, attrs=self.metadata["attrs"]
)
Expand Down
4 changes: 2 additions & 2 deletions tiled/serialization/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@serialization_registry.register(StructureFamily.table, APACHE_ARROW_FILE_MIME_TYPE)
def serialize_arrow(df, metadata, preserve_index=False):
def serialize_arrow(df, metadata, preserve_index=True):
import pyarrow

if isinstance(df, dict):
Expand All @@ -30,7 +30,7 @@ def deserialize_arrow(buffer):
# There seems to be no official Parquet MIME type.
# https://issues.apache.org/jira/browse/PARQUET-1889
@serialization_registry.register(StructureFamily.table, "application/x-parquet")
def serialize_parquet(df, metadata, preserve_index=False):
def serialize_parquet(df, metadata, preserve_index=True):
import pyarrow.parquet

table = pyarrow.Table.from_pandas(df, preserve_index=preserve_index)
Expand Down
23 changes: 13 additions & 10 deletions tiled/server/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@
MAX_ALLOWED_SPECS = 20


STRUCTURE_TYPES = {
StructureFamily.array: ArrayStructure,
StructureFamily.awkward: AwkwardStructure,
StructureFamily.table: TableStructure,
StructureFamily.sparse: SparseStructure,
}


class Error(pydantic.BaseModel):
code: int
message: str
Expand Down Expand Up @@ -148,6 +140,15 @@ def from_orm(cls, orm: tiled.catalog.orm.Revision) -> Revision:
)


STRUCTURE_TYPES = {
StructureFamily.array: ArrayStructure,
StructureFamily.awkward: AwkwardStructure,
StructureFamily.table: TableStructure,
StructureFamily.sparse: SparseStructure,
StructureFamily.container: NodeStructure,
}


class DataSource(pydantic.BaseModel, Generic[StructureT]):
id: Optional[int] = None
structure_family: StructureFamily
Expand Down Expand Up @@ -467,8 +468,10 @@ def specs_uniqueness_validator(cls, v):
def narrow_strucutre_type(self):
"Convert the structure on each data_source from a dict to the appropriate pydantic model."
for data_source in self.data_sources:
structure_cls = STRUCTURE_TYPES[self.structure_family]
data_source.structure = structure_cls(**data_source.structure)
if self.structure_family != StructureFamily.container:
structure_cls = STRUCTURE_TYPES[self.structure_family]
if data_source.structure is not None:
data_source.structure = structure_cls(**data_source.structure)
return self


Expand Down

0 comments on commit 8d26ce5

Please sign in to comment.