From cb4b1383afadb7f5f80ff91819edcd9abf8f6251 Mon Sep 17 00:00:00 2001 From: Ian Reynolds Date: Thu, 18 Jan 2024 20:14:15 -0500 Subject: [PATCH] working but still not a huge jump in performance --- src/event.py | 4 ++-- src/gtfs.py | 21 ++++++++++++--------- src/timing.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 11 deletions(-) create mode 100644 src/timing.py diff --git a/src/event.py b/src/event.py index 90a927f..1b66ad7 100644 --- a/src/event.py +++ b/src/event.py @@ -165,11 +165,11 @@ def enrich_event(df: pd.DataFrame, gtfs_archive: gtfs.GtfsArchive): # get trips and stop times for this route specifically (slow to scan them all) route_id = df["route_id"].iloc[0] - trip_id = df["trip_id"].iloc[0] scheduled_trips_for_route = gtfs_archive.trips_by_route_id(route_id) - scheduled_stop_times_for_route = gtfs_archive.stop_times_by_trip_id(trip_id) + scheduled_stop_times_for_route = gtfs_archive.stop_times_by_route_id(route_id) headway_adjusted_df = gtfs.add_gtfs_headways(df, scheduled_trips_for_route, scheduled_stop_times_for_route) + # headway_adjusted_df = gtfs.add_gtfs_headways(df, gtfs_archive.trips, gtfs_archive.stop_times) # future warning: returning a series is actually the correct future behavior of to_pydatetime(), can drop the # context manager later with warnings.catch_warnings(): diff --git a/src/gtfs.py b/src/gtfs.py index 652f01e..3889d0d 100644 --- a/src/gtfs.py +++ b/src/gtfs.py @@ -14,6 +14,7 @@ from logger import set_up_logging import util +import timing logger = set_up_logging(__name__) @@ -55,13 +56,16 @@ class GtfsArchive: service_date: datetime.date def __post_init__(self): - self._stop_times_by_trip_id = _group_df_by_column(self.stop_times, "trip_id") + self._trips_empty = _get_empty_df_with_same_columns(self.trips) self._stop_times_empty = _get_empty_df_with_same_columns(self.stop_times) self._trips_by_route_id = _group_df_by_column(self.trips, "route_id") - self._trips_empty = _get_empty_df_with_same_columns(self.trips) + self._stop_times_by_route_id = {} + for route_id in self._trips_by_route_id.keys(): + trip_ids_for_route = self._trips_by_route_id[route_id].trip_id + self._stop_times_by_route_id[route_id] = self.stop_times[self.stop_times.trip_id.isin(trip_ids_for_route)] - def stop_times_by_trip_id(self, route_id: str): - return self._stop_times_by_trip_id.get(route_id, self._stop_times_empty) + def stop_times_by_route_id(self, route_id: str): + return self._stop_times_by_route_id.get(route_id, self._stop_times_empty) def trips_by_route_id(self, route_id: str): return self._trips_by_route_id.get(route_id, self._trips_empty) @@ -168,8 +172,9 @@ def read_gtfs(date: datetime.date) -> GtfsArchive: return GtfsArchive(trips=trips, stop_times=stop_times, stops=stops, service_date=date) +@timing.measure_time(report_frequency=0.1) @tracer.wrap() -def add_gtfs_headways(events_df: pd.DataFrame, all_trips: pd.DataFrame, all_stops: pd.DataFrame) -> pd.DataFrame: +def add_gtfs_headways(events_df: pd.DataFrame, trips: pd.DataFrame, stop_times: pd.DataFrame) -> pd.DataFrame: """ This will calculate scheduled headway and traveltime information from gtfs for the routes we care about, and then match our actual @@ -192,11 +197,11 @@ def add_gtfs_headways(events_df: pd.DataFrame, all_trips: pd.DataFrame, all_stop # we have to do this day-by-day because gtfs changes so often for service_date, days_events in events_df.groupby("service_date"): # filter out the trips of interest - relevant_trips = all_trips[all_trips.route_id.isin(days_events.route_id)] + relevant_trips = trips[trips.route_id.isin(days_events.route_id)] # take only the stops from those trips (adding route and dir info) trip_info = relevant_trips[["trip_id", "route_id", "direction_id"]] - gtfs_stops = all_stops.merge(trip_info, on="trip_id", how="right") + gtfs_stops = stop_times.merge(trip_info, on="trip_id", how="right") # calculate gtfs headways gtfs_stops = gtfs_stops.sort_values(by="arrival_time") @@ -212,8 +217,6 @@ def add_gtfs_headways(events_df: pd.DataFrame, all_trips: pd.DataFrame, all_stop # assign each actual timepoint a scheduled headway # merge_asof 'backward' matches the previous scheduled value of 'arrival_time' days_events["arrival_time"] = days_events.event_time - pd.Timestamp(service_date).tz_localize("US/Eastern") - print("MERGING WITH") - print(gtfs_stops[RTE_DIR_STOP + ["arrival_time", "scheduled_headway"]]) augmented_events = pd.merge_asof( days_events.sort_values(by="arrival_time"), gtfs_stops[RTE_DIR_STOP + ["arrival_time", "scheduled_headway"]], diff --git a/src/timing.py b/src/timing.py new file mode 100644 index 0000000..20c7b98 --- /dev/null +++ b/src/timing.py @@ -0,0 +1,29 @@ +from functools import wraps +from time import time +from random import random +import numpy as np + + +def measure_time(report_frequency: float = 1.0, trail_length=1000): + def decorator(fn): + exec_times = [] + @wraps(fn) + def wrap(*args, **kw): + nonlocal exec_times + ts = time() + result = fn(*args, **kw) + te = time() + exec_times.append(te - ts) + if random() < report_frequency: + last = exec_times[-1] + exec_times = exec_times[-trail_length:] + avg = np.mean(exec_times) + std = np.std(exec_times) + min = np.min(exec_times) + max = np.max(exec_times) + print(f"func {fn.__name__}: last={last:.3f}s min={min:.3f} max={max:.3f} avg={avg:.3f}s std={std:.3f}s") + return result + + return wrap + + return decorator