Skip to content

Commit

Permalink
editoast: add consist params validation
Browse files Browse the repository at this point in the history
Signed-off-by: Egor Berezovskiy <[email protected]>
  • Loading branch information
Wadjetz committed Feb 4, 2025
1 parent aa0fe72 commit 6681143
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 11 deletions.
84 changes: 73 additions & 11 deletions editoast/src/views/timetable/stdcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use utoipa::IntoParams;
use utoipa::ToSchema;
use validator::Validate;

use crate::core::conflict_detection::Conflict;
use crate::core::conflict_detection::TrainRequirements;
Expand Down Expand Up @@ -100,6 +99,8 @@ enum StdcmError {
TrainSimulationFail,
#[error("Path items are invalid")]
InvalidPathItems { items: Vec<InvalidPathItem> },
#[error("Request validation fail: {message}")]
RequestValidation { message: String },
}

#[derive(Debug, Default, Clone, Serialize, Deserialize, IntoParams, ToSchema)]
Expand Down Expand Up @@ -156,7 +157,6 @@ async fn stdcm(

let trace_id = Some(trace_id).filter(|trace_id| *trace_id != TraceId::INVALID);

stdcm_request.validate()?;
let mut conn = db_pool.get().await?;

let timetable_id = id;
Expand All @@ -178,14 +178,19 @@ async fn stdcm(
.await?
.into();

let towed_rolling_stock = stdcm_request
.get_towed_rolling_stock(&mut conn)
.await?
.map(From::from);

stdcm_request.validate_request_mass(&rolling_stock, &towed_rolling_stock)?;
stdcm_request.validate_consist_length(&rolling_stock, &towed_rolling_stock)?;

let physics_consist_parameters = PhysicsConsistParameters {
max_speed: stdcm_request.max_speed,
total_length: stdcm_request.total_length,
total_mass: stdcm_request.total_mass,
towed_rolling_stock: stdcm_request
.get_towed_rolling_stock(&mut conn)
.await?
.map(From::from),
towed_rolling_stock,
traction_engine: rolling_stock,
};

Expand Down Expand Up @@ -579,6 +584,10 @@ mod tests {
use rstest::rstest;
use serde_json::json;
use std::str::FromStr;
use uom::si::length::meter;
use uom::si::length::Length;
use uom::si::mass::kilogram;
use uom::si::quantities::Mass;
use uuid::Uuid;

use crate::core::conflict_detection::Conflict;
Expand Down Expand Up @@ -606,7 +615,12 @@ mod tests {

use super::*;

fn get_stdcm_payload(rolling_stock_id: i64, work_schedule_group_id: Option<i64>) -> Request {
fn get_stdcm_payload(
rolling_stock_id: i64,
work_schedule_group_id: Option<i64>,
total_mass: Option<f64>,
total_length: Option<f64>,
) -> Request {
Request {
start_time: Some(
DateTime::from_str("2024-01-01T10:00:00Z").expect("Failed to parse datetime"),
Expand Down Expand Up @@ -656,8 +670,8 @@ mod tests {
time_gap_before: 35000,
time_gap_after: 35000,
margin: Some(MarginValue::MinPer100Km(4.5)),
total_mass: None,
total_length: None,
total_mass: total_mass.map(|v| Mass::new::<kilogram>(v)),
total_length: total_length.map(|v| Length::new::<meter>(v)),
max_speed: None,
loading_gauge_type: None,
}
Expand Down Expand Up @@ -938,7 +952,53 @@ mod tests {

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None));
.json(&get_stdcm_payload(rolling_stock.id, None, None, None));

let stdcm_response: StdcmResponse =
app.fetch(request).assert_status(StatusCode::OK).json_into();

if let PathfindingResult::Success(path) =
PathfindingResult::Success(pathfinding_result_success())
{
assert_eq!(
stdcm_response,
StdcmResponse::Success {
simulation: simulation_response(),
path,
departure_time: DateTime::from_str("2024-01-02T00:00:00Z")
.expect("Failed to parse datetime")
}
);
}
}

#[rstest]
async fn stdcm_request_mass_validation() {
let db_pool = DbConnectionPoolV2::for_tests();
// let mut core = core_mocking_client();
// core.stub("/v2/stdcm")
// .method(reqwest::Method::POST)
// .response(StatusCode::OK)
// .json(crate::core::stdcm::Response::Success {
// simulation: simulation_response(),
// path: pathfinding_result_success(),
// departure_time: DateTime::from_str("2024-01-02T00:00:00Z")
// .expect("Failed to parse datetime"),
// })
// .finish();

let app = TestAppBuilder::new()
.db_pool(db_pool.clone())
// .core_client(core.into())
.build();
let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
let timetable = create_timetable(&mut db_pool.get_ok()).await;
let rolling_stock =
create_fast_rolling_stock(&mut db_pool.get_ok(), &Uuid::new_v4().to_string()).await;

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None, Some(1.0), None));

let stdcm_response: StdcmResponse =
app.fetch(request).assert_status(StatusCode::OK).json_into();
Expand Down Expand Up @@ -986,7 +1046,7 @@ mod tests {

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None));
.json(&get_stdcm_payload(rolling_stock.id, None, None, None));

let stdcm_response: StdcmResponse =
app.fetch(request).assert_status(StatusCode::OK).json_into();
Expand Down Expand Up @@ -1057,6 +1117,8 @@ mod tests {
.json(&get_stdcm_payload(
rolling_stock.id,
Some(work_schedule_group.id),
None,
None,
));

let stdcm_response: StdcmResponse =
Expand Down
65 changes: 65 additions & 0 deletions editoast/src/views/timetable/stdcm/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use chrono::Utc;
use editoast_common::units;
use editoast_models::DbConnection;
use editoast_schemas::rolling_stock::LoadingGaugeType;
use editoast_schemas::rolling_stock::RollingStock;
use editoast_schemas::rolling_stock::TowedRollingStock;
use editoast_schemas::train_schedule::Comfort;
use editoast_schemas::train_schedule::MarginValue;
use editoast_schemas::train_schedule::PathItem;
Expand All @@ -14,6 +16,11 @@ use serde::Deserializer;
use serde::Serialize;
use serde::Serializer;
use units::quantities;
use uom::fmt::DisplayStyle;
use uom::si::length::meter;
use uom::si::length::Length;
use uom::si::mass::kilogram;
use uom::si::quantities::Mass;
use utoipa::ToSchema;
use validator::Validate;

Expand Down Expand Up @@ -297,6 +304,64 @@ impl Request {
.await?;
Ok(Some(towed_rolling_stock))
}

pub(super) fn validate_request_mass(
&self,
traction_engine: &RollingStock,
towed_rolling_stock: &Option<TowedRollingStock>,
) -> Result<()> {
let max = Mass::new::<kilogram>(10000000.0);
let consist_mass = traction_engine.mass
+ towed_rolling_stock
.as_ref()
.map(|t| t.mass)
.unwrap_or_default();
dbg!(&consist_mass);
if let Some(request_total_mass) = self.total_mass {
if request_total_mass < consist_mass || request_total_mass > max {
return Err(StdcmError::RequestValidation {
message: format!(
"The total weight must be between {} and {}t",
&consist_mass.into_format_args(kilogram, DisplayStyle::Description),
&max.into_format_args(kilogram, DisplayStyle::Description),
),
}
.into());
}
}

Ok(())
}

pub(super) fn validate_consist_length(
&self,
traction_engine: &RollingStock,
towed_rolling_stock: &Option<TowedRollingStock>,
) -> Result<()> {
let max = Length::new::<meter>(750.0);
let consist_length = traction_engine.length
+ towed_rolling_stock
.as_ref()
.map(|t| t.length)
.unwrap_or_default();

let consist_length = consist_length.floor::<meter>();

if let Some(request_total_length) = self.total_length {
if request_total_length < consist_length || request_total_length > max {
return Err(StdcmError::RequestValidation {
message: format!(
"The total length must be between {} and {}",
&consist_length.into_format_args(meter, DisplayStyle::Description),
&max.into_format_args(meter, DisplayStyle::Description),
),
}
.into());
}
}

Ok(())
}
}

impl<'de> Deserialize<'de> for Request {
Expand Down

0 comments on commit 6681143

Please sign in to comment.