diff --git a/CELLDB.md b/CELLDB.md index 239c602..40b5765 100644 --- a/CELLDB.md +++ b/CELLDB.md @@ -6,6 +6,7 @@ information on the actual positions of the cell antennas, as well as other meta data. The following assumes you have such a database in a readable CSV format. + # Import cell database Install Postgres and Postgis and remember credentials. @@ -27,6 +28,7 @@ nano celldb.yaml python -m telcell.celldb --config celldb.yaml import < celldb.csv ``` + # Usage ```py @@ -36,7 +38,7 @@ with script_helper.get_cell_database("celldb.yaml", on_duplicate=duplicate_polic # retrieve cell info my_cell = CellIdentity.create(radio="GSM", mcc=99, mnc=99, eci=123456) - cellinfo = db.get(my_cell, date=datetime.datetime.now()) + cellinfo = db.get(ci=my_cell, date=datetime.datetime.now()) if cellinfo is None: print("cell not found") else: @@ -54,3 +56,21 @@ with script_helper.get_cell_database("celldb.yaml", on_duplicate=duplicate_polic for cellinfo in nearby_gsm_cells: print(cellinfo) ``` + + +# Cell database CSV + +A cell database file is expected to be comma-separated with the column names +in the header, and the following columns: + +* date_start: a timestamp in ISO format of when the antenna became operational +* date_end: a timestamp in ISO format of when the antenna was decommissioned +* radio: the radio technology of the antenna (e.g. GSM, UMTS, LTE, NR) +* mcc: the Mobile Country Code (MCC) +* mnc: the Mobile Network Code (MNC) +* lac: the Location Area Code (LAC), in case of GSM, UMTS +* ci: the Cell Identity (CI), in case of GSM, UMTS +* eci: the evolved Cell Identity (eCI), in case of LTE, NR +* lon: the longitude of the antenna position (WGS84) +* lat: the latitude of the antenna position (WGS84) +* azimuth: the transmission direction of the antenna in degrees, relative to north diff --git a/requirements.txt b/requirements.txt index 706cb41..7666728 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ +geopy lir lrbenchmark>=0.1.2 numpy pydeck scikit-learn streamlit +psycopg2-binary pyproj tensorflow~=2.13.0 confidence diff --git a/telcell/celldb/cell_collection.py b/telcell/celldb/cell_collection.py index b33c27b..0c3d62f 100644 --- a/telcell/celldb/cell_collection.py +++ b/telcell/celldb/cell_collection.py @@ -58,15 +58,16 @@ def search( """ Given a Point, find antennas that are in reach from this point sorted by the distance from the grid point. - :param coords: Point for which nearby antennas are retrieved - :param distance_limit_m: antennas should be within this range - :param date: used to select active antennas - :param radio: antennas should be limited to this radio technology, e.g.: LTE, UMTS, GSM - :param mcc: antennas should be limited to this mcc - :param mnc: antennas should be limited to this mnc - :param count_limit: maximum number of antennas to return - :param exclude: antenna that should be excluded from the retrieved antennas - :return: retrieved antennas within reach from the Point + :param coords: antenna selection criteria are relative to this point + :param distance_limit_m: select antennas within this range of `coords` + :param distance_lower_limit_m: select antennas beyond this range of `coords` + :param date: select antennas which are valid at this date + :param radio: select antennas using this radio technology, e.g.: LTE, UMTS, GSM + :param mcc: select antennas with this MCC + :param mnc: select antennas with this MNC + :param count_limit: return at most this number of antennas + :param exclude: excluded antennas with this `CellIdentity` + :return: retrieved selected antennas """ raise NotImplementedError diff --git a/telcell/celldb/pgdatabase.py b/telcell/celldb/pgdatabase.py index b55b0ca..38abf80 100644 --- a/telcell/celldb/pgdatabase.py +++ b/telcell/celldb/pgdatabase.py @@ -14,29 +14,11 @@ ) from . import duplicate_policy from .cell_collection import CellCollection, Properties -from ..data.models import RD_TO_WGS84, WGS84_TO_RD +from ..data.models import rd_to_point, point_to_rd from ..geography import Angle -RD_X_RANGE = (7000, 300000) -RD_Y_RANGE = (289000, 629000) - -def rd_to_point(x: int, y: int) -> geopy.Point: - if not (RD_Y_RANGE[0] <= y <= RD_Y_RANGE[1]) or not ( - RD_X_RANGE[0] <= x <= RD_X_RANGE[1] - ): - warnings.warn( - f"rijksdriehoek coordinates {x}, {y} outside range: x={RD_X_RANGE}, y={RD_Y_RANGE}" - ) - c = RD_TO_WGS84.transform(x, y) - return geopy.Point(longitude=c[0], latitude=c[1]) - - -def point_to_rd(point: geopy.Point) -> Tuple[int, int]: - return WGS84_TO_RD.transform(point.longitude, point.latitude) - - -def _build_antenna(row): +def _build_antenna(row: Tuple) -> Properties: date_start, date_end, radio, mcc, mnc, lac, ci, eci, rdx, rdy, azimuth_degrees = row if radio == Radio.GSM.value or radio == Radio.UMTS.value: retrieved_ci = CellIdentity.create( @@ -92,11 +74,11 @@ class PgCollection(CellCollection): def __init__( self, con, - qwhere=None, - qargs=None, - qorder=None, - count_limit: int = None, on_duplicate: Callable = duplicate_policy.warn, + _qwhere=None, + _qargs=None, + _qorder=None, + _count_limit: int = None, ): """ Initializes a `PgDatabase` object. @@ -110,20 +92,20 @@ def __init__( - `celldb.duplicate_policy.warn`: same as `take_first`, and emits a warning - `celldb.duplicate_policy.exception`: throws an exception - @param con: an active postgres connection - @param qwhere: - @param qargs: - @param qorder: - @param count_limit: - @param on_duplicate: policy when the cell database has two or more hits for the same cell in a call to `get()`. + :param con: an active postgres connection + :param on_duplicate: policy when the cell database has two or more hits for the same cell in a call to `get()`. + :param _qwhere: for private use: add criteria in WHERE clause + :param _qargs: for private use: add query arguments + :param _qorder: for private use: add criteria in ORDER clause + :param _count_limit: for private use: add item limit """ self._con = con - self._qwhere = qwhere or ["TRUE"] - self._qargs = qargs or [] - self._qorder = qorder or "" - self._count_limit = count_limit - self._cur = None self._on_duplicate = on_duplicate + self._qwhere = _qwhere or ["TRUE"] + self._qargs = _qargs or [] + self._qorder = _qorder or "" + self._count_limit = _count_limit + self._cur = None def get(self, date: datetime.datetime, ci: CellIdentity) -> Optional[Properties]: """ @@ -136,36 +118,13 @@ def get(self, date: datetime.datetime, ci: CellIdentity) -> Optional[Properties] if isinstance(date, datetime.date): date = datetime.datetime.combine(date, datetime.datetime.min.time()) - qwhere = list(self._qwhere) - qargs = list(self._qargs) - - if date is not None: - qwhere = qwhere + [ - "(date_start is NULL OR %s >= date_start) AND (date_end is NULL OR %s < date_end)" - ] - qargs = qargs + [date, date] - - add_qwhere, add_qargs = _build_cell_identity_query(ci) - qwhere.append(add_qwhere) - qargs.extend(add_qargs) - - with self._con.cursor() as cur: - cur.execute( - f""" - SELECT date_start, date_end, radio, mcc, mnc, lac, ci, eci, ST_X(rd), ST_Y(rd), azimuth - FROM antenna_light - WHERE {' AND '.join(qw for qw in qwhere)} - """, - qargs, - ) - - results = [_build_antenna(row) for row in cur.fetchall()] - if len(results) == 0: - return None - elif len(results) > 1: - return self._on_duplicate(ci, results) - else: - return results[0] + results = list(self.search(date=date)) + if len(results) == 0: + return None + elif len(results) > 1: + return self._on_duplicate(ci, results) + else: + return results[0] def search( self, @@ -180,20 +139,6 @@ def search( random_order: bool = False, exclude: Optional[List[CellIdentity]] = None, ) -> CellCollection: - """ - Given a Point, find antennas that are in reach from this point sorted by the distance from the grid point. - - :param coords: Point for which nearby antennas are retrieved - :param distance_limit_m: antennas should be within this range - :param date: select antennas that were valid at `date` - :param radio: antennas should be limited to this radio technology, e.g.: LTE, UMTS, GSM (accepts `str` or list - of `str`) - :param mcc: antennas should be limited to this mcc - :param mnc: antennas should be limited to this mnc - :param count_limit: maximum number of antennas to return - :param exclude: antenna that should be excluded from the retrieved antennas - :return: retrieved antennas within reach from the Point - """ qwhere = list(self._qwhere) qargs = list(self._qargs) @@ -240,7 +185,7 @@ def search( count_limit = count_limit if count_limit is not None else self._count_limit return PgCollection( - self._con, qwhere, qargs, qorder, count_limit, self._on_duplicate + self._con, self._on_duplicate, qwhere, qargs, qorder, count_limit ) def close(self): @@ -292,10 +237,29 @@ def __len__(self): def csv_import(con, flo, progress: Callable = lambda x: x): + """ + Import antenna data into a Postgres database from a CSV file. + + :param con: an open database connection + :param flo: a file like object pointing to CSV data + :param progress: an optional progress bar (like tqdm), or `None` + """ create_table(con) - reader = csv.reader(flo) - next(reader) # skip header + fieldnames = [ + "date_start", + "date_end", + "radio", + "mcc", + "mnc", + "lac", + "ci", + "eci", + "lon", + "lat", + "azimuth", + ] + reader = csv.DictReader(flo, fieldnames=fieldnames) with con.cursor() as cur: for i, row in enumerate(progress(list(reader))): @@ -312,7 +276,7 @@ def csv_import(con, flo, progress: Callable = lambda x: x): lon, lat, azimuth, - ) = [c if c != "" else None for c in row] + ) = [row[f] if row[f] != "" else None for f in fieldnames] lon, lat = float(lon), float(lat) assert math.isfinite(lon), f"invalid number for longitude: {lon}" assert math.isfinite(lat), f"invalid number for latitude: {lat}" @@ -344,6 +308,13 @@ def csv_import(con, flo, progress: Callable = lambda x: x): def csv_export(con, flo): + """ + Export antenna data in a Postgres database to a CSV file. + + :param con: an open database connection + :param flo: a file like object where CSV data will be written + """ + sql_x = "ST_X(rd)" sql_y = "ST_Y(rd)" sql_lon = f""" diff --git a/telcell/data/models.py b/telcell/data/models.py index a4868a6..878b7bf 100644 --- a/telcell/data/models.py +++ b/telcell/data/models.py @@ -1,47 +1,74 @@ from __future__ import annotations import math +import warnings from dataclasses import dataclass from datetime import datetime from functools import cached_property from typing import Any, Mapping, Tuple, Sequence, Iterator, Union +import geopy import pyproj from pyproj import Proj, Geod, Transformer -RD = ("+proj=sterea +lat_0=52.15616055555555 +lon_0=5.38763888888889 " - "+k=0.999908 +x_0=155000 +y_0=463000 +ellps=bessel " - "+towgs84=565.237,50.0087,465.658,-0.406857,0.350733,-1.87035,4.0812 " - "+units=m +no_defs") -GOOGLE = ('+proj=merc +a=6378137 +b=6378137 +lat_ts=0.0 ' - '+lon_0=0.0 +x_0=0.0 +y_0=0 +k=1.0 +units=m ' - '+nadgrids=@null +no_defs +over') -WGS84 = '+proj=latlong +datum=WGS84' +RD = ( + "+proj=sterea +lat_0=52.15616055555555 +lon_0=5.38763888888889 " + "+k=0.999908 +x_0=155000 +y_0=463000 +ellps=bessel " + "+towgs84=565.237,50.0087,465.658,-0.406857,0.350733,-1.87035,4.0812 " + "+units=m +no_defs" +) +GOOGLE = ( + "+proj=merc +a=6378137 +b=6378137 +lat_ts=0.0 " + "+lon_0=0.0 +x_0=0.0 +y_0=0 +k=1.0 +units=m " + "+nadgrids=@null +no_defs +over" +) +WGS84 = "+proj=latlong +datum=WGS84" rd_projection = Proj(RD) google_projection = Proj(GOOGLE) wgs84_projection = Proj(WGS84) -geodesic = Geod('+ellps=sphere') +geodesic = Geod("+ellps=sphere") WGS84_TO_RD = Transformer.from_proj(wgs84_projection, rd_projection) RD_TO_WGS84 = Transformer.from_proj(rd_projection, wgs84_projection) -# TODO: Realistic boundsx -rd_x_range = (1000, 350000) -rd_y_range = (1000, 700000) -GEOD_WGS84 = pyproj.Geod(ellps='WGS84') +# TODO: more accurate range would be: x (7000, 300000) and y (289000, 629000) +RD_X_RANGE = (1000, 300000) +RD_Y_RANGE = (1000, 629000) +GEOD_WGS84 = pyproj.Geod(ellps="WGS84") -def approximately_equal(first, second, tolerance=.0001): +def approximately_equal(first, second, tolerance=0.0001): return abs(first - second) < tolerance +def rd_to_point(x: int, y: int) -> geopy.Point: + if not (RD_Y_RANGE[0] <= y <= RD_Y_RANGE[1]) or not ( + RD_X_RANGE[0] <= x <= RD_X_RANGE[1] + ): + warnings.warn( + f"rijksdriehoek coordinates {x}, {y} outside range: x={RD_X_RANGE}, y={RD_Y_RANGE}" + ) + c = RD_TO_WGS84.transform(x, y) + return geopy.Point(longitude=c[0], latitude=c[1]) + + +def point_to_rd(point: geopy.Point) -> Tuple[int, int]: + return WGS84_TO_RD.transform(point.longitude, point.latitude) + + @dataclass(frozen=True) class RDPoint: x: float y: float def __post_init__(self): - if self.x < rd_x_range[0] or self.x > rd_x_range[1] or self.y < \ - rd_y_range[0] or self.y > rd_y_range[1]: - raise ValueError(f'Invalid rijksdriehoek coordinates: ({self.x=}, {self.y=}); ' - f'allowed range: x={rd_x_range}, y={rd_y_range}.') + if ( + self.x < RD_X_RANGE[0] + or self.x > RD_X_RANGE[1] + or self.y < RD_Y_RANGE[0] + or self.y > RD_Y_RANGE[1] + ): + raise ValueError( + f"Invalid rijksdriehoek coordinates: ({self.x=}, {self.y=}); " + f"allowed range: x={RD_X_RANGE}, y={RD_Y_RANGE}." + ) @property def xy(self) -> Tuple[float, float]: @@ -53,13 +80,17 @@ def convert_to_wgs84(self) -> Point: def distance(self, other: Union[RDPoint, Point]) -> float: other_rd = other.convert_to_rd() if isinstance(other, Point) else other - return math.sqrt(math.pow(self.x - other_rd.x, 2) + math.pow(self.y - other_rd.y, 2)) + return math.sqrt( + math.pow(self.x - other_rd.x, 2) + math.pow(self.y - other_rd.y, 2) + ) - def approx_equal(self, other: Union[RDPoint, Point], tolerance_m: float = 1) -> bool: + def approx_equal( + self, other: Union[RDPoint, Point], tolerance_m: float = 1 + ) -> bool: return self.distance(other) < tolerance_m def __repr__(self): - return f'RDPoint(x={self.x}, y={self.y})' + return f"RDPoint(x={self.x}, y={self.y})" @dataclass(frozen=True) @@ -69,7 +100,7 @@ class Point: def __post_init__(self): if self.lat < -90 or self.lat > 90 or self.lon < -180 or self.lon > 180: - raise ValueError(f'Invalid wgs84 coordinates: ({self.lat=}, {self.lon=}).') + raise ValueError(f"Invalid wgs84 coordinates: ({self.lat=}, {self.lon=}).") @property def latlon(self) -> Tuple[float, float]: @@ -82,7 +113,9 @@ def convert_to_rd(self) -> RDPoint: def distance(self, other: Union[RDPoint, Point]) -> float: self_rd = self.convert_to_rd() other_rd = other.convert_to_rd() if isinstance(other, Point) else other - return math.sqrt(math.pow(self_rd.x - other_rd.x, 2) + math.pow(self_rd.y - other_rd.y, 2)) + return math.sqrt( + math.pow(self_rd.x - other_rd.x, 2) + math.pow(self_rd.y - other_rd.y, 2) + ) def approx_equal(self, other: Union[RDPoint, Point], tolerance_m: int = 1) -> bool: return self.distance(other) < tolerance_m @@ -94,7 +127,7 @@ def __eq__(self, other): return False def __repr__(self): - return f'Point(lat={self.lat}, lon={self.lon})' + return f"Point(lat={self.lat}, lon={self.lon})" @dataclass(eq=True, frozen=True) @@ -108,6 +141,7 @@ class Measurement: this measurement. These could for example inform the accuracy or uncertainty of the measured WGS84 coordinates. """ + coords: Point timestamp: datetime extra: Mapping[str, Any] @@ -132,8 +166,14 @@ def __str__(self): return f"<{self.timestamp}: ({self.lat}, {self.lon})>" def __hash__(self): - return hash((self.lat, self.lon, self.timestamp.date(), - *(_extra for _extra in self.extra.values()))) + return hash( + ( + self.lat, + self.lon, + self.timestamp.date(), + *(_extra for _extra in self.extra.values()), + ) + ) @dataclass @@ -145,6 +185,7 @@ class Track: :param device: The name of the device. :param measurements: A series of measurements ordered by timestamp. """ + owner: str device: str measurements: Sequence[Measurement] diff --git a/telcell/geography.py b/telcell/geography.py index c96ad6b..9759b11 100644 --- a/telcell/geography.py +++ b/telcell/geography.py @@ -7,9 +7,6 @@ import pyproj -GEOD_WGS84 = pyproj.Geod(ellps="WGS84") - - class Angle: """ Class to represent an angle between two lines. Implements mathematical operators and can work with degrees and @@ -121,8 +118,8 @@ def azimuth(coord1: geopy.Point, coord2: geopy.Point) -> Angle: Calculates the azimuth of the line between two points. That is, the angle between the line from the first point northward and the line from the first point to the second. - @param coord1: the coordinates of the first point - @param coord2: the coordinates of the second point - @return: the azimuth of the line that connects the first point to the second + :param coord1: the coordinates of the first point + :param coord2: the coordinates of the second point + :return: the azimuth of the line that connects the first point to the second """ return Angle(degrees=azimuth_deg(coord1, coord2)) diff --git a/tests/auxilliary_models/test_geography.py b/tests/auxilliary_models/test_geography.py index 415b0d5..41e6dba 100644 --- a/tests/auxilliary_models/test_geography.py +++ b/tests/auxilliary_models/test_geography.py @@ -4,7 +4,12 @@ import pytest import tensorflow as tf -from telcell.auxilliary_models.geography import GridPoint, manhattan_distance, Grid, EmptyGrid +from telcell.auxilliary_models.geography import ( + GridPoint, + manhattan_distance, + Grid, + EmptyGrid, +) def test_distance(): @@ -24,7 +29,7 @@ def test_move(): p = GridPoint(10000, 11000) assert p == p assert p == p.move(0, 0) - assert p == p.move(.0, .0) + assert p == p.move(0.0, 0.0) def test_stick_to_resolution(): @@ -34,7 +39,9 @@ def test_stick_to_resolution(): p = GridPoint(10000, 11000) assert p == GridPoint(9510, 10510).stick_to_resolution(1000) assert p == GridPoint(10000.4, 11000.4).stick_to_resolution(1000) - assert GridPoint(100000, 100000) == GridPoint(90000, 110000).stick_to_resolution(100000) + assert GridPoint(100000, 100000) == GridPoint(90000, 110000).stick_to_resolution( + 100000 + ) def test_grid(): @@ -42,33 +49,54 @@ def test_grid(): Test that values (with cut-out) has correct number of np.nan and 1. Also tests specific exceptions """ - _grid = Grid(100, 10, GridPoint(1000, 1000), np.ones((10, 10)), - (GridPoint(1040, 1040), GridPoint(1060, 1060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 1000), + np.ones((10, 10)), + (GridPoint(1040, 1040), GridPoint(1060, 1060)), + ) assert np.sum(np.isnan(_grid.values)) == 4 - assert np.sum(_grid.values == 1.) == 96 + assert np.sum(_grid.values == 1.0) == 96 # Cut out (sw) does not align with resolutions with pytest.raises(ValueError): - _grid = Grid(100, 10, GridPoint(1000, 1000), np.ones((10, 10)), - (GridPoint(1045, 1045), GridPoint(1060, 1060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 1000), + np.ones((10, 10)), + (GridPoint(1045, 1045), GridPoint(1060, 1060)), + ) # Cut out (ne) does not align with resolutions with pytest.raises(ValueError): - _grid = Grid(100, 10, GridPoint(1000, 1000), np.ones((10, 10)), - (GridPoint(1040, 1040), GridPoint(1065, 1065))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 1000), + np.ones((10, 10)), + (GridPoint(1040, 1040), GridPoint(1065, 1065)), + ) # Array does not fit shape of values with pytest.raises(ValueError): - _grid = Grid(100, 10, GridPoint(1000, 1000), np.ones((11, 11)), - (GridPoint(1040, 1040), GridPoint(1060, 1060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 1000), + np.ones((11, 11)), + (GridPoint(1040, 1040), GridPoint(1060, 1060)), + ) def test_edges(): """ Test that southwest and northeast of Grid are correct """ - _grid = EmptyGrid(100, 10, GridPoint(1000, 1000), - (GridPoint(1040, 1040), GridPoint(1060, 1060))) + _grid = EmptyGrid( + 100, 10, GridPoint(1000, 1000), (GridPoint(1040, 1040), GridPoint(1060, 1060)) + ) assert _grid.southwest == GridPoint(1000, 1000) assert _grid.northeast == GridPoint(1100, 1100) @@ -78,24 +106,58 @@ def test_coords(): """ Test correct y and x coordinates """ - _grid = EmptyGrid(100, 10, GridPoint(1000, 10000), - (GridPoint(1040, 10040), GridPoint(1060, 10060))) - assert _grid.x_coords == [1005, 1015, 1025, 1035, 1045, 1055, 1065, 1075, 1085, 1095] - assert _grid.y_coords == [10005, 10015, 10025, 10035, 10045, 10055, 10065, 10075, 10085, 10095] + _grid = EmptyGrid( + 100, + 10, + GridPoint(1000, 10000), + (GridPoint(1040, 10040), GridPoint(1060, 10060)), + ) + assert _grid.x_coords == [ + 1005, + 1015, + 1025, + 1035, + 1045, + 1055, + 1065, + 1075, + 1085, + 1095, + ] + assert _grid.y_coords == [ + 10005, + 10015, + 10025, + 10035, + 10045, + 10055, + 10065, + 10075, + 10085, + 10095, + ] def test_value_for_coord(): """ Test that values assigned to values match the expected coordinates (GridPoint) """ - _grid = EmptyGrid(100, 10, GridPoint(1000, 10000), - (GridPoint(1040, 10040), GridPoint(1060, 10060))) + _grid = EmptyGrid( + 100, + 10, + GridPoint(1000, 10000), + (GridPoint(1040, 10040), GridPoint(1060, 10060)), + ) # Set random values to values random_grid = np.random.rand(*_grid.grid_shape) - _grid = Grid(100, 10, GridPoint(1000, 10000), - random_grid, - (GridPoint(1040, 10040), GridPoint(1060, 10060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 10000), + random_grid, + (GridPoint(1040, 10040), GridPoint(1060, 10060)), + ) assert _grid.get_value_for_coord(GridPoint(1001, 10001)) == random_grid[0, 0] assert _grid.get_value_for_coord(GridPoint(1001, 10099)) == random_grid[-1, 0] @@ -105,7 +167,7 @@ def test_value_for_coord(): assert np.isnan(_grid.get_value_for_coord(GridPoint(1049, 10049))) # Coordinates outside values are invalid - with pytest.raises(Exception, match='is not within a section of Grid'): + with pytest.raises(Exception, match="is not within a section of Grid"): _grid.get_value_for_coord(GridPoint(2001, 10001)) @@ -114,23 +176,31 @@ def test_value_for_center(): Test that values assigned to the values match the expected sections based on the coordinates of the centers of these sections """ - _grid = EmptyGrid(100, 10, GridPoint(1000, 10000), - (GridPoint(1040, 10040), GridPoint(1060, 10060))) + _grid = EmptyGrid( + 100, + 10, + GridPoint(1000, 10000), + (GridPoint(1040, 10040), GridPoint(1060, 10060)), + ) random_grid = np.random.rand(*_grid.grid_shape) - _grid = Grid(100, 10, GridPoint(1000, 10000), - random_grid, - (GridPoint(1040, 10040), GridPoint(1060, 10060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 10000), + random_grid, + (GridPoint(1040, 10040), GridPoint(1060, 10060)), + ) assert _grid.get_value_for_center(GridPoint(1005, 10005)) == random_grid[0, 0] assert np.isnan(_grid.get_value_for_center(GridPoint(1045, 10045))) # is not a center of a section - with pytest.raises(ValueError, match='is not a center within Grid'): + with pytest.raises(ValueError, match="is not a center within Grid"): _grid.get_value_for_center(GridPoint(1001, 10005)) # is not (a center) within the values - with pytest.raises(ValueError, match='is not a center within Grid'): + with pytest.raises(ValueError, match="is not a center within Grid"): _grid.get_value_for_center(GridPoint(2005, 10005)) @@ -138,8 +208,13 @@ def test_sum(): """ Test that sum is valid (and ignores np.nan) """ - _grid = Grid(100, 10, GridPoint(1000, 1000), np.ones((10, 10)), - (GridPoint(1040, 1040), GridPoint(1060, 1060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 1000), + np.ones((10, 10)), + (GridPoint(1040, 1040), GridPoint(1060, 1060)), + ) assert _grid.sum() == 96 @@ -147,8 +222,13 @@ def test_scale_grid_values(): """ Test hat values are correctly scaled """ - _grid = Grid(100, 10, GridPoint(1000, 1000), np.ones((10, 10)), - (GridPoint(1040, 1040), GridPoint(1060, 1060))) + _grid = Grid( + 100, + 10, + GridPoint(1000, 1000), + np.ones((10, 10)), + (GridPoint(1040, 1040), GridPoint(1060, 1060)), + ) _scaled_grid = _grid.scale_grid_values(4) assert _scaled_grid.sum() == 96 * 4 @@ -157,8 +237,9 @@ def test_meshgrid_coords(): """ Test that meshgrid has correct coordinates and correspond with expected section in values """ - _grid = EmptyGrid(500, 10, GridPoint(1000, 1000), - (GridPoint(1200, 1200), GridPoint(1300, 1300))) + _grid = EmptyGrid( + 500, 10, GridPoint(1000, 1000), (GridPoint(1200, 1200), GridPoint(1300, 1300)) + ) x_vals, y_vals = _grid.coords_mesh_grid() @@ -183,7 +264,7 @@ def test_meshgrid_coords(): assert pytest.approx(y_se) == 1005 # south # Check that tf and numpy implementation are equal - x_tf_vals, y_tf_vals = _grid.coords_mesh_grid('tf') + x_tf_vals, y_tf_vals = _grid.coords_mesh_grid("tf") assert np.array_equal(x_tf_vals.numpy(), x_vals) assert np.array_equal(y_tf_vals.numpy(), y_vals) @@ -198,8 +279,9 @@ def test_area(): """ Test Area of corresponding Grid """ - _grid = EmptyGrid(500, 10, GridPoint(1000, 1000), - (GridPoint(1200, 1200), GridPoint(1300, 1300))) + _grid = EmptyGrid( + 500, 10, GridPoint(1000, 1000), (GridPoint(1200, 1200), GridPoint(1300, 1300)) + ) assert GridPoint(1000, 1000) == _grid.southwest @@ -208,31 +290,32 @@ def test_move_values(): """ Test that movement of values results in correct values (asserted by summing values) """ - _norm_grid = Grid(500, 10, GridPoint(1000, 1000), np.ones((50, 50))).normalize(1) + p = GridPoint(10000, 500000) + _norm_grid = Grid(500, 10, p, np.ones((50, 50))).normalize(1) _norm_grid = _norm_grid.move(_norm_grid.southwest.move(0, 0)) assert pytest.approx(_norm_grid.sum()) == 1 # moves values half the diameter to the right, removing half of valid values # padding with zeros, should make the sum half of what it was _norm_grid = _norm_grid.move(_norm_grid.southwest.move(250, 0)) - assert pytest.approx(_norm_grid.sum()) == .5 + assert pytest.approx(_norm_grid.sum()) == 0.5 # moving it back doesn't bring back it values _norm_grid = _norm_grid.move(_norm_grid.southwest.move(-250, 0)) - assert pytest.approx(_norm_grid.sum()) == .5 + assert pytest.approx(_norm_grid.sum()) == 0.5 - _norm_grid = Grid(500, 10, GridPoint(1000, 1000), np.ones((50, 50))).normalize(1) + _norm_grid = Grid(500, 10, p, np.ones((50, 50))).normalize(1) _norm_grid = _norm_grid.move(_norm_grid.southwest.move(250, 250)) - assert pytest.approx(_norm_grid.sum()) == .25 + assert pytest.approx(_norm_grid.sum()) == 0.25 - _norm_grid = Grid(500, 10, GridPoint(1000, 1000), np.ones((50, 50))).normalize(1) + _norm_grid = Grid(500, 10, p, np.ones((50, 50))).normalize(1) _norm_grid = _norm_grid.move(_norm_grid.southwest.move(100, 300)) - assert pytest.approx(_norm_grid.sum()) == .32 + assert pytest.approx(_norm_grid.sum()) == 0.32 - _norm_grid = Grid(500, 10, GridPoint(1000, 1000), np.ones((50, 50))).normalize(1) + _norm_grid = Grid(500, 10, p, np.ones((50, 50))).normalize(1) _norm_grid = _norm_grid.move(_norm_grid.southwest.move(500, 500)) assert pytest.approx(_norm_grid.sum()) == 0 - _norm_grid = Grid(500, 10, GridPoint(1000, 1000), np.ones((50, 50))).normalize(1) + _norm_grid = Grid(500, 10, p, np.ones((50, 50))).normalize(1) with pytest.raises(ValueError): _norm_grid = _norm_grid.move(_norm_grid.southwest.move(505, 505))