From 909796f761543398e113d30720eb2fe9121c2b31 Mon Sep 17 00:00:00 2001 From: Mike Zappitello Date: Wed, 21 Aug 2024 11:16:59 -0400 Subject: [PATCH] Patch after Rebase - Squash on Merge the remote files runtime utility was updated and merged on main while the tm to bus events branch was completed. after rebasing, this patch wires everything up correctly. --- .../bus_performance_manager/tm_ingestion.py | 67 ++++++++++--------- .../test_tm_ingestion.py | 27 ++++++-- tests/test_resources.py | 17 +++++ 3 files changed, 76 insertions(+), 35 deletions(-) diff --git a/src/lamp_py/bus_performance_manager/tm_ingestion.py b/src/lamp_py/bus_performance_manager/tm_ingestion.py index 56579534..b38d67b5 100644 --- a/src/lamp_py/bus_performance_manager/tm_ingestion.py +++ b/src/lamp_py/bus_performance_manager/tm_ingestion.py @@ -3,7 +3,16 @@ import pytz import polars as pl -from lamp_py.runtime_utils.remote_files import RemoteFileLocations +from lamp_py.runtime_utils.remote_files import ( + tm_geo_node_file, + tm_route_file, + tm_trip_file, + tm_vehicle_file, + tm_work_piece_file, + tm_block_file, + tm_run_file, + tm_operator_file, +) BOSTON_TZ = pytz.timezone("EST5EDT") UTC_TZ = pytz.utc @@ -48,28 +57,28 @@ def generate_tm_events(tm_files: List[str]) -> pl.DataFrame: """ # the geo node id is the transit master key and the geo node abbr is the # gtfs stop id - tm_geo_nodes = pl.scan_parquet( - RemoteFileLocations.tm_geo_node_file.get_s3_path() - ).select(["GEO_NODE_ID", "GEO_NODE_ABBR"]) + tm_geo_nodes = pl.scan_parquet(tm_geo_node_file.s3_uri).select( + "GEO_NODE_ID", "GEO_NODE_ABBR" + ) # the route id is the transit master key and the route abbr is the gtfs # route id. # NOTE: some of these route ids have leading zeros - tm_routes = pl.scan_parquet( - RemoteFileLocations.tm_route_file.get_s3_path() - ).select(["ROUTE_ID", "ROUTE_ABBR"]) + tm_routes = pl.scan_parquet(tm_route_file.s3_uri).select( + "ROUTE_ID", "ROUTE_ABBR" + ) # the trip id is the transit master key and the trip serial number is the # gtfs trip id. - tm_trips = pl.scan_parquet( - RemoteFileLocations.tm_trip_file.get_s3_path() - ).select(["TRIP_ID", "TRIP_SERIAL_NUMBER"]) + tm_trips = pl.scan_parquet(tm_trip_file.s3_uri).select( + "TRIP_ID", "TRIP_SERIAL_NUMBER" + ) # the vehicle id is the transit master key and the property tag is the # vehicle label - tm_vehicles = pl.scan_parquet( - RemoteFileLocations.tm_vehicle_file.get_s3_path() - ).select(["VEHICLE_ID", "PROPERTY_TAG"]) + tm_vehicles = pl.scan_parquet(tm_vehicle_file.s3_uri).select( + "VEHICLE_ID", "PROPERTY_TAG" + ) # pull stop crossing information for a given service date and join it with # other dataframes using the transit master keys. @@ -156,9 +165,7 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame: # can have the same run or block. I think its because a Piece of Work can # be scheduled for a single day of the week but we reuse Runs and Blocks # across different scheduled days. - tm_work_pieces = pl.scan_parquet( - RemoteFileLocations.tm_work_piece_file.get_s3_path() - ).select( + tm_work_pieces = pl.scan_parquet(tm_work_piece_file.s3_uri).select( "WORK_PIECE_ID", "BLOCK_ID", "RUN_ID", @@ -171,17 +178,17 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame: # Block Abbr is the ID the rest of the MBTA uses for this Block # Time Table Version Id is similar to our Static Schedule Version keys in # the Rail Performance Manager DB - tm_blocks = pl.scan_parquet( - RemoteFileLocations.tm_block_file.get_s3_path() - ).select("BLOCK_ID", "BLOCK_ABBR", "TIME_TABLE_VERSION_ID") + tm_blocks = pl.scan_parquet(tm_block_file.s3_uri).select( + "BLOCK_ID", "BLOCK_ABBR", "TIME_TABLE_VERSION_ID" + ) # Run Id is the TM Run Table Key # Run Designator is the ID the rest of the MBTA uses for this Run # Time Table Version Id is similar to our Static Schedule Version keys in # the Rail Performance Manager DB - tm_runs = pl.scan_parquet( - RemoteFileLocations.tm_run_file.get_s3_path() - ).select("RUN_ID", "RUN_DESIGNATOR", "TIME_TABLE_VERSION_ID") + tm_runs = pl.scan_parquet(tm_run_file.s3_uri).select( + "RUN_ID", "RUN_DESIGNATOR", "TIME_TABLE_VERSION_ID" + ) # Trip Id is the TM Trip Table Key # Block Id is the TM Block Table Key @@ -190,9 +197,7 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame: # join with the Work Pieces objects. # Time Table Version Id is similar to our Static Schedule Version keys in # the Rail Performance Manager DB - tm_trips = pl.scan_parquet( - RemoteFileLocations.tm_trip_file.get_s3_path() - ).select( + tm_trips = pl.scan_parquet(tm_trip_file.s3_uri).select( "TRIP_ID", "BLOCK_ID", "TRIP_SERIAL_NUMBER", @@ -249,15 +254,15 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame: # Operator Id is the TM Operator Table Key # Operator Logon Id is the Badge Number - tm_operators = pl.scan_parquet( - RemoteFileLocations.tm_operator_file.get_s3_path() - ).select("OPERATOR_ID", "ONBOARD_LOGON_ID") + tm_operators = pl.scan_parquet(tm_operator_file.s3_uri).select( + "OPERATOR_ID", "ONBOARD_LOGON_ID" + ) # Vehicle Id is the TM Vehicle Table Key # Property Tag is Vehicle Label used by the MBTA - tm_vehicles = pl.scan_parquet( - RemoteFileLocations.tm_vehicle_file.get_s3_path() - ).select("VEHICLE_ID", "PROPERTY_TAG") + tm_vehicles = pl.scan_parquet(tm_vehicle_file.s3_uri).select( + "VEHICLE_ID", "PROPERTY_TAG" + ) # Join Operator and Vehicle information to the Daily Work Pieces realtime_work_pieces = daily_work_piece.join( diff --git a/tests/bus_performance_manager/test_tm_ingestion.py b/tests/bus_performance_manager/test_tm_ingestion.py index 609a4456..06630e29 100644 --- a/tests/bus_performance_manager/test_tm_ingestion.py +++ b/tests/bus_performance_manager/test_tm_ingestion.py @@ -6,7 +6,13 @@ from lamp_py.bus_performance_manager.tm_ingestion import generate_tm_events -from ..test_resources import LocalFileLocaions +from ..test_resources import ( + tm_geo_node_file, + tm_route_file, + tm_trip_file, + tm_vehicle_file, + tm_stop_crossings, +) def test_tm_to_bus_events(monkeypatch: MonkeyPatch) -> None: @@ -14,11 +20,24 @@ def test_tm_to_bus_events(monkeypatch: MonkeyPatch) -> None: run tests on each file in the test files tm stop crossings directory """ monkeypatch.setattr( - "lamp_py.bus_performance_manager.tm_ingestion.RemoteFileLocations", - LocalFileLocaions, + "lamp_py.bus_performance_manager.tm_ingestion.tm_geo_node_file", + tm_geo_node_file, + ) + monkeypatch.setattr( + "lamp_py.bus_performance_manager.tm_ingestion.tm_route_file", + tm_route_file, + ) + monkeypatch.setattr( + "lamp_py.bus_performance_manager.tm_ingestion.tm_trip_file", + tm_trip_file, + ) + monkeypatch.setattr( + "lamp_py.bus_performance_manager.tm_ingestion.tm_vehicle_file", + tm_vehicle_file, ) - tm_sc_dir = LocalFileLocaions.tm_stop_crossing.get_s3_path() + tm_sc_dir = tm_stop_crossings.s3_uri + print(tm_sc_dir) assert os.path.exists(tm_sc_dir) for filename in os.listdir(tm_sc_dir): diff --git a/tests/test_resources.py b/tests/test_resources.py index a7ee1a00..95989979 100644 --- a/tests/test_resources.py +++ b/tests/test_resources.py @@ -61,3 +61,20 @@ def s3_uri(self) -> str: bucket=S3_SPRINGBOARD, prefix="RT_VEHICLE_POSITIONS", ) + +tm_stop_crossings = LocalS3Location( + bucket=S3_SPRINGBOARD, + prefix="TM/STOP_CROSSING", +) +tm_geo_node_file = LocalS3Location( + bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_GEO_NODE.parquet" +) +tm_route_file = LocalS3Location( + bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_ROUTE.parquet" +) +tm_trip_file = LocalS3Location( + bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_TRIP.parquet" +) +tm_vehicle_file = LocalS3Location( + bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_VEHICLE.parquet" +)