Skip to content

Commit 512f065

Browse files
editoast: fix projection endpoint
Co-authored-by: Youness CHRIFI ALAOUI <[email protected]> Signed-off-by: Florian Amsallem <[email protected]> Signed-off-by: Youness CHRIFI ALAOUI <[email protected]>
1 parent 4c340c5 commit 512f065

File tree

2 files changed

+126
-36
lines changed

2 files changed

+126
-36
lines changed

editoast/src/views/train_schedule.rs

+95-6
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,8 @@ async fn get_path(
796796
#[cfg(test)]
797797
mod tests {
798798
use axum::http::StatusCode;
799+
use chrono::DateTime;
800+
use chrono::Utc;
799801
use pretty_assertions::assert_eq;
800802
use rstest::rstest;
801803
use serde_json::json;
@@ -920,8 +922,7 @@ mod tests {
920922
)
921923
}
922924

923-
async fn app_infra_id_train_schedule_id_for_simulation_tests() -> (TestApp, i64, i64) {
924-
let db_pool = DbConnectionPoolV2::for_tests();
925+
fn mocked_core_pathfinding_sim_and_proj(train_id: i64) -> MockingClient {
925926
let mut core = MockingClient::new();
926927
core.stub("/v2/pathfinding/blocks")
927928
.method(reqwest::Method::POST)
@@ -975,10 +976,18 @@ mod tests {
975976
}
976977
}))
977978
.finish();
978-
let app = TestAppBuilder::new()
979-
.db_pool(db_pool.clone())
980-
.core_client(core.into())
981-
.build();
979+
core.stub("/v2/signal_projection")
980+
.method(reqwest::Method::POST)
981+
.response(StatusCode::OK)
982+
.json(json!({
983+
"signal_updates": {train_id.to_string(): [] },
984+
}))
985+
.finish();
986+
core
987+
}
988+
989+
async fn app_infra_id_train_schedule_id_for_simulation_tests() -> (TestApp, i64, i64) {
990+
let db_pool = DbConnectionPoolV2::for_tests();
982991
let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
983992
let rolling_stock =
984993
create_fast_rolling_stock(&mut db_pool.get_ok(), "simulation_rolling_stock").await;
@@ -997,6 +1006,11 @@ mod tests {
9971006
.create(&mut db_pool.get_ok())
9981007
.await
9991008
.expect("Failed to create train schedule");
1009+
let core = mocked_core_pathfinding_sim_and_proj(train_schedule.id);
1010+
let app = TestAppBuilder::new()
1011+
.db_pool(db_pool.clone())
1012+
.core_client(core.into())
1013+
.build();
10001014
(app, small_infra.id, train_schedule.id)
10011015
}
10021016

@@ -1024,4 +1038,79 @@ mod tests {
10241038
}));
10251039
app.fetch(request).assert_status(StatusCode::OK);
10261040
}
1041+
1042+
#[derive(Deserialize)]
1043+
struct PartialProjectPathTrainResult {
1044+
departure_time: DateTime<Utc>,
1045+
// Ignore the rest of the payload
1046+
}
1047+
1048+
#[rstest]
1049+
async fn train_schedule_project_path() {
1050+
// SETUP
1051+
let db_pool = DbConnectionPoolV2::for_tests();
1052+
1053+
let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
1054+
let rolling_stock =
1055+
create_fast_rolling_stock(&mut db_pool.get_ok(), "simulation_rolling_stock").await;
1056+
let timetable = create_timetable(&mut db_pool.get_ok()).await;
1057+
let train_schedule_base: TrainScheduleBase = TrainScheduleBase {
1058+
rolling_stock_name: rolling_stock.name.clone(),
1059+
..serde_json::from_str(include_str!("../tests/train_schedules/simple.json"))
1060+
.expect("Unable to parse")
1061+
};
1062+
let train_schedule: Changeset<TrainSchedule> = TrainScheduleForm {
1063+
timetable_id: Some(timetable.id),
1064+
train_schedule: train_schedule_base.clone(),
1065+
}
1066+
.into();
1067+
let train_schedule_valid = train_schedule
1068+
.create(&mut db_pool.get_ok())
1069+
.await
1070+
.expect("Failed to create train schedule");
1071+
1072+
let train_schedule_fail: Changeset<TrainSchedule> = TrainScheduleForm {
1073+
timetable_id: Some(timetable.id),
1074+
train_schedule: TrainScheduleBase {
1075+
rolling_stock_name: "fail".to_string(),
1076+
start_time: DateTime::from_timestamp(0, 0).unwrap(),
1077+
..train_schedule_base.clone()
1078+
},
1079+
}
1080+
.into();
1081+
1082+
let train_schedule_fail = train_schedule_fail
1083+
.create(&mut db_pool.get_ok())
1084+
.await
1085+
.expect("Failed to create train schedule");
1086+
1087+
let core = mocked_core_pathfinding_sim_and_proj(train_schedule_valid.id);
1088+
let app = TestAppBuilder::new()
1089+
.db_pool(db_pool.clone())
1090+
.core_client(core.into())
1091+
.build();
1092+
1093+
// TEST
1094+
let request = app.post("/train_schedule/project_path").json(&json!({
1095+
"infra_id": small_infra.id,
1096+
"electrical_profile_set_id": null,
1097+
"ids": vec![train_schedule_fail.id, train_schedule_valid.id],
1098+
"path": {
1099+
"track_section_ranges": [
1100+
{"track_section": "TA1", "begin": 0, "end": 100, "direction": "START_TO_STOP"}
1101+
],
1102+
"routes": [],
1103+
"blocks": []
1104+
}
1105+
}));
1106+
let response: HashMap<i64, PartialProjectPathTrainResult> =
1107+
app.fetch(request).assert_status(StatusCode::OK).json_into();
1108+
1109+
// EXPECT
1110+
assert_eq!(response.len(), 1);
1111+
assert_eq!(
1112+
response[&train_schedule_valid.id].departure_time,
1113+
train_schedule_base.start_time
1114+
);
1115+
}
10271116
}

editoast/src/views/train_schedule/projection.rs

+31-30
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ async fn project_path(
198198
let mut trains_hash_values = vec![];
199199
let mut trains_details = vec![];
200200

201-
for (sim, pathfinding_result) in simulations {
201+
for (train, (sim, pathfinding_result)) in izip!(&trains, simulations) {
202202
let track_ranges = match pathfinding_result {
203203
PathfindingResult::Success(PathfindingResultSuccess {
204204
track_section_ranges,
@@ -221,6 +221,7 @@ async fn project_path(
221221
} = report_train;
222222

223223
let train_details = TrainSimulationDetails {
224+
train_id: train.id,
224225
positions,
225226
times,
226227
signal_critical_positions,
@@ -242,17 +243,13 @@ async fn project_path(
242243
let cached_projections: Vec<Option<CachedProjectPathTrainResult>> =
243244
valkey_conn.json_get_bulk(&trains_hash_values).await?;
244245

245-
let mut hit_cache: HashMap<i64, CachedProjectPathTrainResult> = HashMap::new();
246-
let mut miss_cache = HashMap::new();
247-
for (train_details, projection, train_id) in izip!(
248-
trains_details,
249-
cached_projections,
250-
trains.iter().map(|t| t.id)
251-
) {
246+
let mut hit_cache = vec![];
247+
let mut miss_cache = vec![];
248+
for (train_details, projection) in izip!(&trains_details, cached_projections) {
252249
if let Some(cached) = projection {
253-
hit_cache.insert(train_id, cached);
250+
hit_cache.push((cached, train_details.train_id));
254251
} else {
255-
miss_cache.insert(train_id, train_details.clone());
252+
miss_cache.push(train_details.clone());
256253
}
257254
}
258255

@@ -277,47 +274,47 @@ async fn project_path(
277274
let signal_updates = signal_updates?;
278275

279276
// 3. Store the projection in the cache (using pipeline)
280-
let trains_hash_values: HashMap<_, _> = trains
277+
let trains_hash_values: HashMap<_, _> = trains_details
281278
.iter()
282-
.map(|t| t.id)
279+
.map(|t| t.train_id)
283280
.zip(trains_hash_values)
284281
.collect();
285282
let mut new_items = vec![];
286-
for id in miss_cache.keys() {
287-
let hash = &trains_hash_values[id];
283+
for train_id in miss_cache.iter().map(|t| t.train_id) {
284+
let hash = &trains_hash_values[&train_id];
288285
let cached_value = CachedProjectPathTrainResult {
289286
space_time_curves: space_time_curves
290-
.get(id)
287+
.get(&train_id)
291288
.expect("Space time curves not available for train")
292289
.clone(),
293290
signal_updates: signal_updates
294-
.get(id)
291+
.get(&train_id)
295292
.expect("Signal update not available for train")
296293
.clone(),
297294
};
298-
hit_cache.insert(*id, cached_value.clone());
295+
hit_cache.push((cached_value.clone(), train_id));
299296
new_items.push((hash, cached_value));
300297
}
301298
valkey_conn.json_set_bulk(&new_items).await?;
302299

303300
let train_map: HashMap<i64, TrainSchedule> = trains.into_iter().map(|ts| (ts.id, ts)).collect();
304301

305302
// 4.1 Fetch rolling stock length
306-
let mut project_path_result = HashMap::new();
307303
let rolling_stock_length: HashMap<_, _> = rolling_stocks
308304
.into_iter()
309305
.map(|rs| (rs.name, rs.length))
310306
.collect();
311307

312308
// 4.2 Build the projection response
313-
for (id, cached) in hit_cache {
314-
let train = train_map.get(&id).expect("Train not found");
309+
let mut project_path_result = HashMap::new();
310+
for (cached, train_id) in hit_cache {
311+
let train = train_map.get(&train_id).expect("Train not found");
315312
let length = rolling_stock_length
316313
.get(&train.rolling_stock_name)
317314
.expect("Rolling stock length not found");
318315

319316
project_path_result.insert(
320-
id,
317+
train_id,
321318
ProjectPathTrainResult {
322319
departure_time: train.start_time,
323320
rolling_stock_length: (length * 1000.).round() as u64,
@@ -332,6 +329,7 @@ async fn project_path(
332329
/// Input for the projection of a train schedule on a path
333330
#[derive(Debug, Clone, Hash)]
334331
struct TrainSimulationDetails {
332+
train_id: i64,
335333
positions: Vec<u64>,
336334
times: Vec<u64>,
337335
train_path: Vec<TrackRange>,
@@ -346,7 +344,7 @@ async fn compute_batch_signal_updates<'a>(
346344
path_track_ranges: &'a Vec<TrackRange>,
347345
path_routes: &'a Vec<Identifier>,
348346
path_blocks: &'a Vec<Identifier>,
349-
trains_details: &'a HashMap<i64, TrainSimulationDetails>,
347+
trains_details: &'a [TrainSimulationDetails],
350348
) -> Result<HashMap<i64, Vec<SignalUpdate>>> {
351349
if trains_details.is_empty() {
352350
return Ok(HashMap::new());
@@ -359,13 +357,13 @@ async fn compute_batch_signal_updates<'a>(
359357
blocks: path_blocks,
360358
train_simulations: trains_details
361359
.iter()
362-
.map(|(id, details)| {
360+
.map(|detail| {
363361
(
364-
*id,
362+
detail.train_id,
365363
TrainSimulation {
366-
signal_critical_positions: &details.signal_critical_positions,
367-
zone_updates: &details.zone_updates,
368-
simulation_end_time: details.times[details.times.len() - 1],
364+
signal_critical_positions: &detail.signal_critical_positions,
365+
zone_updates: &detail.zone_updates,
366+
simulation_end_time: detail.times[detail.times.len() - 1],
369367
},
370368
)
371369
})
@@ -377,14 +375,14 @@ async fn compute_batch_signal_updates<'a>(
377375

378376
/// Compute space time curves of a list of train schedules
379377
async fn compute_batch_space_time_curves<'a>(
380-
trains_details: &HashMap<i64, TrainSimulationDetails>,
378+
trains_details: &Vec<TrainSimulationDetails>,
381379
path_projection: &PathProjection<'a>,
382380
) -> HashMap<i64, Vec<SpaceTimeCurve>> {
383381
let mut space_time_curves = HashMap::new();
384382

385-
for (train_id, train_detail) in trains_details {
383+
for train_detail in trains_details {
386384
space_time_curves.insert(
387-
*train_id,
385+
train_detail.train_id,
388386
compute_space_time_curves(train_detail, path_projection),
389387
);
390388
}
@@ -584,6 +582,7 @@ mod tests {
584582
];
585583

586584
let project_path_input = TrainSimulationDetails {
585+
train_id: 0,
587586
positions,
588587
times,
589588
train_path,
@@ -618,6 +617,7 @@ mod tests {
618617
];
619618

620619
let project_path_input = TrainSimulationDetails {
620+
train_id: 0,
621621
positions: positions.clone(),
622622
times: times.clone(),
623623
train_path,
@@ -655,6 +655,7 @@ mod tests {
655655
let path_projection = PathProjection::new(&path);
656656

657657
let project_path_input = TrainSimulationDetails {
658+
train_id: 0,
658659
positions,
659660
times,
660661
train_path,

0 commit comments

Comments
 (0)