Skip to content

Commit

Permalink
editoast: parallelise errors layer generation
Browse files Browse the repository at this point in the history
  • Loading branch information
younesschrifi committed Nov 27, 2023
1 parent b6ead5b commit c68c840
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 97 deletions.
1 change: 1 addition & 0 deletions editoast/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions editoast/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ itertools = "0.11.0"
utoipa.workspace = true
paste = "1.0.14"
url = "2.4.1"
rayon = "1"

[dev-dependencies]
async-std = { version = "1.12.0", features = ["attributes", "tokio1"] }
Expand Down
189 changes: 97 additions & 92 deletions editoast/src/generated_data/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ pub mod switch_types;
pub mod switches;
pub mod track_sections;

use std::collections::HashMap;

use async_trait::async_trait;
use diesel::sql_query;
use diesel::sql_types::{Array, BigInt, Json};
use diesel_async::{AsyncPgConnection as PgConnection, RunQueryDsl};
use serde_json::to_value;
use std::collections::HashMap;

use super::GeneratedData;
use crate::error::Result;
Expand Down Expand Up @@ -89,10 +88,10 @@ impl<Ctx> GlobalErrorGenerator<Ctx> {
/// Generate errors given static object and global error generators.
/// This function assume that object error generators list isn't empty and sorted by priority.
/// Global errors are generated at the end.
fn generate_errors<Ctx: Default>(
async fn generate_errors<Ctx: Default>(
object_type: ObjectType,
infra_cache: &InfraCache,
graph: &Graph,
graph: &Graph<'_>,
object_err_generators: &'static ObjectErrorGenerators<Ctx>,
global_err_generators: &'static GlobalErrorGenerators<Ctx>,
) -> Vec<InfraError> {
Expand Down Expand Up @@ -143,90 +142,84 @@ fn generate_errors<Ctx: Default>(
errors
}

pub fn generate_infra_errors(infra_cache: &InfraCache) -> Vec<InfraError> {
pub async fn generate_infra_errors(infra_cache: &InfraCache) -> Vec<InfraError> {
// Create a graph for topological errors
let graph = Graph::load(infra_cache);

// Generate the errors
let mut infra_errors = generate_errors(
ObjectType::TrackSection,
infra_cache,
&graph,
&track_sections::OBJECT_GENERATORS,
&[],
);
infra_errors.extend(generate_errors(
ObjectType::Signal,
infra_cache,
&graph,
&signals::OBJECT_GENERATORS,
&[],
));
infra_errors.extend(generate_errors(
ObjectType::SpeedSection,
infra_cache,
&graph,
&speed_sections::OBJECT_GENERATORS,
&speed_sections::GLOBAL_GENERATORS,
));

infra_errors.extend(generate_errors(
ObjectType::SwitchType,
infra_cache,
&graph,
&switch_types::OBJECT_GENERATORS,
&[],
));

infra_errors.extend(generate_errors(
ObjectType::Detector,
infra_cache,
&graph,
&detectors::OBJECT_GENERATORS,
&[],
));
infra_errors.extend(generate_errors(
ObjectType::BufferStop,
infra_cache,
&graph,
&buffer_stops::OBJECT_GENERATORS,
&buffer_stops::GLOBAL_GENERATORS,
));

infra_errors.extend(generate_errors(
ObjectType::OperationalPoint,
infra_cache,
&graph,
&operational_points::OBJECT_GENERATORS,
&[],
));

infra_errors.extend(generate_errors(
ObjectType::Route,
infra_cache,
&graph,
&routes::OBJECT_GENERATORS,
&routes::GLOBAL_GENERATORS,
));

infra_errors.extend(generate_errors(
ObjectType::Switch,
infra_cache,
&graph,
&switches::OBJECT_GENERATORS,
&[],
));
infra_errors.extend(generate_errors(
ObjectType::Catenary,
infra_cache,
&graph,
&catenaries::OBJECT_GENERATORS,
&catenaries::GLOBAL_GENERATORS,
));

// TODO: generate neutralSections errors

infra_errors
let mut futures = vec![
generate_errors(
ObjectType::TrackSection,
infra_cache,
&graph,
&track_sections::OBJECT_GENERATORS,
&[],
),
generate_errors(
ObjectType::Signal,
infra_cache,
&graph,
&signals::OBJECT_GENERATORS,
&[],
),
generate_errors(
ObjectType::SpeedSection,
infra_cache,
&graph,
&speed_sections::OBJECT_GENERATORS,
&speed_sections::GLOBAL_GENERATORS,
),
generate_errors(
ObjectType::SwitchType,
infra_cache,
&graph,
&switch_types::OBJECT_GENERATORS,
&[],
),
generate_errors(
ObjectType::Detector,
infra_cache,
&graph,
&detectors::OBJECT_GENERATORS,
&[],
),
generate_errors(
ObjectType::BufferStop,
infra_cache,
&graph,
&buffer_stops::OBJECT_GENERATORS,
&buffer_stops::GLOBAL_GENERATORS,
),
generate_errors(
ObjectType::OperationalPoint,
infra_cache,
&graph,
&operational_points::OBJECT_GENERATORS,
&[],
),
// generate_errors(
// ObjectType::Switch,
// infra_cache,
// &graph,
// &switches::OBJECT_GENERATORS,
// &[],
// ),
generate_errors(
ObjectType::Catenary,
infra_cache,
&graph,
&catenaries::OBJECT_GENERATORS,
&catenaries::GLOBAL_GENERATORS,
),
// generate_errors(
// ObjectType::Route,
// infra_cache,
// &graph,
// &routes::OBJECT_GENERATORS,
// &routes::GLOBAL_GENERATORS,
// ),
];
let infra_errors = futures::future::join_all(futures).await;
infra_errors.into_iter().flatten().collect()
}

/// Get sql query that insert errors given an object type
Expand Down Expand Up @@ -284,11 +277,10 @@ impl GeneratedData for ErrorLayer {
infra_id: i64,
infra_cache: &InfraCache,
) -> Result<()> {
let infra_errors = generate_infra_errors(infra_cache);
let infra_errors = generate_infra_errors(infra_cache).await;

// Insert errors in DB
insert_errors(conn, infra_id, infra_errors).await?;

Ok(())
}

Expand All @@ -305,6 +297,8 @@ impl GeneratedData for ErrorLayer {

#[cfg(test)]
mod test {
use rstest::rstest;

use super::{
buffer_stops, catenaries, detectors, generate_errors, operational_points, routes, signals,
speed_sections, switch_types, switches, track_sections, Graph,
Expand All @@ -314,8 +308,8 @@ mod test {

use crate::schema::ObjectType;

#[test]
fn small_infra_cache_validation() {
#[rstest]
async fn small_infra_cache_validation() {
let small_infra_cache = create_small_infra_cache();

let graph = Graph::load(&small_infra_cache);
Expand All @@ -328,6 +322,7 @@ mod test {
&track_sections::OBJECT_GENERATORS,
&[],
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::Signal,
Expand All @@ -336,6 +331,7 @@ mod test {
&signals::OBJECT_GENERATORS,
&[],
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::SpeedSection,
Expand All @@ -344,6 +340,7 @@ mod test {
&speed_sections::OBJECT_GENERATORS,
&speed_sections::GLOBAL_GENERATORS,
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::SwitchType,
Expand All @@ -352,6 +349,7 @@ mod test {
&switch_types::OBJECT_GENERATORS,
&[],
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::Detector,
Expand All @@ -360,6 +358,7 @@ mod test {
&detectors::OBJECT_GENERATORS,
&[],
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::BufferStop,
Expand All @@ -368,6 +367,7 @@ mod test {
&buffer_stops::OBJECT_GENERATORS,
&buffer_stops::GLOBAL_GENERATORS,
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::Route,
Expand All @@ -376,6 +376,7 @@ mod test {
&routes::OBJECT_GENERATORS,
&routes::GLOBAL_GENERATORS,
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::OperationalPoint,
Expand All @@ -384,6 +385,7 @@ mod test {
&operational_points::OBJECT_GENERATORS,
&[],
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::Switch,
Expand All @@ -392,6 +394,7 @@ mod test {
&switches::OBJECT_GENERATORS,
&[],
)
.await
.is_empty());
assert!(generate_errors(
ObjectType::Catenary,
Expand All @@ -400,11 +403,12 @@ mod test {
&catenaries::OBJECT_GENERATORS,
&catenaries::GLOBAL_GENERATORS,
)
.await
.is_empty());
}

#[test]
fn error_priority_check() {
#[rstest]
async fn error_priority_check() {
let mut small_infra_cache = create_small_infra_cache();
let bf = create_buffer_stop_cache("BF_error", "E", 530.0);
small_infra_cache.add(bf);
Expand All @@ -416,7 +420,8 @@ mod test {
&graph,
&buffer_stops::OBJECT_GENERATORS,
&[],
);
)
.await;
assert_eq!(1, errors.len());
}
}
4 changes: 2 additions & 2 deletions editoast/src/schema/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::schema::ObjectRef;
use serde::{Deserialize, Serialize};
use strum_macros::EnumVariantNames;

#[derive(Serialize, Deserialize, PartialEq, Debug)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct InfraError {
obj_id: String,
Expand All @@ -14,7 +14,7 @@ pub struct InfraError {
sub_type: InfraErrorType,
}

#[derive(Serialize, Deserialize, PartialEq, Debug, EnumVariantNames)]
#[derive(Serialize, Deserialize, PartialEq, Debug, EnumVariantNames, Clone)]
#[strum(serialize_all = "snake_case")]
#[serde(tag = "error_type", rename_all = "snake_case", deny_unknown_fields)]
pub enum InfraErrorType {
Expand Down
6 changes: 3 additions & 3 deletions editoast/src/views/infra/auto_fixes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async fn list_auto_fixes(

let mut fixes = vec![];
for _ in 0..MAX_AUTO_FIXES_ITERATIONS {
let new_fixes = fix_infra(&mut infra_cache_clone)?;
let new_fixes = fix_infra(&mut infra_cache_clone).await?;
if new_fixes.is_empty() {
// Every possible error is fixed
return Ok(WebJson(fixes));
Expand All @@ -58,8 +58,8 @@ async fn list_auto_fixes(
Err(AutoFixesEditoastError::MaximumIterationReached().into())
}

fn fix_infra(infra_cache: &mut InfraCache) -> Result<Vec<Operation>> {
let infra_errors = generate_infra_errors(infra_cache);
async fn fix_infra(infra_cache: &mut InfraCache) -> Result<Vec<Operation>> {
let infra_errors = generate_infra_errors(infra_cache).await;

let mut delete_fixes_already_retained = HashSet::new();
let mut all_fixes = vec![];
Expand Down

0 comments on commit c68c840

Please sign in to comment.