Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup stdcm editoast endpoint 🚀 #10692

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 121 additions & 18 deletions editoast/src/models/timetable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use diesel::prelude::*;
use diesel::sql_query;
use diesel::sql_types::Array;
use diesel::sql_types::BigInt;
use diesel::sql_types::Timestamptz;
use diesel_async::RunQueryDsl;
use editoast_derive::Model;
use futures_util::stream::TryStreamExt;
Expand Down Expand Up @@ -55,27 +56,42 @@ impl Timetable {
.map_err(Into::into)
}

pub async fn schedules_before_date(
/// This function will return all train schedules in a timetable that runs within the time window.
///
/// **IMPORTANT**: The filter is based on the scheduled arrival time and not the actual simulated arrival time.
///
/// The diagram below shows a list of trains in a timetable:
/// `?`: unscheduled arrival times.
/// `|`: scheduled arrival times.
///
/// ```
/// min_time max_time
/// Time Window |----------------------|
/// Train 1 ✅ |-------------|
/// Train 2 ❌ |-------|
/// Train 3 ✅ |----------|
/// Train 4 ✅ |-------?
/// Train 5 ✅ |---------?
/// Train 6 ❌ |------?
/// ```
pub async fn schedules_in_time_window(
self,
conn: &mut DbConnection,
time: DateTime<Utc>,
min_time: DateTime<Utc>,
max_time: DateTime<Utc>,
) -> Result<Vec<TrainSchedule>> {
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use editoast_models::tables::train_schedule::dsl;

let train_schedules = dsl::train_schedule
.filter(dsl::start_time.le(time))
.filter(dsl::timetable_id.eq(self.id))
.load_stream::<Row<TrainSchedule>>(conn.write().await.deref_mut())
.await?
.map_ok(|ts| ts.into())
.try_collect::<Vec<TrainSchedule>>()
.await;
match train_schedules {
Ok(train_schedules) => Ok(train_schedules),
Err(err) => Err(err.into()),
}
let train_schedules = sql_query(include_str!(
"timetable/sql/get_train_schedules_in_time_window.sql"
))
.bind::<BigInt, _>(self.id)
.bind::<Timestamptz, _>(max_time)
.bind::<Timestamptz, _>(min_time)
.load_stream::<Row<TrainSchedule>>(conn.write().await.deref_mut())
.await?
.map_ok(|ts| ts.into())
.try_collect::<Vec<TrainSchedule>>()
.await;
train_schedules.map_err(|e| e.into())
}
}

Expand Down Expand Up @@ -117,3 +133,90 @@ impl From<TimetableWithTrains> for Timetable {
}
}
}

#[cfg(test)]
pub mod tests {
use chrono::Duration;
use chrono::TimeZone;
use chrono::Utc;
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::collections::HashSet;

use super::*;
use crate::models::fixtures::{create_timetable, simple_train_schedule_base};
use crate::models::train_schedule::TrainScheduleChangeset;
use editoast_models::DbConnectionPoolV2;

#[rstest]
async fn test_schedules_in_time_window() {
let db_pool = DbConnectionPoolV2::for_tests();
let timetable = create_timetable(&mut db_pool.get_ok()).await;
// Note that this train has a last arrival at PT50M
let min_time = Utc.with_ymd_and_hms(2025, 1, 1, 12, 0, 0).unwrap();
let max_time = Utc.with_ymd_and_hms(2025, 1, 1, 14, 0, 0).unwrap();
let base_ts = simple_train_schedule_base();
TrainScheduleChangeset::from(base_ts.clone())
.timetable_id(timetable.id)
.train_name("Train 1".into())
.start_time(min_time - Duration::minutes(20))
.create(&mut db_pool.get_ok())
.await
.unwrap();
TrainScheduleChangeset::from(base_ts.clone())
.timetable_id(timetable.id)
.train_name("Train 2".into())
.start_time(min_time - Duration::hours(2))
.create(&mut db_pool.get_ok())
.await
.unwrap();
TrainScheduleChangeset::from(base_ts.clone())
.timetable_id(timetable.id)
.train_name("Train 3".into())
.start_time(max_time - Duration::minutes(5))
.create(&mut db_pool.get_ok())
.await
.unwrap();
TrainScheduleChangeset::from(base_ts.clone())
.timetable_id(timetable.id)
.train_name("Train 4".into())
.start_time(min_time - Duration::hours(2))
.schedule(vec![])
.create(&mut db_pool.get_ok())
.await
.unwrap();
TrainScheduleChangeset::from(base_ts.clone())
.timetable_id(timetable.id)
.train_name("Train 5".into())
.start_time(max_time - Duration::minutes(10))
.schedule(vec![])
.create(&mut db_pool.get_ok())
.await
.unwrap();
TrainScheduleChangeset::from(base_ts.clone())
.timetable_id(timetable.id)
.train_name("Train 6".into())
.start_time(max_time + Duration::minutes(10))
.schedule(vec![])
.create(&mut db_pool.get_ok())
.await
.unwrap();

// Test
let train_schedules = timetable
.schedules_in_time_window(&mut db_pool.get_ok(), min_time, max_time)
.await
.expect("Failed to get train schedules in time window");

// Expected: Train 1, Train 3, Train 4, Train 5
assert_eq!(train_schedules.len(), 4);
let train_names: HashSet<_> = train_schedules
.into_iter()
.map(|ts| ts.train_name)
.collect();
assert!(train_names.contains("Train 1"));
assert!(train_names.contains("Train 3"));
assert!(train_names.contains("Train 4"));
assert!(train_names.contains("Train 5"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
WITH last_path_step AS (
SELECT id,
jsonb_path_query(path, '$[last].id')::text as step_id
FROM train_schedule
WHERE train_schedule.start_time <= $2
AND train_schedule.timetable_id = $1
),
scheduled_path_step AS (
SELECT id,
jsonb_path_query(schedule, '$[*]') as schedule_point
FROM train_schedule
WHERE train_schedule.start_time <= $2
AND train_schedule.timetable_id = $1
),
arrival_time AS (
SELECT last_path_step.id,
(scheduled_path_step.schedule_point->>'arrival')::interval as arrival
FROM last_path_step
LEFT JOIN scheduled_path_step ON last_path_step.id = scheduled_path_step.id
AND last_path_step.step_id = (scheduled_path_step.schedule_point->'at')::text
)
SELECT train_schedule.id,
train_schedule.train_name,
train_schedule.labels,
train_schedule.rolling_stock_name,
train_schedule.timetable_id,
train_schedule.start_time,
train_schedule.schedule,
train_schedule.margins,
train_schedule.initial_speed,
train_schedule.comfort,
train_schedule.path,
train_schedule.constraint_distribution,
train_schedule.speed_limit_tag,
train_schedule.power_restrictions,
train_schedule.options
FROM train_schedule
LEFT JOIN arrival_time ON train_schedule.id = arrival_time.id
WHERE train_schedule.start_time <= $2
AND train_schedule.timetable_id = $1
AND (
arrival_time.arrival IS NULL
OR train_schedule.start_time + arrival_time.arrival >= $3
)
1 change: 1 addition & 0 deletions editoast/src/models/train_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::models::prelude::*;
#[derive(Debug, Default, Clone, Model)]
#[model(table = editoast_models::tables::train_schedule)]
#[model(gen(ops = crud, batch_ops = crd, list))]
#[model(row(derive(diesel::QueryableByName)))]
pub struct TrainSchedule {
pub id: i64,
pub train_name: String,
Expand Down
5 changes: 5 additions & 0 deletions editoast/src/views/path/path_item_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use super::pathfinding::PathfindingResult;
type TrackOffsetResult = std::result::Result<Vec<Vec<TrackOffset>>, PathfindingResult>;

/// Gather information about several path items, factorizing db calls.
#[derive(Default)]
pub struct PathItemCache {
uic_to_ops: HashMap<i64, Vec<OperationalPointModel>>,
trigram_to_ops: HashMap<String, Vec<OperationalPointModel>>,
Expand All @@ -37,6 +38,10 @@ impl PathItemCache {
infra_id: i64,
path_items: &[&PathItemLocation],
) -> Result<PathItemCache> {
if path_items.is_empty() {
return Ok(PathItemCache::default());
}

let (trigrams, ops_uic, ops_id) = collect_path_item_ids(path_items);
let uic_to_ops = retrieve_op_from_uic(conn, infra_id, &ops_uic).await?;
let trigram_to_ops = retrieve_op_from_trigrams(conn, infra_id, &trigrams).await?;
Expand Down
59 changes: 2 additions & 57 deletions editoast/src/views/timetable/stdcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,65 +216,10 @@ async fn stdcm(
})
.await?;

// Filter trains
// The goal is to filter out as many trains as possible whose schedules overlap
// with the LMR train being searched for.
// The diagram below shows an LMR train inserted into a timetable.

// '?': unscheduled arrival times.
// '|': scheduled arrival times.
// tA: earliest_departure_time
// tB: latest_simulation_end
//
// tA tB
// LMR Train |----------------------|
// Train 1 |--------------|
// Train 2 |------------|
// |----------? Train 3
// Train 4 |-------?
// Train 5 |---------?
// |----------? Train 6

// Step 1 (SQL Filter):
// Trains that depart after the latest arrival time of the LMR train are excluded.
// In this example, Train 3 and Train 6 are filtered out.

// It's not easy to write an SQL query to filter trains when the train departure time < latest_simulation_ended
// because there are two cases : when the train departure time > tA (Step 2) and the train departure time < tA (Step 3).

// Step 2 (Rust filter) :
// If the train departure time > LMR train departure (tA), the train is kept (e.g., train_5)
// Step 3 (Rust filter) :
// For trains departing before the LMR train departure (tA):

// If the train's arrival time is unscheduled (?), the train is kept (e.g., Train 4 and Train 5).
// If the train's arrival time is scheduled (|), the train is kept only if its arrival time is after the LMR train's earliest departure time.
// Train 1 is kept and train 2 is filtered out.

// Step 1
let mut train_schedules = timetable
.schedules_before_date(&mut conn, latest_simulation_end)
let train_schedules = timetable
.schedules_in_time_window(&mut conn, earliest_departure_time, latest_simulation_end)
.await?;

train_schedules.retain(|train_schedule| {
// Step 2 and 3
train_schedule.start_time >= earliest_departure_time
|| train_schedule
.schedule
.last()
.and_then(|last_schedule_item| {
train_schedule.path.last().and_then(|last_path_item| {
(last_schedule_item.at == last_path_item.id).then_some(last_schedule_item)
})
})
.and_then(|last_schedule_item| {
last_schedule_item.arrival.clone().map(|arrival| {
train_schedule.start_time + *arrival > earliest_departure_time
})
})
.unwrap_or(true)
});

// 3. Get scheduled train requirements
let simulations: Vec<_> = train_simulation_batch(
&mut conn,
Expand Down
40 changes: 24 additions & 16 deletions editoast/src/views/train_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use tracing::info;
use tracing::Instrument;
use utoipa::IntoParams;
use utoipa::ToSchema;

Expand Down Expand Up @@ -87,7 +88,7 @@ editoast_common::schemas! {
projection::schemas(),
}

pub const TRAIN_SIZE_BATCH: usize = 250;
pub const TRAIN_SIZE_BATCH: usize = 100;

#[derive(Debug, Error, EditoastError)]
#[editoast_error(base_id = "train_schedule")]
Expand Down Expand Up @@ -404,30 +405,37 @@ pub async fn train_simulation_batch(
.map(|rs| PhysicsConsistParameters::from_traction_engine(rs.into()))
.collect();

let consists_ref = &consists;
let futures: Vec<_> = train_batches
.zip(iter::repeat(conn.clone()))
.map(|(chunk, conn)| {
let valkey_client = valkey_client.clone();
let core = core.clone();
async move {
consist_train_simulation_batch(
&mut conn.clone(),
valkey_client.clone(),
core.clone(),
infra,
chunk,
consists_ref,
electrical_profile_set_id,
)
.await
}
let consists = consists.clone();
let infra = <Infra as Clone>::clone(infra);
let chunk = chunk.to_vec();
tokio::spawn(
async move {
consist_train_simulation_batch(
&mut conn.clone(),
valkey_client.clone(),
core.clone(),
&infra,
&chunk,
&consists,
electrical_profile_set_id,
)
.await
}
.in_current_span(),
)
})
.collect();

let results = futures::future::try_join_all(futures).await.unwrap();
let results = results.into_iter().flatten().collect();
Ok(results)
results
.into_iter()
.flatten_ok()
.collect::<Result<Vec<_>, _>>()
}

#[tracing::instrument(skip_all, fields(nb_trains = train_schedules.len()))]
Expand Down
Loading