Skip to content

Commit

Permalink
Add type hinting for constructors and from_geopandas (#399)
Browse files Browse the repository at this point in the history
Better type hinting in vscode! This should hopefully be a massive QOL
upgrade, at least in vscode. This doesn't show up in Jupyter though 🤷‍♂️


![image](https://github.com/developmentseed/lonboard/assets/15164633/7fc7ad3d-b291-4612-b3be-888eb708b307)

![image](https://github.com/developmentseed/lonboard/assets/15164633/6a474abe-f76c-4ab4-9a34-6611a2242ac0)
  • Loading branch information
kylebarron authored Mar 1, 2024
1 parent bdbc2a7 commit c9f1214
Show file tree
Hide file tree
Showing 9 changed files with 437 additions and 70 deletions.
133 changes: 127 additions & 6 deletions lonboard/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple
from typing import (
TYPE_CHECKING,
List,
Optional,
Sequence,
Tuple,
)

import geopandas as gpd
import ipywidgets
Expand All @@ -34,13 +40,28 @@
NormalAccessor,
PyarrowTableTrait,
)
from lonboard.types.layer import (
BaseLayerKwargs,
BitmapLayerKwargs,
BitmapTileLayerKwargs,
HeatmapLayerKwargs,
PathLayerKwargs,
PointCloudLayerKwargs,
ScatterplotLayerKwargs,
SolidPolygonLayerKwargs,
)

if TYPE_CHECKING:
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

if sys.version_info >= (3, 12):
from typing import Unpack
else:
from typing_extensions import Unpack


class BaseLayer(BaseWidget):
# Note: these class attributes are **not** serialized to JS
Expand Down Expand Up @@ -185,7 +206,7 @@ def _add_extension_traits(self, extensions: Sequence[BaseExtension]):


def default_geoarrow_viewport(
table: pa.Table
table: pa.Table,
) -> Optional[Tuple[Bbox, WeightedCentroid]]:
# Note: in the ArcLayer we won't necessarily have a column with a geoarrow
# extension type/metadata
Expand Down Expand Up @@ -236,7 +257,11 @@ class BaseArrowLayer(BaseLayer):
table: traitlets.TraitType

def __init__(
self, *, table: pa.Table, _rows_per_chunk: Optional[int] = None, **kwargs
self,
*,
table: pa.Table,
_rows_per_chunk: Optional[int] = None,
**kwargs: Unpack[BaseLayerKwargs],
):
# Check for Arrow PyCapsule Interface
# https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
Expand Down Expand Up @@ -265,7 +290,11 @@ def __init__(

@classmethod
def from_geopandas(
cls, gdf: gpd.GeoDataFrame, *, auto_downcast: bool = True, **kwargs
cls,
gdf: gpd.GeoDataFrame,
*,
auto_downcast: bool = True,
**kwargs: Unpack[BaseLayerKwargs],
) -> Self:
"""Construct a Layer from a geopandas GeoDataFrame.
Expand Down Expand Up @@ -311,6 +340,9 @@ class BitmapLayer(BaseLayer):
```
"""

def __init__(self, **kwargs: BitmapLayerKwargs):
super().__init__(**kwargs) # type: ignore

_layer_type = traitlets.Unicode("bitmap").tag(sync=True)

image = traitlets.Unicode().tag(sync=True)
Expand Down Expand Up @@ -411,6 +443,9 @@ class BitmapTileLayer(BaseLayer):
```
"""

def __init__(self, **kwargs: BitmapTileLayerKwargs):
super().__init__(**kwargs) # type: ignore

_layer_type = traitlets.Unicode("bitmap-tile").tag(sync=True)

data = traitlets.Union(
Expand Down Expand Up @@ -607,6 +642,25 @@ class ScatterplotLayer(BaseArrowLayer):
```
"""

def __init__(
self,
*,
table: pa.Table,
_rows_per_chunk: Optional[int] = None,
**kwargs: Unpack[ScatterplotLayerKwargs],
):
super().__init__(table=table, _rows_per_chunk=_rows_per_chunk, **kwargs)

@classmethod
def from_geopandas(
cls,
gdf: gpd.GeoDataFrame,
*,
auto_downcast: bool = True,
**kwargs: Unpack[ScatterplotLayerKwargs],
) -> Self:
return super().from_geopandas(gdf=gdf, auto_downcast=auto_downcast, **kwargs)

_layer_type = traitlets.Unicode("scatterplot").tag(sync=True)

table = PyarrowTableTrait(
Expand Down Expand Up @@ -819,6 +873,25 @@ class PathLayer(BaseArrowLayer):
```
"""

def __init__(
self,
*,
table: pa.Table,
_rows_per_chunk: Optional[int] = None,
**kwargs: Unpack[PathLayerKwargs],
):
super().__init__(table=table, _rows_per_chunk=_rows_per_chunk, **kwargs)

@classmethod
def from_geopandas(
cls,
gdf: gpd.GeoDataFrame,
*,
auto_downcast: bool = True,
**kwargs: Unpack[PathLayerKwargs],
) -> Self:
return super().from_geopandas(gdf=gdf, auto_downcast=auto_downcast, **kwargs)

_layer_type = traitlets.Unicode("path").tag(sync=True)

table = PyarrowTableTrait(
Expand Down Expand Up @@ -960,6 +1033,25 @@ class PointCloudLayer(BaseArrowLayer):
```
"""

def __init__(
self,
*,
table: pa.Table,
_rows_per_chunk: Optional[int] = None,
**kwargs: Unpack[PointCloudLayerKwargs],
):
super().__init__(table=table, _rows_per_chunk=_rows_per_chunk, **kwargs)

@classmethod
def from_geopandas(
cls,
gdf: gpd.GeoDataFrame,
*,
auto_downcast: bool = True,
**kwargs: Unpack[PointCloudLayerKwargs],
) -> Self:
return super().from_geopandas(gdf=gdf, auto_downcast=auto_downcast, **kwargs)

_layer_type = traitlets.Unicode("point-cloud").tag(sync=True)

table = PyarrowTableTrait(
Expand Down Expand Up @@ -1056,6 +1148,25 @@ class SolidPolygonLayer(BaseArrowLayer):
```
"""

def __init__(
self,
*,
table: pa.Table,
_rows_per_chunk: Optional[int] = None,
**kwargs: Unpack[SolidPolygonLayerKwargs],
):
super().__init__(table=table, _rows_per_chunk=_rows_per_chunk, **kwargs)

@classmethod
def from_geopandas(
cls,
gdf: gpd.GeoDataFrame,
*,
auto_downcast: bool = True,
**kwargs: Unpack[SolidPolygonLayerKwargs],
) -> Self:
return super().from_geopandas(gdf=gdf, auto_downcast=auto_downcast, **kwargs)

_layer_type = traitlets.Unicode("solid-polygon").tag(sync=True)

table = PyarrowTableTrait(
Expand Down Expand Up @@ -1193,10 +1304,20 @@ class HeatmapLayer(BaseArrowLayer):
"""

def __init__(self, *args, table: pa.Table, **kwargs):
def __init__(self, *, table: pa.Table, **kwargs: Unpack[HeatmapLayerKwargs]):
# NOTE: we override the default for _rows_per_chunk because otherwise we render
# one heatmap per _chunk_ not for the entire dataset.
super().__init__(*args, table=table, _rows_per_chunk=len(self.table), **kwargs)
super().__init__(table=table, _rows_per_chunk=len(table), **kwargs)

@classmethod
def from_geopandas(
cls,
gdf: gpd.GeoDataFrame,
*,
auto_downcast: bool = True,
**kwargs: Unpack[HeatmapLayerKwargs],
) -> Self:
return super().from_geopandas(gdf=gdf, auto_downcast=auto_downcast, **kwargs)

_layer_type = traitlets.Unicode("heatmap").tag(sync=True)

Expand Down
15 changes: 13 additions & 2 deletions lonboard/_map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import sys
from pathlib import Path
from typing import IO, Optional, Sequence, TextIO, Union
from typing import IO, TYPE_CHECKING, Optional, Sequence, TextIO, Union

import ipywidgets
import traitlets
Expand All @@ -12,6 +13,14 @@
from lonboard._layer import BaseLayer
from lonboard._viewport import compute_view
from lonboard.basemap import CartoBasemap
from lonboard.types.map import MapKwargs

if TYPE_CHECKING:
if sys.version_info >= (3, 12):
from typing import Unpack
else:
from typing_extensions import Unpack


# bundler yields lonboard/static/{index.js,styles.css}
bundler_output_dir = Path(__file__).parent / "static"
Expand Down Expand Up @@ -64,7 +73,9 @@ class Map(BaseAnyWidget):
```
"""

def __init__(self, layers: Union[BaseLayer, Sequence[BaseLayer]], **kwargs) -> None:
def __init__(
self, layers: Union[BaseLayer, Sequence[BaseLayer]], **kwargs: Unpack[MapKwargs]
) -> None:
"""Create a new Map.
Aside from the `layers` argument, pass keyword arguments for any other attribute
Expand Down
Loading

0 comments on commit c9f1214

Please sign in to comment.