Skip to content

Commit

Permalink
Patch after Rebase - Squash on Merge
Browse files Browse the repository at this point in the history
the remote files runtime utility was updated and merged on main while
the tm to bus events branch was completed. after rebasing, this patch
wires everything up correctly.
  • Loading branch information
mzappitello committed Aug 21, 2024
1 parent de504d3 commit 909796f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 35 deletions.
67 changes: 36 additions & 31 deletions src/lamp_py/bus_performance_manager/tm_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import pytz
import polars as pl

from lamp_py.runtime_utils.remote_files import RemoteFileLocations
from lamp_py.runtime_utils.remote_files import (
tm_geo_node_file,
tm_route_file,
tm_trip_file,
tm_vehicle_file,
tm_work_piece_file,
tm_block_file,
tm_run_file,
tm_operator_file,
)

BOSTON_TZ = pytz.timezone("EST5EDT")
UTC_TZ = pytz.utc
Expand Down Expand Up @@ -48,28 +57,28 @@ def generate_tm_events(tm_files: List[str]) -> pl.DataFrame:
"""
# the geo node id is the transit master key and the geo node abbr is the
# gtfs stop id
tm_geo_nodes = pl.scan_parquet(
RemoteFileLocations.tm_geo_node_file.get_s3_path()
).select(["GEO_NODE_ID", "GEO_NODE_ABBR"])
tm_geo_nodes = pl.scan_parquet(tm_geo_node_file.s3_uri).select(
"GEO_NODE_ID", "GEO_NODE_ABBR"
)

# the route id is the transit master key and the route abbr is the gtfs
# route id.
# NOTE: some of these route ids have leading zeros
tm_routes = pl.scan_parquet(
RemoteFileLocations.tm_route_file.get_s3_path()
).select(["ROUTE_ID", "ROUTE_ABBR"])
tm_routes = pl.scan_parquet(tm_route_file.s3_uri).select(
"ROUTE_ID", "ROUTE_ABBR"
)

# the trip id is the transit master key and the trip serial number is the
# gtfs trip id.
tm_trips = pl.scan_parquet(
RemoteFileLocations.tm_trip_file.get_s3_path()
).select(["TRIP_ID", "TRIP_SERIAL_NUMBER"])
tm_trips = pl.scan_parquet(tm_trip_file.s3_uri).select(
"TRIP_ID", "TRIP_SERIAL_NUMBER"
)

# the vehicle id is the transit master key and the property tag is the
# vehicle label
tm_vehicles = pl.scan_parquet(
RemoteFileLocations.tm_vehicle_file.get_s3_path()
).select(["VEHICLE_ID", "PROPERTY_TAG"])
tm_vehicles = pl.scan_parquet(tm_vehicle_file.s3_uri).select(
"VEHICLE_ID", "PROPERTY_TAG"
)

# pull stop crossing information for a given service date and join it with
# other dataframes using the transit master keys.
Expand Down Expand Up @@ -156,9 +165,7 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame:
# can have the same run or block. I think its because a Piece of Work can
# be scheduled for a single day of the week but we reuse Runs and Blocks
# across different scheduled days.
tm_work_pieces = pl.scan_parquet(
RemoteFileLocations.tm_work_piece_file.get_s3_path()
).select(
tm_work_pieces = pl.scan_parquet(tm_work_piece_file.s3_uri).select(
"WORK_PIECE_ID",
"BLOCK_ID",
"RUN_ID",
Expand All @@ -171,17 +178,17 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame:
# Block Abbr is the ID the rest of the MBTA uses for this Block
# Time Table Version Id is similar to our Static Schedule Version keys in
# the Rail Performance Manager DB
tm_blocks = pl.scan_parquet(
RemoteFileLocations.tm_block_file.get_s3_path()
).select("BLOCK_ID", "BLOCK_ABBR", "TIME_TABLE_VERSION_ID")
tm_blocks = pl.scan_parquet(tm_block_file.s3_uri).select(
"BLOCK_ID", "BLOCK_ABBR", "TIME_TABLE_VERSION_ID"
)

# Run Id is the TM Run Table Key
# Run Designator is the ID the rest of the MBTA uses for this Run
# Time Table Version Id is similar to our Static Schedule Version keys in
# the Rail Performance Manager DB
tm_runs = pl.scan_parquet(
RemoteFileLocations.tm_run_file.get_s3_path()
).select("RUN_ID", "RUN_DESIGNATOR", "TIME_TABLE_VERSION_ID")
tm_runs = pl.scan_parquet(tm_run_file.s3_uri).select(
"RUN_ID", "RUN_DESIGNATOR", "TIME_TABLE_VERSION_ID"
)

# Trip Id is the TM Trip Table Key
# Block Id is the TM Block Table Key
Expand All @@ -190,9 +197,7 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame:
# join with the Work Pieces objects.
# Time Table Version Id is similar to our Static Schedule Version keys in
# the Rail Performance Manager DB
tm_trips = pl.scan_parquet(
RemoteFileLocations.tm_trip_file.get_s3_path()
).select(
tm_trips = pl.scan_parquet(tm_trip_file.s3_uri).select(
"TRIP_ID",
"BLOCK_ID",
"TRIP_SERIAL_NUMBER",
Expand Down Expand Up @@ -249,15 +254,15 @@ def get_daily_work_pieces(daily_work_piece_files: List[str]) -> pl.DataFrame:

# Operator Id is the TM Operator Table Key
# Operator Logon Id is the Badge Number
tm_operators = pl.scan_parquet(
RemoteFileLocations.tm_operator_file.get_s3_path()
).select("OPERATOR_ID", "ONBOARD_LOGON_ID")
tm_operators = pl.scan_parquet(tm_operator_file.s3_uri).select(
"OPERATOR_ID", "ONBOARD_LOGON_ID"
)

# Vehicle Id is the TM Vehicle Table Key
# Property Tag is Vehicle Label used by the MBTA
tm_vehicles = pl.scan_parquet(
RemoteFileLocations.tm_vehicle_file.get_s3_path()
).select("VEHICLE_ID", "PROPERTY_TAG")
tm_vehicles = pl.scan_parquet(tm_vehicle_file.s3_uri).select(
"VEHICLE_ID", "PROPERTY_TAG"
)

# Join Operator and Vehicle information to the Daily Work Pieces
realtime_work_pieces = daily_work_piece.join(
Expand Down
27 changes: 23 additions & 4 deletions tests/bus_performance_manager/test_tm_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,38 @@

from lamp_py.bus_performance_manager.tm_ingestion import generate_tm_events

from ..test_resources import LocalFileLocaions
from ..test_resources import (
tm_geo_node_file,
tm_route_file,
tm_trip_file,
tm_vehicle_file,
tm_stop_crossings,
)


def test_tm_to_bus_events(monkeypatch: MonkeyPatch) -> None:
"""
run tests on each file in the test files tm stop crossings directory
"""
monkeypatch.setattr(
"lamp_py.bus_performance_manager.tm_ingestion.RemoteFileLocations",
LocalFileLocaions,
"lamp_py.bus_performance_manager.tm_ingestion.tm_geo_node_file",
tm_geo_node_file,
)
monkeypatch.setattr(
"lamp_py.bus_performance_manager.tm_ingestion.tm_route_file",
tm_route_file,
)
monkeypatch.setattr(
"lamp_py.bus_performance_manager.tm_ingestion.tm_trip_file",
tm_trip_file,
)
monkeypatch.setattr(
"lamp_py.bus_performance_manager.tm_ingestion.tm_vehicle_file",
tm_vehicle_file,
)

tm_sc_dir = LocalFileLocaions.tm_stop_crossing.get_s3_path()
tm_sc_dir = tm_stop_crossings.s3_uri
print(tm_sc_dir)
assert os.path.exists(tm_sc_dir)

for filename in os.listdir(tm_sc_dir):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,20 @@ def s3_uri(self) -> str:
bucket=S3_SPRINGBOARD,
prefix="RT_VEHICLE_POSITIONS",
)

tm_stop_crossings = LocalS3Location(
bucket=S3_SPRINGBOARD,
prefix="TM/STOP_CROSSING",
)
tm_geo_node_file = LocalS3Location(
bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_GEO_NODE.parquet"
)
tm_route_file = LocalS3Location(
bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_ROUTE.parquet"
)
tm_trip_file = LocalS3Location(
bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_TRIP.parquet"
)
tm_vehicle_file = LocalS3Location(
bucket=S3_SPRINGBOARD, prefix="TM/TMMAIN_VEHICLE.parquet"
)

0 comments on commit 909796f

Please sign in to comment.