Skip to content

Commit

Permalink
editoast: add consist params validation
Browse files Browse the repository at this point in the history
Signed-off-by: Egor Berezovskiy <[email protected]>
  • Loading branch information
Wadjetz committed Feb 4, 2025
1 parent 5d6174c commit e7b1c75
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 15 deletions.
25 changes: 25 additions & 0 deletions editoast/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4826,6 +4826,7 @@ components:
- $ref: '#/components/schemas/EditoastSpriteErrorsUnknownSignalingSystem'
- $ref: '#/components/schemas/EditoastStdcmErrorInfraNotFound'
- $ref: '#/components/schemas/EditoastStdcmErrorInvalidPathItems'
- $ref: '#/components/schemas/EditoastStdcmErrorRequestValidationFail'
- $ref: '#/components/schemas/EditoastStdcmErrorRollingStockNotFound'
- $ref: '#/components/schemas/EditoastStdcmErrorTimetableNotFound'
- $ref: '#/components/schemas/EditoastStdcmErrorTowedRollingStockNotFound'
Expand Down Expand Up @@ -6042,6 +6043,30 @@ components:
type: string
enum:
- editoast:stdcm_v2:InvalidPathItems
EditoastStdcmErrorRequestValidationFail:
type: object
required:
- type
- status
- message
properties:
context:
type: object
required:
- message
properties:
message:
type: string
message:
type: string
status:
type: integer
enum:
- 400
type:
type: string
enum:
- editoast:stdcm_v2:RequestValidationFail
EditoastStdcmErrorRollingStockNotFound:
type: object
required:
Expand Down
117 changes: 106 additions & 11 deletions editoast/src/views/timetable/stdcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use utoipa::IntoParams;
use utoipa::ToSchema;
use validator::Validate;

use crate::core::conflict_detection::Conflict;
use crate::core::conflict_detection::TrainRequirements;
Expand Down Expand Up @@ -100,6 +99,8 @@ enum StdcmError {
TrainSimulationFail,
#[error("Path items are invalid")]
InvalidPathItems { items: Vec<InvalidPathItem> },
#[error("Request validation failed: {message}")]
RequestValidationFail { message: String },
}

#[derive(Debug, Default, Clone, Serialize, Deserialize, IntoParams, ToSchema)]
Expand Down Expand Up @@ -156,7 +157,6 @@ async fn stdcm(

let trace_id = Some(trace_id).filter(|trace_id| *trace_id != TraceId::INVALID);

stdcm_request.validate()?;
let mut conn = db_pool.get().await?;

let timetable_id = id;
Expand All @@ -178,14 +178,18 @@ async fn stdcm(
.await?
.into();

let towed_rolling_stock = stdcm_request
.get_towed_rolling_stock(&mut conn)
.await?
.map(From::from);

stdcm_request.validate_consist(&rolling_stock, &towed_rolling_stock)?;

let physics_consist_parameters = PhysicsConsistParameters {
max_speed: stdcm_request.max_speed,
total_length: stdcm_request.total_length,
total_mass: stdcm_request.total_mass,
towed_rolling_stock: stdcm_request
.get_towed_rolling_stock(&mut conn)
.await?
.map(From::from),
towed_rolling_stock,
traction_engine: rolling_stock,
};

Expand Down Expand Up @@ -579,6 +583,10 @@ mod tests {
use rstest::rstest;
use serde_json::json;
use std::str::FromStr;
use uom::si::length::meter;
use uom::si::length::Length;
use uom::si::mass::kilogram;
use uom::si::quantities::Mass;
use uuid::Uuid;

use crate::core::conflict_detection::Conflict;
Expand All @@ -590,6 +598,7 @@ mod tests {
use crate::core::simulation::PhysicsConsist;
use crate::core::simulation::ReportTrain;
use crate::core::simulation::SpeedLimitProperties;
use crate::error::InternalError;
use crate::models::fixtures::create_fast_rolling_stock;
use crate::models::fixtures::create_simple_rolling_stock;
use crate::models::fixtures::create_small_infra;
Expand All @@ -606,7 +615,12 @@ mod tests {

use super::*;

fn get_stdcm_payload(rolling_stock_id: i64, work_schedule_group_id: Option<i64>) -> Request {
fn get_stdcm_payload(
rolling_stock_id: i64,
work_schedule_group_id: Option<i64>,
total_mass: Option<f64>,
total_length: Option<f64>,
) -> Request {
Request {
start_time: Some(
DateTime::from_str("2024-01-01T10:00:00Z").expect("Failed to parse datetime"),
Expand Down Expand Up @@ -656,8 +670,8 @@ mod tests {
time_gap_before: 35000,
time_gap_after: 35000,
margin: Some(MarginValue::MinPer100Km(4.5)),
total_mass: None,
total_length: None,
total_mass: total_mass.map(Mass::new::<kilogram>),
total_length: total_length.map(Length::new::<meter>),
max_speed: None,
loading_gauge_type: None,
}
Expand Down Expand Up @@ -938,7 +952,7 @@ mod tests {

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None));
.json(&get_stdcm_payload(rolling_stock.id, None, None, None));

let stdcm_response: StdcmResponse =
app.fetch(request).assert_status(StatusCode::OK).json_into();
Expand All @@ -958,6 +972,85 @@ mod tests {
}
}

#[rstest]
async fn stdcm_request_mass_validation() {
let db_pool = DbConnectionPoolV2::for_tests();
let mut core = core_mocking_client();
core.stub("/v2/stdcm")
.method(reqwest::Method::POST)
.response(StatusCode::OK)
.json(crate::core::stdcm::Response::Success {
simulation: simulation_response(),
path: pathfinding_result_success(),
departure_time: DateTime::from_str("2024-01-02T00:00:00Z")
.expect("Failed to parse datetime"),
})
.finish();

let app = TestAppBuilder::new()
.db_pool(db_pool.clone())
.core_client(core.into())
.build();
let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
let timetable = create_timetable(&mut db_pool.get_ok()).await;
let rolling_stock =
create_fast_rolling_stock(&mut db_pool.get_ok(), &Uuid::new_v4().to_string()).await;

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None, Some(1.0), None));

let stdcm_response: InternalError = app
.fetch(request)
.assert_status(StatusCode::BAD_REQUEST)
.json_into();

assert_eq!(
stdcm_response.message,
"Request validation failed: The total weight must be between 900000 kilograms and 10000000 kilograms".to_owned()
);
}

#[rstest]
async fn stdcm_request_length_validation() {
let db_pool = DbConnectionPoolV2::for_tests();
let mut core = core_mocking_client();
core.stub("/v2/stdcm")
.method(reqwest::Method::POST)
.response(StatusCode::OK)
.json(crate::core::stdcm::Response::Success {
simulation: simulation_response(),
path: pathfinding_result_success(),
departure_time: DateTime::from_str("2024-01-02T00:00:00Z")
.expect("Failed to parse datetime"),
})
.finish();

let app = TestAppBuilder::new()
.db_pool(db_pool.clone())
.core_client(core.into())
.build();
let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
let timetable = create_timetable(&mut db_pool.get_ok()).await;
let rolling_stock =
create_fast_rolling_stock(&mut db_pool.get_ok(), &Uuid::new_v4().to_string()).await;

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None, None, Some(1.0)));

let stdcm_response: InternalError = app
.fetch(request)
.assert_status(StatusCode::BAD_REQUEST)
.json_into();

assert_eq!(
stdcm_response.message,
"Request validation failed: The total length must be between 400 meters and 750 meters"
.to_owned()
);
}

#[rstest]
async fn stdcm_return_conflicts() {
let db_pool = DbConnectionPoolV2::for_tests();
Expand Down Expand Up @@ -986,7 +1079,7 @@ mod tests {

let request = app
.post(format!("/timetable/{}/stdcm?infra={}", timetable.id, small_infra.id).as_str())
.json(&get_stdcm_payload(rolling_stock.id, None));
.json(&get_stdcm_payload(rolling_stock.id, None, None, None));

let stdcm_response: StdcmResponse =
app.fetch(request).assert_status(StatusCode::OK).json_into();
Expand Down Expand Up @@ -1057,6 +1150,8 @@ mod tests {
.json(&get_stdcm_payload(
rolling_stock.id,
Some(work_schedule_group.id),
None,
None,
));

let stdcm_response: StdcmResponse =
Expand Down
76 changes: 74 additions & 2 deletions editoast/src/views/timetable/stdcm/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use chrono::Utc;
use editoast_common::units;
use editoast_models::DbConnection;
use editoast_schemas::rolling_stock::LoadingGaugeType;
use editoast_schemas::rolling_stock::RollingStock;
use editoast_schemas::rolling_stock::TowedRollingStock;
use editoast_schemas::train_schedule::Comfort;
use editoast_schemas::train_schedule::MarginValue;
use editoast_schemas::train_schedule::PathItem;
Expand All @@ -14,8 +16,12 @@ use serde::Deserializer;
use serde::Serialize;
use serde::Serializer;
use units::quantities;
use uom::fmt::DisplayStyle;
use uom::si::length::meter;
use uom::si::length::Length;
use uom::si::mass::kilogram;
use uom::si::quantities::Mass;
use utoipa::ToSchema;
use validator::Validate;

use crate::core::pathfinding::PathfindingInputError;
use crate::error::Result;
Expand Down Expand Up @@ -70,7 +76,7 @@ pub(crate) struct StepTimingData {

/// An STDCM request
#[editoast_derive::annotate_units]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Validate, ToSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)]
#[serde(remote = "Self")]
pub(crate) struct Request {
/// Deprecated, first step arrival time should be used instead
Expand Down Expand Up @@ -297,6 +303,72 @@ impl Request {
.await?;
Ok(Some(towed_rolling_stock))
}

pub(super) fn validate_consist(
&self,
traction_engine: &RollingStock,
towed_rolling_stock: &Option<TowedRollingStock>,
) -> Result<()> {
self.validate_consist_mass(traction_engine, towed_rolling_stock)?;
self.validate_consist_length(traction_engine, towed_rolling_stock)?;
Ok(())
}

fn validate_consist_mass(
&self,
traction_engine: &RollingStock,
towed_rolling_stock: &Option<TowedRollingStock>,
) -> Result<()> {
let max = Mass::new::<kilogram>(10000000.0);
let consist_mass = traction_engine.mass
+ towed_rolling_stock
.as_ref()
.map(|t| t.mass)
.unwrap_or_default();
if let Some(request_total_mass) = self.total_mass {
if request_total_mass < consist_mass || request_total_mass > max {
return Err(StdcmError::RequestValidationFail {
message: format!(
"The total weight must be between {} and {}",
&consist_mass.into_format_args(kilogram, DisplayStyle::Description),
&max.into_format_args(kilogram, DisplayStyle::Description),
),
}
.into());
}
}

Ok(())
}

fn validate_consist_length(
&self,
traction_engine: &RollingStock,
towed_rolling_stock: &Option<TowedRollingStock>,
) -> Result<()> {
let max = Length::new::<meter>(750.0);
let consist_length = traction_engine.length
+ towed_rolling_stock
.as_ref()
.map(|t| t.length)
.unwrap_or_default();
let consist_length = consist_length.floor::<meter>();

if let Some(request_total_length) = self.total_length {
if request_total_length < consist_length || request_total_length > max {
return Err(StdcmError::RequestValidationFail {
message: format!(
"The total length must be between {} and {}",
&consist_length.into_format_args(meter, DisplayStyle::Description),
&max.into_format_args(meter, DisplayStyle::Description),
),
}
.into());
}
}

Ok(())
}
}

impl<'de> Deserialize<'de> for Request {
Expand Down
3 changes: 2 additions & 1 deletion front/public/locales/en/errors.json
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@
"RollingStockNotFound": "Rolling stock '{{rolling_stock_id}}' does not exist",
"TowedRollingStockNotFound": "Towed rolling stock {towed_rolling_stock_id} does not exist",
"TrainSimulationFail": "Train simulation failed",
"TimetableNotFound": "Timetable '{{timetable_id}}' does not exist"
"TimetableNotFound": "Timetable '{{timetable_id}}' does not exist",
"RequestValidationFail": "Request validation failed"
},
"study": {
"Database": "Internal error (database)",
Expand Down
3 changes: 2 additions & 1 deletion front/public/locales/fr/errors.json
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@
"RollingStockNotFound": "Matériel roulant '{{rolling_stock_id}}' non trouvé",
"TowedRollingStockNotFound": "Matériel roulant remorqué {towed_rolling_stock_id} non trouvé",
"TrainSimulationFail": "Échec de la simulation du train",
"TimetableNotFound": "Grille horaire '{{timetable_id}}' non trouvée"
"TimetableNotFound": "Grille horaire '{{timetable_id}}' non trouvée",
"RequestValidationFail": "Échec de la validation de la requête"
},
"study": {
"Database": "Erreur interne (base de données)",
Expand Down

0 comments on commit e7b1c75

Please sign in to comment.