diff --git a/backend/.gitignore b/backend/.gitignore index 5e92857..7e90627 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -4,4 +4,5 @@ container_* trips.txt positions.txt siri-vehicles.json -gtfs/ \ No newline at end of file +gtfs/ +env \ No newline at end of file diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 2ef3d2f..e6b7511 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -49,6 +49,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -100,6 +115,7 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" dependencies = [ + "brotli", "flate2", "futures-core", "memchr", @@ -161,7 +177,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -213,6 +229,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tower 0.5.0", "tower-http", "tracing", "tracing-subscriber", @@ -278,6 +295,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -1048,7 +1086,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", + "tower 0.4.13", "tower-service", "tracing", ] @@ -2679,6 +2717,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36b837f86b25d7c0d7988f00a54e74739be6477f2aac6201b8f429a7569991b7" +dependencies = [ + "tower-layer", + "tower-service", +] + [[package]] name = "tower-http" version = "0.5.2" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 9af6a0e..240050c 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -9,7 +9,6 @@ axum = "0.7.5" chrono = { version = "0.4.38", features = ["serde"] } chrono-tz = "0.9.0" csv = "1.3.0" -# geo-types = "0.7.13" futures = "0.3.30" geo = "0.28.0" http = "1.1.0" @@ -34,10 +33,13 @@ sqlx = { version = "0.8.0", features = [ ] } thiserror = "1.0.63" tokio = { version = "1.39.2", features = ["full"] } +tower = "0.5.0" tower-http = { version = "0.5.2", features = [ "trace", "timeout", "compression-gzip", + "compression-br", + "normalize-path", ] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/backend/src/alerts.rs b/backend/src/alerts.rs index d3b48a8..691f2d3 100644 --- a/backend/src/alerts.rs +++ b/backend/src/alerts.rs @@ -1,23 +1,18 @@ use crate::{ - feed::{self}, - static_data::ROUTES, - trips::DecodeFeedError, + gtfs::decode, + train::{static_data::ROUTES, trips::DecodeFeedError}, }; use chrono::{DateTime, Utc}; -use prost::Message; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::prelude::*; use sqlx::{PgPool, QueryBuilder}; -use std::{env::var, time::Duration}; -use tokio::{ - fs::{create_dir, write}, - time::sleep, -}; +use std::time::Duration; +use tokio::time::sleep; use uuid::Uuid; pub async fn import(pool: PgPool) { tokio::spawn(async move { loop { - if let Err(e) = decode(&pool).await { + if let Err(e) = parse_gtfs(&pool).await { tracing::error!("Failed to decode feed: {:?}", e); } sleep(Duration::from_secs(10)).await; @@ -106,20 +101,12 @@ pub struct AffectedEntity { pub sort_order: i32, } -async fn decode(pool: &PgPool) -> Result<(), DecodeFeedError> { - let data = reqwest::Client::new() - .get("https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/camsys%2Fall-alerts") - .send() - .await? - .bytes() - .await?; - let feed = feed::FeedMessage::decode(data)?; - - if var("DEBUG_GTFS").is_ok() { - let msgs = format!("{:#?}", feed); - create_dir("./gtfs").await.ok(); - write("./gtfs/alerts.txt", msgs).await.unwrap(); - } +async fn parse_gtfs(pool: &PgPool) -> Result<(), DecodeFeedError> { + let feed = decode( + "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/camsys%2Fall-alerts", + "alerts", + ) + .await?; let mut in_feed_ids = vec![]; let mut cloned_ids: Vec = vec![]; @@ -220,14 +207,13 @@ async fn decode(pool: &PgPool) -> Result<(), DecodeFeedError> { r } }) - .map(|r| { + .and_then(|r| { if ROUTES.contains(&r.as_str()) { Some(r) } else { None } - }) - .flatten(); + }); // check if route_id is in ROUTES, otherwise its a bus route // TODO: improve this diff --git a/backend/src/bus/positions.rs b/backend/src/bus/positions.rs index 314ea16..370510e 100644 --- a/backend/src/bus/positions.rs +++ b/backend/src/bus/positions.rs @@ -1,14 +1,11 @@ -use crate::{bus::api_key, feed, trips::DecodeFeedError}; +use crate::{bus::api_key, gtfs::decode, train::trips::DecodeFeedError}; use chrono::{DateTime, Utc}; -use prost::Message; use rayon::prelude::*; use serde::{Deserialize, Deserializer}; use sqlx::{PgPool, QueryBuilder}; use std::time::Duration; use tokio::time::sleep; -// use std::io::Write; - #[derive(Debug)] struct Position { vehicle_id: i32, @@ -27,7 +24,7 @@ pub async fn import(pool: PgPool) { let pool1 = pool.clone(); tokio::spawn(async move { loop { - match decode_feed(pool1.clone()).await { + match parse_gtfs(pool1.clone()).await { Ok(_) => (), Err(e) => { tracing::error!("Error importing bus position data: {:?}", e); @@ -40,7 +37,7 @@ pub async fn import(pool: PgPool) { tokio::spawn(async move { loop { - match decode_siri(pool.clone()).await { + match parse_siri(pool.clone()).await { Ok(_) => (), Err(e) => { tracing::error!("Error importing SIRI bus data: {:?}", e); @@ -52,19 +49,12 @@ pub async fn import(pool: PgPool) { }); } -pub async fn decode_feed(pool: PgPool) -> Result<(), DecodeFeedError> { - let data = reqwest::Client::new() - .get("https://gtfsrt.prod.obanyc.com/vehiclePositions") - .send() - .await? - .bytes() - .await?; - - let feed = feed::FeedMessage::decode(data)?; - - // let mut msgs = Vec::new(); - // write!(msgs, "{:#?}", feed).unwrap(); - // tokio::fs::write("./positions.txt", msgs).await.unwrap(); +pub async fn parse_gtfs(pool: PgPool) -> Result<(), DecodeFeedError> { + let feed = decode( + "https://gtfsrt.prod.obanyc.com/vehiclePositions", + "buspositions", + ) + .await?; // let stop_ids = sqlx::query!("SELECT id FROM bus_stops") // .fetch_all(&pool) @@ -164,6 +154,41 @@ pub async fn decode_feed(pool: PgPool) -> Result<(), DecodeFeedError> { Ok(()) } +#[derive(Debug)] +struct SiriPosition { + vehicle_id: i32, + mta_id: String, + progress_status: Option, + passengers: Option, + capacity: Option, +} + +impl From for SiriPosition { + fn from(value: MonitoredVehicleJourney) -> Self { + let vehicle_id: i32 = value.vehicle_ref.parse().unwrap(); + // TODO: simplify + let capacity = value.monitored_call.and_then(|c| { + c.extensions.map(|e| { + ( + e.capacities.estimated_passenger_count, + e.capacities.estimated_passenger_capacity, + ) + }) + }); + + let progress_status = value.progress_status.and_then(|s| s.into_iter().nth(0)); + let mta_id = value.framed_vehicle_journey_ref.dated_vehicle_journey_ref; + + Self { + vehicle_id, + mta_id, + progress_status, + passengers: capacity.map(|c| c.0), + capacity: capacity.map(|c| c.1), + } + } +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "PascalCase")] struct ServiceDelivery { @@ -238,9 +263,10 @@ struct Capacities { } // need to get siri feed so we can get progress status and capacities -pub async fn decode_siri(pool: PgPool) -> Result<(), DecodeFeedError> { +pub async fn parse_siri(pool: PgPool) -> Result<(), DecodeFeedError> { let siri_res = reqwest::Client::new() .get("https://api.prod.obanyc.com/api/siri/vehicle-monitoring.json") + .timeout(Duration::from_secs(29)) .query(&[ ("key", api_key()), ("version", "2"), @@ -269,36 +295,29 @@ pub async fn decode_siri(pool: PgPool) -> Result<(), DecodeFeedError> { return Err(DecodeFeedError::Siri("no vehicles".to_string())); }; - // TODO: make sure progress status is correct (that we only need to worry about statuses bc when rate is unknown/no progress its always layover/spooking) - for vehicle in vehicles.vehicle_activity { - let monitored_vehicle_journey = vehicle.monitored_vehicle_journey; - let capacities = monitored_vehicle_journey.monitored_call.and_then(|c| { - c.extensions.map(|e| { - ( - e.capacities.estimated_passenger_count, - e.capacities.estimated_passenger_capacity, - ) - }) - }); - - let progress_status = monitored_vehicle_journey - .progress_status - .map(|s| s[0].clone()); - - let vehicle_id: i32 = monitored_vehicle_journey.vehicle_ref.parse().unwrap(); - let trip_id = monitored_vehicle_journey - .framed_vehicle_journey_ref - .dated_vehicle_journey_ref; + let positions = vehicles + .vehicle_activity + .into_par_iter() + .map(|v| v.monitored_vehicle_journey.into()) + .collect::>(); - sqlx::query!( - "UPDATE bus_positions SET progress_status = $1, passengers = $2, capacity = $3 WHERE vehicle_id = $4 AND mta_id = $5", - progress_status, - capacities.map(|c| c.0), - capacities.map(|c| c.1), - vehicle_id, - trip_id - ).execute(&pool).await?; + // TODO: fix querybuilder issue "trailing junk after parameter at or near \"$5V\" + for p in positions { + sqlx::query!("UPDATE bus_positions SET progress_status = $1, passengers = $2, capacity = $3 WHERE vehicle_id = $4 AND mta_id = $5", p.progress_status, p.passengers, p.capacity, p.vehicle_id, p.mta_id).execute(&pool).await?; } + // for positions in positions.chunks(1) { + // dbg!(positions); + // let mut query_builder = QueryBuilder::new("UPDATE bus_positions SET progress_status = $1, passengers = $2, capacity = $3 WHERE vehicle_id = $4 AND mta_id = $5"); + // query_builder.push_values(positions, |mut b, position| { + // b.push_bind(&position.progress_status) + // .push_bind(position.passengers) + // .push_bind(position.capacity) + // .push_bind(position.vehicle_id) + // .push_bind(&position.mta_id); + // }); + // let query = query_builder.build(); + // query.execute(&pool).await?; + // } // println!("{:#?}", service_delivery); // let mut progresses = Vec::new(); diff --git a/backend/src/bus/static_data.rs b/backend/src/bus/static_data.rs index bd5e389..8b48c9f 100644 --- a/backend/src/bus/static_data.rs +++ b/backend/src/bus/static_data.rs @@ -156,9 +156,8 @@ pub async fn stops_and_routes(pool: &PgPool) { let query = query_builder.build(); query.execute(pool).await.unwrap(); - // prevent issues with query being too big - let chunk_size = 1000; - for chunk in stops.chunks(chunk_size) { + // Chunk to the maximum amount of parameters allowed + for chunk in stops.chunks(65534 / 5) { let mut query_builder = QueryBuilder::new("INSERT INTO bus_stops (id, name, direction, lat, lon)"); query_builder.push_values(chunk, |mut b, route| { @@ -173,7 +172,7 @@ pub async fn stops_and_routes(pool: &PgPool) { query.execute(pool).await.unwrap(); } - for chunk in route_stops.chunks(chunk_size) { + for chunk in route_stops.chunks(65534 / 5) { let mut query_builder = QueryBuilder::new( "INSERT INTO bus_route_stops (route_id, stop_id, stop_sequence, headsign, direction)", ); @@ -331,7 +330,7 @@ where D: Deserializer<'de>, { let polyline = String::deserialize(deserializer)?; - Ok(decode_polyline(&polyline, 5).map_err(serde::de::Error::custom)?) + decode_polyline(&polyline, 5).map_err(serde::de::Error::custom) } #[derive(Deserialize, Clone, Debug)] diff --git a/backend/src/bus/trips.rs b/backend/src/bus/trips.rs index 86330bb..bf23e7d 100644 --- a/backend/src/bus/trips.rs +++ b/backend/src/bus/trips.rs @@ -1,25 +1,19 @@ use crate::{ - feed::{self, TripUpdate}, - trips::{DecodeFeedError, IntoStopTimeError, StopTimeUpdateWithTrip}, + feed::TripUpdate, + gtfs::decode, + train::trips::{DecodeFeedError, IntoStopTimeError, StopTimeUpdateWithTrip}, }; use chrono::{DateTime, Utc}; -use prost::Message; use rayon::prelude::*; use sqlx::{PgPool, QueryBuilder}; -use std::env::var; use std::time::Duration; -use tokio::{ - fs::{create_dir, write}, - time::sleep, -}; +use tokio::time::sleep; use uuid::Uuid; -// TODO: remove unwraps and handle errors - pub async fn import(pool: PgPool) { tokio::spawn(async move { loop { - match decode_feed(&pool).await { + match parse_gtfs(&pool).await { Ok(_) => (), Err(e) => { tracing::error!("Error importing bus trip data: {:?}", e); @@ -153,24 +147,10 @@ impl<'a> TryFrom> for StopTime { } } -pub async fn decode_feed(pool: &PgPool) -> Result<(), DecodeFeedError> { +pub async fn parse_gtfs(pool: &PgPool) -> Result<(), DecodeFeedError> { // after 30 seconds the obanyc api will return its own timeout error // TODO: fix decode error that shows up once in a while - let data = reqwest::Client::new() - .get("https://gtfsrt.prod.obanyc.com/tripUpdates") - .timeout(Duration::from_secs(29)) - .send() - .await? - .bytes() - .await?; - - let feed = feed::FeedMessage::decode(data)?; - - if var("DEBUG_GTFS").is_ok() { - let msgs = format!("{:#?}", feed); - create_dir("./gtfs").await.ok(); - write("./gtfs/bus-trips.txt", msgs).await.unwrap(); - } + let feed = decode("https://gtfsrt.prod.obanyc.com/tripUpdates", "bustrips").await?; let mut trips: Vec = vec![]; let mut stop_times: Vec = vec![]; @@ -258,8 +238,9 @@ pub async fn decode_feed(pool: &PgPool) -> Result<(), DecodeFeedError> { let query = query_builder.build(); query.execute(pool).await?; - // Insert stop times. Need to chunk otherwise its too big for a single query - for chunk in stop_times.chunks(1000) { + // The maximum bind parameters for postgres is 65534 and we have 5 parameters for each stop time. + // https://docs.rs/sqlx/latest/sqlx/struct.QueryBuilder.html#method.push_bind + for chunk in stop_times.chunks(65534 / 5) { let mut query_builder = QueryBuilder::new( "INSERT INTO bus_stop_times (trip_id, stop_id, arrival, departure, stop_sequence) ", ); diff --git a/backend/src/gtfs.rs b/backend/src/gtfs.rs new file mode 100644 index 0000000..4ed9e95 --- /dev/null +++ b/backend/src/gtfs.rs @@ -0,0 +1,30 @@ +use crate::{feed::FeedMessage, train::trips::DecodeFeedError}; +use prost::Message; +use std::env::var; +use std::sync::OnceLock; +use tokio::fs::{create_dir, write}; + +// https://stackoverflow.com/a/77249700 +pub fn debug_gtfs() -> &'static bool { + // you need bustime api key to run this + static DEBUG_GTFS: OnceLock = OnceLock::new(); + DEBUG_GTFS.get_or_init(|| var("DEBUG_GTFS").is_ok()) +} + +pub async fn decode(url: &str, name: &str) -> Result { + let data = reqwest::Client::new() + .get(url) + .send() + .await? + .bytes() + .await?; + + let feed = FeedMessage::decode(data)?; + + if *debug_gtfs() { + let msgs = format!("{:#?}", feed); + create_dir("./gtfs").await.ok(); + write(format!("./gtfs/{}.txt", name), msgs).await.unwrap(); + } + Ok(feed) +} diff --git a/backend/src/main.rs b/backend/src/main.rs index 40164c0..6dacc02 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,5 +1,12 @@ -use axum::{body::Body, response::Response, routing::get, Router}; +use axum::{ + body::Body, + extract::Request, + response::{IntoResponse, Response}, + routing::get, + Router, ServiceExt, +}; use chrono::Utc; +use http::StatusCode; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use std::{ convert::Infallible, @@ -7,14 +14,17 @@ use std::{ time::Duration, }; use tokio::time::sleep; -use tower_http::{compression::CompressionLayer, trace::TraceLayer}; +use tower::Layer; +use tower_http::{ + compression::CompressionLayer, normalize_path::NormalizePathLayer, trace::TraceLayer, +}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod alerts; mod bus; +mod gtfs; mod routes; -mod static_data; -mod trips; +mod train; pub mod feed { include!(concat!(env!("OUT_DIR"), "/transit_realtime.rs")); @@ -37,13 +47,15 @@ async fn main() { .expect("DATABASE_URL env not set") .parse() .unwrap(); - // pg_connect_option = pg_connect_option.disable_statement_logging(); let pool = PgPoolOptions::new() .max_connections(100) .connect_with(pg_connect_option) .await - .unwrap(); - sqlx::migrate!().run(&pool).await.unwrap(); + .expect("Failed to create postgres pool"); + sqlx::migrate!() + .run(&pool) + .await + .expect("Failed to run database migrations"); let s_pool = pool.clone(); tokio::spawn(async move { @@ -79,7 +91,7 @@ async fn main() { tracing::info!("Updating stops and trips"); bus::static_data::stops_and_routes(&s_pool).await; - static_data::stops_and_routes(&s_pool).await; + train::static_data::stops_and_routes(&s_pool).await; // remove old update_ats sqlx::query!("DELETE FROM last_update") @@ -94,10 +106,9 @@ async fn main() { } }); + train::trips::import(pool.clone()).await; bus::trips::import(pool.clone()).await; bus::positions::import(pool.clone()).await; - - trips::import(pool.clone()).await; alerts::import(pool.clone()).await; let app = Router::new() @@ -109,25 +120,39 @@ async fn main() { Ok::<_, Infallible>(res) }), ) + // trains .route("/stops", get(routes::stops::get)) .route("/stops/times", get(routes::stops::times)) .route("/trips", get(routes::trips::get)) .route("/trips/:id", get(routes::trips::by_id)) - .route("/alerts", get(routes::alerts::get)) // bus stuff .route("/bus/stops", get(routes::bus::stops::get)) .route("/bus/stops/times", get(routes::bus::stops::times)) .route("/bus/trips", get(routes::bus::trips::get)) .route("/bus/trips/:id", get(routes::bus::trips::by_id)) .route("/bus/routes", get(routes::bus::routes::get)) + // alerts + .route("/alerts", get(routes::alerts::get)) .layer(TraceLayer::new_for_http()) .layer(CompressionLayer::new()) - .with_state(pool); + .with_state(pool) + .fallback(handler_404); + + // Need to specify normalize path layer like this so it runs before routing + let app = NormalizePathLayer::trim_trailing_slash().layer(app); + let listener = - tokio::net::TcpListener::bind(var("ADDRESS").unwrap_or_else(|_| "0.0.0.0:3055".into())) + tokio::net::TcpListener::bind(var("ADDRESS").unwrap_or_else(|_| "127.0.0.1:3055".into())) .await .unwrap(); tracing::info!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); + // https://github.com/tokio-rs/axum/discussions/2377 need to specify types bc of normalize path layer + axum::serve(listener, ServiceExt::::into_make_service(app)) + .await + .unwrap(); +} + +async fn handler_404() -> impl IntoResponse { + (StatusCode::NOT_FOUND, "not found") } diff --git a/backend/src/routes/alerts.rs b/backend/src/routes/alerts.rs index 08232d7..0f7e763 100644 --- a/backend/src/routes/alerts.rs +++ b/backend/src/routes/alerts.rs @@ -32,7 +32,7 @@ pub async fn get( time: CurrentTime, ) -> Result { // query is different depending on if they are asking for live data - + // TODO: allow specifying route ids and bus route ids let alerts = { if time.1 { sqlx::query_as!( diff --git a/backend/src/routes/trips.rs b/backend/src/routes/trips.rs index d20094f..c190c21 100644 --- a/backend/src/routes/trips.rs +++ b/backend/src/routes/trips.rs @@ -1,6 +1,6 @@ use crate::{ routes::{errors::ServerError, parse_list, CurrentTime}, - trips::STOP_TIMES_RESPONSE, + train::trips::STOP_TIMES_RESPONSE, }; use axum::{ extract::{Path, State}, diff --git a/backend/src/train/mod.rs b/backend/src/train/mod.rs new file mode 100644 index 0000000..7618c78 --- /dev/null +++ b/backend/src/train/mod.rs @@ -0,0 +1,2 @@ +pub mod static_data; +pub mod trips; diff --git a/backend/src/static_data.rs b/backend/src/train/static_data.rs similarity index 100% rename from backend/src/static_data.rs rename to backend/src/train/static_data.rs diff --git a/backend/src/trips.rs b/backend/src/train/trips.rs similarity index 95% rename from backend/src/trips.rs rename to backend/src/train/trips.rs index f5ed230..4922c83 100644 --- a/backend/src/trips.rs +++ b/backend/src/train/trips.rs @@ -1,14 +1,15 @@ -use crate::feed::{self, trip_update::StopTimeUpdate, TripDescriptor}; -use crate::routes::trips::Trip as TripRow; +use crate::{ + feed::{trip_update::StopTimeUpdate, TripDescriptor}, + gtfs::decode, + routes::trips::Trip as TripRow, +}; use chrono::{DateTime, NaiveDateTime, NaiveTime, TimeZone, Utc}; use once_cell::sync::Lazy; -use prost::{DecodeError, Message}; +use prost::DecodeError; use rayon::prelude::*; use sqlx::{PgPool, QueryBuilder}; -use std::env::var; use std::time::Duration; use thiserror::Error; -use tokio::fs::{create_dir, write}; use tokio::sync::Mutex; use tokio::time::sleep; use uuid::Uuid; @@ -57,11 +58,11 @@ pub static STOP_TIMES_RESPONSE: Lazy>> = Lazy::new(|| Mutex:: pub async fn import(pool: PgPool) { tokio::spawn(async move { loop { - let futures = (0..ENDPOINTS.len()).map(|i| decode_feed(&pool, ENDPOINTS[i])); + let futures = (0..ENDPOINTS.len()).map(|i| parse_gtfs(&pool, ENDPOINTS[i])); let _ = futures::future::join_all(futures).await; cache_stop_times(&pool).await.unwrap(); // for endpoint in ENDPOINTS.iter() { - // match decode_feed(&pool, endpoint).await { + // match parse_gtfs(&pool, endpoint).await { // Ok(_) => (), // Err(e) => { // tracing::error!("Error importing data: {:?}", e); @@ -147,29 +148,6 @@ pub enum IntoTripError { StopId, } -impl TripDescriptor { - // result is (route_id, express) - pub fn parse_route_id(&self) -> Result<(String, bool), IntoTripError> { - self.route_id - .as_ref() - .ok_or(IntoTripError::RouteId) - .map(|id| { - let mut route_id = id.to_owned(); - if route_id == "SS" { - route_id = "SI".to_string(); - // TODO: set express to true for SS - }; - - let mut express = false; - if route_id.ends_with('X') { - route_id.pop(); - express = true; - } - (route_id, express) - }) - } -} - impl TryFrom for Trip { type Error = IntoTripError; @@ -251,6 +229,29 @@ impl TryFrom for Trip { } } +impl TripDescriptor { + // result is (route_id, express) + pub fn parse_route_id(&self) -> Result<(String, bool), IntoTripError> { + self.route_id + .as_ref() + .ok_or(IntoTripError::RouteId) + .map(|id| { + let mut route_id = id.to_owned(); + if route_id == "SS" { + route_id = "SI".to_string(); + // TODO: set express to true for SS + }; + + let mut express = false; + if route_id.ends_with('X') { + route_id.pop(); + express = true; + } + (route_id, express) + }) + } +} + impl Trip { // finds trip in db by matching mta_id, train_id, created_at, and direction, returns true if found pub async fn find(&mut self, pool: &PgPool) -> Result { @@ -362,25 +363,15 @@ pub struct StopTime { departure: DateTime, } -pub async fn decode_feed(pool: &PgPool, endpoint: &str) -> Result<(), DecodeFeedError> { - let data = reqwest::Client::new() - .get(format!( +pub async fn parse_gtfs(pool: &PgPool, endpoint: &str) -> Result<(), DecodeFeedError> { + let feed = decode( + &format!( "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs{}", endpoint - )) - .send() - .await? - .bytes() - .await?; - - let feed = feed::FeedMessage::decode(data)?; - if var("DEBUG_GTFS").is_ok() { - let msgs = format!("{:#?}", feed); - create_dir("./gtfs").await.ok(); - write(format!("./gtfs/trains{}.txt", &endpoint), msgs) - .await - .unwrap(); - } + ), + &format!("train{}", endpoint), + ) + .await?; let mut trips: Vec = Vec::new(); let mut stop_times: Vec = Vec::new();