Skip to content

Commit d23f3c7

Browse files
younesschrififlomonster
authored andcommitted
editoast: compute train simulation in batch
Signed-off-by: Youness CHRIFI ALAOUI <[email protected]>
1 parent 5bd1d85 commit d23f3c7

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

editoast/src/views/train_schedule.rs

+32-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::collections::HashMap;
55
use std::collections::HashSet;
66
use std::hash::Hash;
77
use std::hash::Hasher;
8+
use std::iter;
89
use std::sync::Arc;
910

1011
use axum::extract::Json;
@@ -86,6 +87,8 @@ editoast_common::schemas! {
8687
projection::schemas(),
8788
}
8889

90+
pub const TRAIN_SIZE_BATCH: usize = 500;
91+
8992
#[derive(Debug, Error, EditoastError)]
9093
#[editoast_error(base_id = "train_schedule")]
9194
#[allow(clippy::enum_variant_names)] // Variant have the same postfix by chance, it's not a problem
@@ -383,28 +386,45 @@ pub async fn train_simulation_batch(
383386
electrical_profile_set_id: Option<i64>,
384387
) -> Result<Vec<(SimulationResponse, PathfindingResult)>> {
385388
// Compute path
389+
390+
let train_batches = train_schedules.chunks(TRAIN_SIZE_BATCH);
391+
386392
let rolling_stocks_ids = train_schedules
387393
.iter()
388394
.map::<String, _>(|t| t.rolling_stock_name.clone());
389395

390-
let (rolling_stocks, _): (Vec<_>, HashSet<String>) =
391-
RollingStockModel::retrieve_batch(conn, rolling_stocks_ids).await?;
396+
let rolling_stocks: Vec<_> =
397+
RollingStockModel::retrieve_batch_unchecked(&mut conn.clone(), rolling_stocks_ids).await?;
392398

393399
let consists: Vec<PhysicsConsistParameters> = rolling_stocks
394400
.into_iter()
395401
.map(|rs| PhysicsConsistParameters::from_traction_engine(rs.into()))
396402
.collect();
397403

398-
consist_train_simulation_batch(
399-
conn,
400-
valkey_client,
401-
core.clone(),
402-
infra,
403-
train_schedules,
404-
&consists,
405-
electrical_profile_set_id,
406-
)
407-
.await
404+
let consists_ref = &consists;
405+
let futures: Vec<_> = train_batches
406+
.zip(iter::repeat(conn.clone()))
407+
.map(|(chunk, conn)| {
408+
let valkey_client = valkey_client.clone();
409+
let core = core.clone();
410+
async move {
411+
consist_train_simulation_batch(
412+
&mut conn.clone(),
413+
valkey_client.clone(),
414+
core.clone(),
415+
infra,
416+
chunk,
417+
consists_ref,
418+
electrical_profile_set_id,
419+
)
420+
.await
421+
}
422+
})
423+
.collect();
424+
425+
let results = futures::future::try_join_all(futures).await.unwrap();
426+
let results = results.into_iter().flatten().collect();
427+
Ok(results)
408428
}
409429

410430
pub async fn consist_train_simulation_batch(

0 commit comments

Comments
 (0)