diff --git a/telcell/auxilliary_models/rare_pair/predictor.py b/telcell/auxilliary_models/rare_pair/predictor.py index f553ad6..c7a8701 100644 --- a/telcell/auxilliary_models/rare_pair/predictor.py +++ b/telcell/auxilliary_models/rare_pair/predictor.py @@ -20,6 +20,7 @@ class Predictor: def __init__(self, models: Mapping[Tuple[int, Bin], CoverageModel]): + assert all(isinstance(key[0], int) for key in models.keys()) self.models = models self._bins = list(sorted(set(bin for _, bin in models.keys()))) @@ -40,7 +41,7 @@ def predict_probability_locations_given_antenna(self, normalized_probabilities_for_reference_area = [] for measurement in measurements: - model = self.models[(measurement.extra['mnc'], bin)] + model = self.models[(measurement.extra['cell'].mnc, bin)] measurement_area = model.measurement_area(measurement) if measurement_area.intersect(reference_area): normalized_probabilities_for_reference_area.append( @@ -60,7 +61,7 @@ def predict_probability_antenna_given_locations(self, :param delta_t: the time difference between the two antenna registrations :returns: the probability """ - model = self.models[(measurement.extra['mnc'], self.get_bin(delta_t))] + model = self.models[(measurement.extra['cell'].mnc, self.get_bin(delta_t))] return model.probabilities(measurement) def get_probability_e_h(self, diff --git a/telcell/celldb/cell_collection.py b/telcell/celldb/cell_collection.py index 0c3d62f..697afab 100644 --- a/telcell/celldb/cell_collection.py +++ b/telcell/celldb/cell_collection.py @@ -6,20 +6,10 @@ import geopy from telcell.cell_identity import CellIdentity +from telcell.celldb.models import CellInfo -class Properties(dict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __getattr__(self, item): - if item in self: - return self[item] - else: - return Properties() - - -class CellCollection(Iterable[Properties], Sized): +class CellCollection(Iterable[CellInfo], Sized): """ Collection of cell towers in a cellular network, that may be queried by cell id (e.g. for geocoding) or by coordinates (e.g. reverse geocoding). @@ -32,7 +22,7 @@ class CellCollection(Iterable[Properties], Sized): """ @abstractmethod - def get(self, date: datetime.datetime, ci: CellIdentity) -> Optional[Properties]: + def get(self, date: datetime.datetime, ci: CellIdentity) -> Optional[CellInfo]: """ Retrieve a specific antenna from database. diff --git a/telcell/celldb/duplicate_policy.py b/telcell/celldb/duplicate_policy.py index fb87380..90ad79b 100644 --- a/telcell/celldb/duplicate_policy.py +++ b/telcell/celldb/duplicate_policy.py @@ -2,7 +2,7 @@ from typing import Optional, Sequence from telcell.cell_identity import CellIdentity -from telcell.celldb.cell_collection import Properties +from telcell.celldb.models import CellInfo def get_duplicate_policy(name: str): @@ -12,20 +12,20 @@ def get_duplicate_policy(name: str): return globals()[name] -def exception(ci: CellIdentity, _results: Sequence[Properties]) -> Optional[Properties]: +def exception(ci: CellIdentity, _results: Sequence[CellInfo]) -> Optional[CellInfo]: raise ValueError(f"duplicate cell id {ci} (not allowed by current policy)") -def warn(ci: CellIdentity, results: Sequence[Properties]) -> Optional[Properties]: +def warn(ci: CellIdentity, results: Sequence[CellInfo]) -> Optional[CellInfo]: warnings.warn(f"duplicate cell id {ci}") return results[0] def take_first( - _ci: CellIdentity, results: Sequence[Properties] -) -> Optional[Properties]: + _ci: CellIdentity, results: Sequence[CellInfo] +) -> Optional[CellInfo]: return results[0] -def drop(_ci: CellIdentity, _results: Sequence[Properties]) -> Optional[Properties]: +def drop(_ci: CellIdentity, _results: Sequence[CellInfo]) -> Optional[CellInfo]: return None diff --git a/telcell/celldb/google.py b/telcell/celldb/google.py index 3af649a..994562b 100644 --- a/telcell/celldb/google.py +++ b/telcell/celldb/google.py @@ -13,7 +13,7 @@ Radio, ) from telcell.celldb import CellCollection -from telcell.celldb.cell_collection import Properties +from telcell.celldb.models import CellInfo def _ci_to_dict(cell: CellIdentity) -> dict[str, str | int]: @@ -48,7 +48,7 @@ def __init__(self, key: str, user_agent: str = "TestApp", cache_name: str = None else: self._session = requests_cache.CachedSession(cache_name) - def get(self, date: datetime.datetime, cell: CellIdentity) -> Properties: + def get(self, date: datetime.datetime, cell: CellIdentity) -> CellInfo: if cell.radio is None and isinstance(cell, EutranCellGlobalIdentity): info = self.get(date, CellIdentity.parse(f"{Radio.LTE.value}/{cell}")) if info is None: @@ -72,7 +72,7 @@ def get(self, date: datetime.datetime, cell: CellIdentity) -> Properties: ) accuracy = res["accuracy"] - return Properties(cell=cell, wgs84=point, accuracy=accuracy) + return CellInfo(cell=cell, wgs84=point, accuracy_m=accuracy) def search( self, diff --git a/telcell/celldb/models.py b/telcell/celldb/models.py new file mode 100644 index 0000000..ef021fe --- /dev/null +++ b/telcell/celldb/models.py @@ -0,0 +1,16 @@ +from typing import Optional + +import geopy +from pydantic import BaseModel, ConfigDict + +from telcell.cell_identity import CellIdentity +from telcell.geography import Angle + + +class CellInfo(BaseModel): + model_config = ConfigDict(frozen=True, extra="allow", arbitrary_types_allowed=True) + + cell: Optional[CellIdentity] = None + wgs84: Optional[geopy.Point] = None + azimuth: Optional[Angle] = None + accuracy_m: Optional[float] = None diff --git a/telcell/celldb/pgdatabase.py b/telcell/celldb/pgdatabase.py index f942e93..4759e04 100644 --- a/telcell/celldb/pgdatabase.py +++ b/telcell/celldb/pgdatabase.py @@ -13,12 +13,13 @@ EutranCellGlobalIdentity, ) from . import duplicate_policy -from .cell_collection import CellCollection, Properties +from .cell_collection import CellCollection +from .models import CellInfo from ..data.models import rd_to_point, point_to_rd from ..geography import Angle -def _build_antenna(row: Tuple) -> Properties: +def _build_antenna(row: Tuple) -> CellInfo: date_start, date_end, radio, mcc, mnc, lac, ci, eci, rdx, rdy, azimuth_degrees = row if radio == Radio.GSM.value or radio == Radio.UMTS.value: retrieved_ci = CellIdentity.create( @@ -37,7 +38,7 @@ def _build_antenna(row: Tuple) -> Properties: coords = rd_to_point(rdx, rdy) azimuth = Angle(degrees=azimuth_degrees) if azimuth_degrees is not None else None - return Properties(wgs84=coords, azimuth=azimuth, cell=retrieved_ci) + return CellInfo(wgs84=coords, azimuth=azimuth, cell=retrieved_ci) def _build_cell_identity_query(ci): @@ -110,7 +111,7 @@ def __init__( self._count_limit = _count_limit self._cur = None - def get(self, date: datetime.datetime, ci: CellIdentity) -> Optional[Properties]: + def get(self, date: datetime.datetime, ci: CellIdentity) -> Optional[CellInfo]: """ Retrieve a specific antenna from database. diff --git a/telcell/data/models.py b/telcell/data/models.py index ad799b3..d965ee7 100644 --- a/telcell/data/models.py +++ b/telcell/data/models.py @@ -3,13 +3,14 @@ from dataclasses import dataclass from datetime import datetime from functools import cached_property -from typing import Any, Mapping, Tuple, Sequence, Iterator, Optional +from typing import Any, Mapping, Tuple, Sequence, Iterator import geopy import geopy.distance import pyproj from pyproj import Proj, Geod, Transformer + RD = ( "+proj=sterea +lat_0=52.15616055555555 +lon_0=5.38763888888889 " "+k=0.999908 +x_0=155000 +y_0=463000 +ellps=bessel " @@ -79,7 +80,7 @@ class Measurement: """ coords: geopy.Point - timestamp: Optional[datetime] + timestamp: datetime extra: Mapping[str, Any] @property @@ -107,7 +108,7 @@ def __hash__(self): self.lat, self.lon, self.timestamp.date(), - *(_extra for _extra in self.extra.values()), + str(self.extra), ) ) diff --git a/telcell/data/parsers.py b/telcell/data/parsers.py index a803c14..f43c3aa 100644 --- a/telcell/data/parsers.py +++ b/telcell/data/parsers.py @@ -90,9 +90,9 @@ def parse_coverage_data_csv(path: Union[str, Path]) -> List[CoverageData]: timestamp=None, extra={ "bandwidth": row["antenna_bandwidth"], - "height": row["antenna_height"], - "azimuth": row["antenna_azimuth"], - "mnc": row["antenna_mnc"], + "height": float(row["antenna_height"]) if row["antenna_height"] != "" else None, + "azimuth": float(row["antenna_azimuth"]), + "mnc": int(row["antenna_mnc"]), "radio": row["antenna_radio"], }, ) diff --git a/telcell/geography.py b/telcell/geography.py index 9759b11..9f9a7f7 100644 --- a/telcell/geography.py +++ b/telcell/geography.py @@ -106,6 +106,8 @@ def normalize_angle(angle: Angle) -> Angle: def azimuth_deg(coord1: geopy.Point, coord2: geopy.Point) -> float: + assert isinstance(coord1, geopy.Point), f"argument 1: expected geopy.Point; found {type(coord1)}" + assert isinstance(coord2, geopy.Point), f"argument 2: expected geopy.Point; found {type(coord2)}" geodesic = pyproj.Geod(ellps="WGS84") fwd_azimuth, back_azimuth, distance = geodesic.inv( coord1.longitude, coord1.latitude, coord2.longitude, coord2.latitude diff --git a/telcell/models/rare_pair_feature_based.py b/telcell/models/rare_pair_feature_based.py index 1db105d..c4f90ca 100644 --- a/telcell/models/rare_pair_feature_based.py +++ b/telcell/models/rare_pair_feature_based.py @@ -42,13 +42,14 @@ def __init__(self, coverage_training_data: Sequence[CoverageData], transformer: def filter_track(track: Track, filter: Mapping) -> Track: measurements = track.measurements if 'mnc' in filter: - measurements = [m for m in measurements if m.extra['mnc'] in filter['mnc']] + measurements = [m for m in measurements if m.extra['cell'].mnc in filter['mnc']] return Track(measurements=measurements, device=track.device, owner=track.owner) - def predict_lr(self, track_a: Track, track_b: Track, **kwargs) \ + def predict_lr(self, track_a: Track, track_b: Track, filter: Mapping = None, **kwargs) \ -> Tuple[Optional[float], Optional[Mapping[str, Any]]]: - track_a = self.filter_track(track_a, filter=kwargs['filter']) - track_b = self.filter_track(track_b, filter=kwargs['filter']) + if filter: + track_a = self.filter_track(track_a, filter=filter) + track_b = self.filter_track(track_b, filter=filter) if not track_a or not track_b: return None, None switches = get_switches(track_a, track_b) diff --git a/tests/auxilliary_models/rare_pair/test_model.py b/tests/auxilliary_models/rare_pair/test_model.py index 0986498..a8ab577 100644 --- a/tests/auxilliary_models/rare_pair/test_model.py +++ b/tests/auxilliary_models/rare_pair/test_model.py @@ -10,6 +10,7 @@ ) from telcell.auxilliary_models.rare_pair.predictor import Predictor from telcell.auxilliary_models.rare_pair.utils import DISTANCE_STEP +from telcell.cell_identity import CellIdentity from telcell.data.models import Measurement @@ -17,7 +18,7 @@ def test_angle_distance_coverage_model(): test_measurement = Measurement( geopy.Point(latitude=52.0449566305567, longitude=4.3585472613577965), datetime.strptime("2023-01-01", "%Y-%m-%d"), - {"mnc": 4, "azimuth": 0}, + {"cell": CellIdentity.create(mnc=4), "azimuth": 0}, ) clf = DecisionTreeClassifier()