@@ -5,6 +5,7 @@ use std::collections::HashMap;
5
5
use std:: collections:: HashSet ;
6
6
use std:: hash:: Hash ;
7
7
use std:: hash:: Hasher ;
8
+ use std:: iter;
8
9
use std:: sync:: Arc ;
9
10
10
11
use axum:: extract:: Json ;
@@ -86,6 +87,8 @@ editoast_common::schemas! {
86
87
projection:: schemas( ) ,
87
88
}
88
89
90
+ pub const TRAIN_SIZE_BATCH : usize = 500 ;
91
+
89
92
#[ derive( Debug , Error , EditoastError ) ]
90
93
#[ editoast_error( base_id = "train_schedule" ) ]
91
94
#[ allow( clippy:: enum_variant_names) ] // Variant have the same postfix by chance, it's not a problem
@@ -383,28 +386,45 @@ pub async fn train_simulation_batch(
383
386
electrical_profile_set_id : Option < i64 > ,
384
387
) -> Result < Vec < ( SimulationResponse , PathfindingResult ) > > {
385
388
// Compute path
389
+
390
+ let train_batches = train_schedules. chunks ( TRAIN_SIZE_BATCH ) ;
391
+
386
392
let rolling_stocks_ids = train_schedules
387
393
. iter ( )
388
394
. map :: < String , _ > ( |t| t. rolling_stock_name . clone ( ) ) ;
389
395
390
- let ( rolling_stocks, _ ) : ( Vec < _ > , HashSet < String > ) =
391
- RollingStockModel :: retrieve_batch ( conn, rolling_stocks_ids) . await ?;
396
+ let rolling_stocks: Vec < _ > =
397
+ RollingStockModel :: retrieve_batch_unchecked ( & mut conn. clone ( ) , rolling_stocks_ids) . await ?;
392
398
393
399
let consists: Vec < PhysicsConsistParameters > = rolling_stocks
394
400
. into_iter ( )
395
401
. map ( |rs| PhysicsConsistParameters :: from_traction_engine ( rs. into ( ) ) )
396
402
. collect ( ) ;
397
403
398
- consist_train_simulation_batch (
399
- conn,
400
- valkey_client,
401
- core. clone ( ) ,
402
- infra,
403
- train_schedules,
404
- & consists,
405
- electrical_profile_set_id,
406
- )
407
- . await
404
+ let consists_ref = & consists;
405
+ let futures: Vec < _ > = train_batches
406
+ . zip ( iter:: repeat ( conn. clone ( ) ) )
407
+ . map ( |( chunk, conn) | {
408
+ let valkey_client = valkey_client. clone ( ) ;
409
+ let core = core. clone ( ) ;
410
+ async move {
411
+ consist_train_simulation_batch (
412
+ & mut conn. clone ( ) ,
413
+ valkey_client. clone ( ) ,
414
+ core. clone ( ) ,
415
+ infra,
416
+ chunk,
417
+ consists_ref,
418
+ electrical_profile_set_id,
419
+ )
420
+ . await
421
+ }
422
+ } )
423
+ . collect ( ) ;
424
+
425
+ let results = futures:: future:: try_join_all ( futures) . await . unwrap ( ) ;
426
+ let results = results. into_iter ( ) . flatten ( ) . collect ( ) ;
427
+ Ok ( results)
408
428
}
409
429
410
430
pub async fn consist_train_simulation_batch (
0 commit comments