diff --git a/editoast/editoast_models/src/db_connection_pool.rs b/editoast/editoast_models/src/db_connection_pool.rs index 9bc32b66d51..55e2eb1acaf 100644 --- a/editoast/editoast_models/src/db_connection_pool.rs +++ b/editoast/editoast_models/src/db_connection_pool.rs @@ -134,6 +134,14 @@ impl DbConnectionPoolV2 { /// # } /// ``` /// + /// ### Deadlocks + /// + /// We encountered a deadlock error in our tests, + /// especially those using `empty_infra` and `small_infra`. + /// Adding `#[serial_test::serial]` solved the issue. + /// We tried increasing the deadlock timeout, but that didn't work. + /// Using random `infra_id` with rand didn't help either. + /// /// ## Guidelines /// /// To prevent these issues, prefer the following patterns: diff --git a/editoast/src/fixtures.rs b/editoast/src/fixtures.rs index 4946a72e33f..4f5cb1c2dcd 100644 --- a/editoast/src/fixtures.rs +++ b/editoast/src/fixtures.rs @@ -4,6 +4,7 @@ pub mod tests { use std::io::Cursor; use std::ops::{Deref, DerefMut}; use std::sync::Arc; + use uuid::Uuid; use editoast_models::create_connection_pool; use editoast_models::DbConnection; @@ -261,8 +262,7 @@ pub mod tests { } = scenario_fixture_set().await; let pathfinding = pathfinding(db_pool()).await; - let mut rs_name = "fast_rolling_stock_".to_string(); - rs_name.push_str(name); + let rs_name = format!("fast_rolling_stock_{}_{name}", Uuid::new_v4()).to_string(); let rolling_stock = named_fast_rolling_stock(&rs_name, db_pool()).await; let ts_model = make_train_schedule( db_pool(), @@ -474,7 +474,7 @@ pub mod tests { Infra::changeset() .name("small_infra".to_owned()) .last_railjson_version() - .persist(railjson, db_pool) + .persist(railjson, db_pool.get().await.unwrap().deref_mut()) .await .unwrap() } diff --git a/editoast/src/generated_data/mod.rs b/editoast/src/generated_data/mod.rs index aa45c89b0ac..d4f51e6a02d 100644 --- a/editoast/src/generated_data/mod.rs +++ b/editoast/src/generated_data/mod.rs @@ -30,6 +30,7 @@ use operational_point::OperationalPointLayer; use psl_sign::PSLSignLayer; use signal::SignalLayer; use speed_section::SpeedSectionLayer; +use std::ops::DerefMut; use std::sync::Arc; use switch::SwitchLayer; use tracing::debug; @@ -39,7 +40,7 @@ use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; use editoast_models::DbConnection; -use editoast_models::DbConnectionPool; +use editoast_models::DbConnectionPoolV2; editoast_common::schemas! { error::schemas(), @@ -68,12 +69,11 @@ pub trait GeneratedData { } async fn refresh_pool( - pool: Arc, + pool: Arc, infra: i64, infra_cache: &InfraCache, ) -> Result<()> { - let mut conn = pool.get().await?; - Self::refresh(&mut conn, infra, infra_cache).await + Self::refresh(pool.get().await?.deref_mut(), infra, infra_cache).await } /// Search and update all objects that needs to be refreshed given a list of operation. @@ -86,37 +86,39 @@ pub trait GeneratedData { } /// Refresh all the generated data of a given infra +#[tracing::instrument(level = "debug", skip_all, fields(infra_id))] pub async fn refresh_all( - db_pool: Arc, - infra: i64, + db_pool: Arc, + infra_id: i64, infra_cache: &InfraCache, ) -> Result<()> { // The other layers depend on track section layer. // We must wait until its completion before running the other requests in parallel - TrackSectionLayer::refresh_pool(db_pool.clone(), infra, infra_cache).await?; - debug!("⚙️ Infra {infra}: track section layer is generated"); - let mut conn = db_pool.get().await?; + TrackSectionLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache).await?; + debug!("⚙️ Infra {infra_id}: track section layer is generated"); // The analyze step significantly improves the performance when importing and generating together // It doesn’t seem to make a different when the generation step is ran separately // It isn’t clear why without analyze the Postgres server seems to run at 100% without halting - sql_query("analyze").execute(&mut conn).await?; - debug!("⚙️ Infra {infra}: database analyzed"); + sql_query("analyze") + .execute(db_pool.get().await?.deref_mut()) + .await?; + debug!("⚙️ Infra {infra_id}: database analyzed"); futures::try_join!( - SpeedSectionLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - SignalLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - SwitchLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - BufferStopLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - ElectrificationLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - DetectorLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - OperationalPointLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - PSLSignLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - NeutralSectionLayer::refresh_pool(db_pool.clone(), infra, infra_cache), - NeutralSignLayer::refresh_pool(db_pool.clone(), infra, infra_cache), + SpeedSectionLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + SignalLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + SwitchLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + BufferStopLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + ElectrificationLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + DetectorLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + OperationalPointLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + PSLSignLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + NeutralSectionLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), + NeutralSignLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache), )?; - debug!("⚙️ Infra {infra}: object layers is generated"); + debug!("⚙️ Infra {infra_id}: object layers is generated"); // The error layer depends on the other layers and must be executed at the end. - ErrorLayer::refresh_pool(db_pool.clone(), infra, infra_cache).await?; - debug!("⚙️ Infra {infra}: errors layer is generated"); + ErrorLayer::refresh_pool(db_pool.clone(), infra_id, infra_cache).await?; + debug!("⚙️ Infra {infra_id}: errors layer is generated"); Ok(()) } @@ -164,18 +166,20 @@ pub mod tests { use rstest::rstest; use std::ops::DerefMut; - use crate::fixtures::tests::db_pool; use crate::generated_data::clear_all; use crate::generated_data::refresh_all; use crate::generated_data::update_all; use crate::modelsv2::fixtures::create_empty_infra; use editoast_models::DbConnectionPoolV2; - #[rstest] // Slow test + #[rstest] + // Slow test + // PostgreSQL deadlock can happen in this test, see section `Deadlock` of [DbConnectionPoolV2::get] for more information + #[serial_test::serial] async fn refresh_all_test() { - let db_pool_v2 = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool_v2.get_ok().deref_mut()).await; - assert!(refresh_all(db_pool(), infra.id, &Default::default()) + let db_pool = DbConnectionPoolV2::for_tests(); + let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + assert!(refresh_all(db_pool.into(), infra.id, &Default::default()) .await .is_ok()); } diff --git a/editoast/src/main.rs b/editoast/src/main.rs index 6a2b2b32f0e..b156ea82bd2 100644 --- a/editoast/src/main.rs +++ b/editoast/src/main.rs @@ -64,6 +64,7 @@ pub use redis_utils::{RedisClient, RedisConnection}; use std::error::Error; use std::fs::File; use std::io::{BufReader, IsTerminal}; +use std::ops::DerefMut; use std::process::exit; use std::sync::Arc; use std::{env, fs}; @@ -220,12 +221,12 @@ async fn run() -> Result<(), Box> { SearchCommands::Refresh(args) => refresh_search_tables(args, db_pool.pool_v1()).await, }, Commands::Infra(subcommand) => match subcommand { - InfraCommands::Clone(args) => clone_infra(args, db_pool.pool_v1()).await, + InfraCommands::Clone(args) => clone_infra(args, db_pool.into()).await, InfraCommands::Clear(args) => clear_infra(args, db_pool.pool_v1(), redis_config).await, InfraCommands::Generate(args) => { - generate_infra(args, db_pool.pool_v1(), redis_config).await + generate_infra(args, db_pool.into(), redis_config).await } - InfraCommands::ImportRailjson(args) => import_railjson(args, db_pool.pool_v1()).await, + InfraCommands::ImportRailjson(args) => import_railjson(args, db_pool.into()).await, }, Commands::Timetables(subcommand) => match subcommand { TimetablesCommands::Import(args) => trains_import(args, db_pool.pool_v1()).await, @@ -484,19 +485,18 @@ async fn batch_retrieve_infras( /// This command refresh all infra given as input (if no infra given then refresh all of them) async fn generate_infra( args: GenerateArgs, - db_pool: Arc, + db_pool: Arc, redis_config: RedisConfig, ) -> Result<(), Box> { - let mut conn = db_pool.get().await?; let mut infras = vec![]; if args.infra_ids.is_empty() { // Retrieve all available infra - for infra in Infra::all(&mut conn).await { + for infra in Infra::all(db_pool.get().await?.deref_mut()).await { infras.push(infra); } } else { // Retrieve given infras - infras = batch_retrieve_infras(&mut conn, &args.infra_ids).await?; + infras = batch_retrieve_infras(db_pool.get().await?.deref_mut(), &args.infra_ids).await?; } for mut infra in infras { println!( @@ -504,7 +504,7 @@ async fn generate_infra( infra.name.clone().bold(), infra.id ); - let infra_cache = InfraCache::load(&mut conn, &infra).await?; + let infra_cache = InfraCache::load(db_pool.get().await?.deref_mut(), &infra).await?; if infra .refresh(db_pool.clone(), args.force, &infra_cache) .await? @@ -578,10 +578,9 @@ async fn import_rolling_stock( async fn clone_infra( infra_args: InfraCloneArgs, - db_pool: Arc, + db_pool: Arc, ) -> Result<(), Box> { - let conn = &mut db_pool.get().await?; - let infra = Infra::retrieve(conn, infra_args.id as i64) + let infra = Infra::retrieve(db_pool.get().await?.deref_mut(), infra_args.id as i64) .await? .ok_or_else(|| { // When EditoastError will be removed from the models crate, @@ -594,7 +593,9 @@ async fn clone_infra( let new_name = infra_args .new_name .unwrap_or_else(|| format!("{} (clone)", infra.name)); - let cloned_infra = infra.clone(conn, new_name).await?; + let cloned_infra = infra + .clone(db_pool.get().await?.deref_mut(), new_name) + .await?; println!( "✅ Infra {} (ID: {}) was successfully cloned", cloned_infra.name.bold(), @@ -605,7 +606,7 @@ async fn clone_infra( async fn import_railjson( args: ImportRailjsonArgs, - db_pool: Arc, + db_pool: Arc, ) -> Result<(), Box> { let railjson_file = match File::open(args.railjson_path.clone()) { Ok(file) => file, @@ -629,18 +630,19 @@ async fn import_railjson( let railjson: RailJson = serde_json::from_reader(BufReader::new(railjson_file))?; println!("🍞 Importing infra {infra_name}"); - let mut infra = infra.persist(railjson, db_pool.clone()).await?; + let mut infra = infra + .persist(railjson, db_pool.get().await?.deref_mut()) + .await?; - let mut conn = db_pool.get().await?; infra - .bump_version(&mut conn) + .bump_version(db_pool.get().await?.deref_mut()) .await .map_err(|_| InfraApiError::NotFound { infra_id: infra.id })?; println!("✅ Infra {infra_name}[{}] saved!", infra.id); // Generate only if the was set if args.generate { - let infra_cache = InfraCache::load(&mut conn, &infra).await?; + let infra_cache = InfraCache::load(db_pool.get().await?.deref_mut(), &infra).await?; infra.refresh(db_pool, true, &infra_cache).await?; println!( "✅ Infra {infra_name}[{}] generated data refreshed!", @@ -882,9 +884,6 @@ mod tests { get_trainschedule_json_array, TestFixture, }; use crate::modelsv2::RollingStockModel; - use diesel::sql_query; - use diesel::sql_types::Text; - use diesel_async::RunQueryDsl; use modelsv2::DeleteStatic; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; @@ -1081,7 +1080,7 @@ mod tests { } #[rstest] - async fn import_railjson_ko_file_not_found(db_pool: Arc) { + async fn import_railjson_ko_file_not_found() { // GIVEN let railjson_path = "non/existing/railjson/file/location"; let args: ImportRailjsonArgs = ImportRailjsonArgs { @@ -1091,7 +1090,7 @@ mod tests { }; // WHEN - let result = import_railjson(args.clone(), db_pool).await; + let result = import_railjson(args.clone(), DbConnectionPoolV2::for_tests().into()).await; // THEN assert!(result.is_err()); @@ -1106,7 +1105,7 @@ mod tests { } #[rstest] - async fn import_railjson_ok(db_pool: Arc) { + async fn import_railjson_ok() { // GIVEN let railjson = Default::default(); let file = generate_temp_file::(&railjson); @@ -1124,18 +1123,10 @@ mod tests { }; // WHEN - let result = import_railjson(args, db_pool.clone()).await; + let result = import_railjson(args, DbConnectionPoolV2::for_tests().into()).await; // THEN assert!(result.is_ok()); - - // CLEANUP - let mut conn = db_pool.get().await.unwrap(); - sql_query("DELETE FROM infra WHERE name = $1") - .bind::(infra_name) - .execute(&mut conn) - .await - .unwrap(); } fn generate_temp_file(object: &T) -> NamedTempFile { diff --git a/editoast/src/modelsv2/fixtures.rs b/editoast/src/modelsv2/fixtures.rs index 73b3c60abe0..2d42f27464e 100644 --- a/editoast/src/modelsv2/fixtures.rs +++ b/editoast/src/modelsv2/fixtures.rs @@ -2,6 +2,7 @@ use std::io::Cursor; use chrono::Utc; use editoast_schemas::infra::InfraObject; +use editoast_schemas::infra::RailJson; use editoast_schemas::primitives::OSRDObject; use editoast_schemas::train_schedule::TrainScheduleBase; @@ -21,6 +22,8 @@ use crate::modelsv2::Tags; use crate::views::rolling_stocks::rolling_stock_form::RollingStockForm; use crate::views::v2::train_schedule::TrainScheduleForm; use editoast_models::DbConnection; +use editoast_models::DbConnectionPool; +use editoast_models::DbConnectionPoolV2; pub fn project_changeset(name: &str) -> Changeset { Project::changeset() @@ -282,3 +285,16 @@ where assert!(result.is_ok(), "Failed to create a {object_type}"); railjson_object } + +pub async fn create_small_infra(conn: &mut DbConnection) -> Infra { + let railjson: RailJson = serde_json::from_str(include_str!( + "../../../tests/data/infras/small_infra/infra.json" + )) + .unwrap(); + Infra::changeset() + .name("small_infra".to_owned()) + .last_railjson_version() + .persist(railjson, conn) + .await + .unwrap() +} diff --git a/editoast/src/modelsv2/infra.rs b/editoast/src/modelsv2/infra.rs index ad744f12fdd..c96726831a5 100644 --- a/editoast/src/modelsv2/infra.rs +++ b/editoast/src/modelsv2/infra.rs @@ -6,6 +6,8 @@ mod speed_limit_tags; mod split_track_section_with_data; mod voltage; +use std::ops::DerefMut; + use chrono::NaiveDateTime; use chrono::Utc; use derivative::Derivative; @@ -14,6 +16,7 @@ use diesel::sql_query; use diesel::sql_types::BigInt; use diesel::ExpressionMethods; use diesel::QueryDsl; +use diesel_async::AsyncConnection; use diesel_async::RunQueryDsl; use editoast_derive::ModelV2; use serde::Deserialize; @@ -36,7 +39,7 @@ use crate::modelsv2::railjson::persist_railjson; use crate::modelsv2::Create; use crate::tables::infra::dsl; use editoast_models::DbConnection; -use editoast_models::DbConnectionPool; +use editoast_models::DbConnectionPoolV2; use editoast_schemas::infra::RailJson; use editoast_schemas::infra::RAILJSON_VERSION; use editoast_schemas::primitives::ObjectType; @@ -72,16 +75,11 @@ pub struct Infra { } impl InfraChangeset { - pub async fn persist( - self, - railjson: RailJson, - db_pool: Arc, - ) -> Result { - let conn = &mut db_pool.get().await?; + pub async fn persist(self, railjson: RailJson, conn: &mut DbConnection) -> Result { let infra = self.create(conn).await?; // TODO: lock infra for update debug!("🛤 Begin importing all railjson objects"); - if let Err(e) = persist_railjson(db_pool, infra.id, railjson).await { + if let Err(e) = persist_railjson(conn, infra.id, railjson).await { error!("Could not import infrastructure {}. Rolling back", infra.id); infra.delete(conn).await?; return Err(e); @@ -123,7 +121,7 @@ impl Infra { } pub async fn clone(&self, conn: &mut DbConnection, new_name: String) -> Result { - conn.build_transaction().run(|conn| Box::pin(async { + conn.transaction(|conn| Box::pin(async { // Duplicate infra shell let cloned_infra = ::clone(self) .into_changeset() @@ -216,7 +214,7 @@ impl Infra { /// If refreshed you need to call `invalidate_after_refresh` to invalidate layer cache pub async fn refresh( &mut self, - db_pool: Arc, + db_pool: Arc, force: bool, infra_cache: &InfraCache, ) -> Result { @@ -233,8 +231,8 @@ impl Infra { generated_data::refresh_all(db_pool.clone(), self.id, infra_cache).await?; // Update generated infra version - let mut conn = db_pool.get().await?; - self.bump_generated_version(&mut conn).await?; + self.bump_generated_version(db_pool.get().await?.deref_mut()) + .await?; Ok(true) } @@ -315,9 +313,6 @@ pub mod tests { use super::Infra; use crate::error::EditoastError; - use crate::fixtures::tests::db_pool; - use crate::fixtures::tests::small_infra; - use crate::fixtures::tests::IntoFixture; use crate::modelsv2::fixtures::create_empty_infra; use crate::modelsv2::infra::DEFAULT_INFRA_VERSION; use crate::modelsv2::prelude::*; @@ -338,24 +333,28 @@ pub mod tests { } #[rstest] + // PostgreSQL deadlock can happen in this test, see section `Deadlock` of [DbConnectionPoolV2::get] for more information + #[serial_test::serial] async fn clone_infra_with_new_name_returns_new_cloned_infra() { // GIVEN - let db_pool = db_pool(); - let small_infra = small_infra(db_pool.clone()).await; + let db_pool = DbConnectionPoolV2::for_tests(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; let infra_new_name = "clone_infra_with_new_name_returns_new_cloned_infra".to_string(); // WHEN - let mut conn = db_pool.get().await.unwrap(); - let result = small_infra.clone(&mut conn, infra_new_name.clone()).await; + let result = empty_infra + .clone(db_pool.get_ok().deref_mut(), infra_new_name.clone()) + .await + .expect("could not clone infra"); // THEN - let infra = result.expect("could not clone infra").into_fixture(db_pool); - assert_eq!(infra.name, infra_new_name); + assert_eq!(result.name, infra_new_name); } #[rstest] + #[serial_test::serial] async fn persists_railjson_ko_version() { - let pool = db_pool(); + let db_pool = DbConnectionPoolV2::for_tests(); let railjson_with_invalid_version = RailJson { version: "0".to_string(), ..Default::default() @@ -363,7 +362,7 @@ pub mod tests { let res = Infra::changeset() .name("test".to_owned()) .last_railjson_version() - .persist(railjson_with_invalid_version, pool) + .persist(railjson_with_invalid_version, db_pool.get_ok().deref_mut()) .await; assert!(res.is_err()); let expected_error = RailJsonError::UnsupportedVersion { @@ -395,14 +394,13 @@ pub mod tests { version: RAILJSON_VERSION.to_string(), }; - let pool = db_pool(); + let db_pool = DbConnectionPoolV2::for_tests(); let infra = Infra::changeset() .name("persist_railjson_ok_infra".to_owned()) .last_railjson_version() - .persist(railjson.clone(), pool.clone()) + .persist(railjson.clone(), db_pool.get_ok().deref_mut()) .await - .expect("could not persist infra") - .into_fixture(pool.clone()); + .expect("could not persist infra"); // THEN assert_eq!(infra.railjson_version, railjson.version); @@ -412,51 +410,94 @@ pub mod tests { objects } - let conn = &mut pool.get().await.unwrap(); let id = infra.id; assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.buffer_stops) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.routes) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.extended_switch_types) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.switches) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.track_sections) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.speed_sections) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.neutral_sections) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.electrifications) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.signals) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.detectors) ); assert_eq!( - sort::(find_all_schemas(conn, id).await.unwrap()), + sort::( + find_all_schemas(db_pool.get_ok().deref_mut(), id) + .await + .unwrap() + ), sort(railjson.operational_points) ); } diff --git a/editoast/src/modelsv2/railjson.rs b/editoast/src/modelsv2/railjson.rs index 309c434d9c4..743a99d077a 100644 --- a/editoast/src/modelsv2/railjson.rs +++ b/editoast/src/modelsv2/railjson.rs @@ -1,15 +1,12 @@ -use std::sync::Arc; - use editoast_derive::EditoastError; use editoast_schemas::infra::RailJson; use editoast_schemas::infra::RAILJSON_VERSION; -use crate::error::InternalError; use crate::error::Result; use crate::modelsv2::infra_objects::*; use crate::modelsv2::prelude::*; +use diesel_async::AsyncConnection; use editoast_models::DbConnection; -use editoast_models::DbConnectionPool; #[derive(Debug, thiserror::Error, EditoastError)] #[editoast_error(base_id = "railjson")] @@ -22,26 +19,11 @@ pub enum RailJsonError { /// /// All objects are attached to a given infra. /// -/// #### `/!\ ATTENTION /!\` On failure this function does NOT rollback the insertions! pub async fn persist_railjson( - db_pool: Arc, + connection: &mut DbConnection, infra_id: i64, railjson: RailJson, ) -> Result<()> { - macro_rules! persist { - ($model:ident, $objects:expr) => { - async { - let conn = &mut db_pool.get().await.map_err(Into::::into)?; - let _ = $model::create_batch::<_, Vec<_>>( - conn, - $model::from_infra_schemas(infra_id, $objects), - ) - .await?; - Ok(()) - } - }; - } - let RailJson { version, track_sections, @@ -56,6 +38,7 @@ pub async fn persist_railjson( extended_switch_types, neutral_sections, } = railjson; + if version != RAILJSON_VERSION { return Err(RailJsonError::UnsupportedVersion { actual: version, @@ -63,20 +46,80 @@ pub async fn persist_railjson( } .into()); } - futures::try_join!( - persist!(TrackSectionModel, track_sections), - persist!(BufferStopModel, buffer_stops), - persist!(ElectrificationModel, electrifications), - persist!(DetectorModel, detectors), - persist!(OperationalPointModel, operational_points), - persist!(RouteModel, routes), - persist!(SignalModel, signals), - persist!(SwitchModel, switches), - persist!(SpeedSectionModel, speed_sections), - persist!(SwitchTypeModel, extended_switch_types), - persist!(NeutralSectionModel, neutral_sections), - ) - .map(|_| ()) + + connection + .transaction(|conn| { + Box::pin(async { + let _ = TrackSectionModel::create_batch::<_, Vec<_>>( + conn, + TrackSectionModel::from_infra_schemas(infra_id, track_sections), + ) + .await?; + + let _ = BufferStopModel::create_batch::<_, Vec<_>>( + conn, + BufferStopModel::from_infra_schemas(infra_id, buffer_stops), + ) + .await?; + + let _ = ElectrificationModel::create_batch::<_, Vec<_>>( + conn, + ElectrificationModel::from_infra_schemas(infra_id, electrifications), + ) + .await?; + + let _ = DetectorModel::create_batch::<_, Vec<_>>( + conn, + DetectorModel::from_infra_schemas(infra_id, detectors), + ) + .await?; + + let _ = OperationalPointModel::create_batch::<_, Vec<_>>( + conn, + OperationalPointModel::from_infra_schemas(infra_id, operational_points), + ) + .await?; + + let _ = RouteModel::create_batch::<_, Vec<_>>( + conn, + RouteModel::from_infra_schemas(infra_id, routes), + ) + .await?; + + let _ = SignalModel::create_batch::<_, Vec<_>>( + conn, + SignalModel::from_infra_schemas(infra_id, signals), + ) + .await?; + + let _ = SwitchModel::create_batch::<_, Vec<_>>( + conn, + SwitchModel::from_infra_schemas(infra_id, switches), + ) + .await?; + + let _ = SpeedSectionModel::create_batch::<_, Vec<_>>( + conn, + SpeedSectionModel::from_infra_schemas(infra_id, speed_sections), + ) + .await?; + + let _ = SwitchTypeModel::create_batch::<_, Vec<_>>( + conn, + SwitchTypeModel::from_infra_schemas(infra_id, extended_switch_types), + ) + .await?; + + let _ = NeutralSectionModel::create_batch::<_, Vec<_>>( + conn, + NeutralSectionModel::from_infra_schemas(infra_id, neutral_sections), + ) + .await?; + + Ok(()) + }) + }) + .await } pub async fn find_all_schemas(conn: &mut DbConnection, infra_id: i64) -> Result diff --git a/editoast/src/views/infra/auto_fixes/mod.rs b/editoast/src/views/infra/auto_fixes/mod.rs index 24994fdf1c6..44de240c7f8 100644 --- a/editoast/src/views/infra/auto_fixes/mod.rs +++ b/editoast/src/views/infra/auto_fixes/mod.rs @@ -1,5 +1,6 @@ use std::collections::hash_map::Entry; use std::collections::hash_map::HashMap; +use std::ops::DerefMut; use actix_web::get; use actix_web::web::Data; @@ -27,7 +28,7 @@ use crate::modelsv2::prelude::*; use crate::modelsv2::Infra; use crate::views::infra::InfraApiError; use crate::views::infra::InfraIdParam; -use editoast_models::DbConnectionPool; +use editoast_models::DbConnectionPoolV2; use editoast_schemas::infra::InfraObject; use editoast_schemas::primitives::OSRDIdentified as _; use editoast_schemas::primitives::OSRDObject; @@ -86,18 +87,19 @@ crate::routes! { async fn list_auto_fixes( infra: Path, infra_caches: Data>, - db_pool: Data, + db_pool: Data, ) -> Result>> { let infra_id = infra.into_inner(); - let mut conn = db_pool.get().await?; - let infra = - Infra::retrieve_or_fail(&mut conn, infra_id, || InfraApiError::NotFound { infra_id }) - .await?; + let infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), infra_id, || { + InfraApiError::NotFound { infra_id } + }) + .await?; // accepting the early release of ReadGuard as it's anyway released when sending the suggestions (so before edit) - let mut infra_cache_clone = InfraCache::get_or_load(&mut conn, &infra_caches, &infra) - .await? - .clone(); + let mut infra_cache_clone = + InfraCache::get_or_load(db_pool.get().await?.deref_mut(), &infra_caches, &infra) + .await? + .clone(); let mut fixes = vec![]; for _ in 0..MAX_AUTO_FIXES_ITERATIONS { @@ -320,26 +322,25 @@ mod tests { use actix_http::Request; use actix_http::StatusCode; - use actix_web::test::call_service; - use actix_web::test::read_body_json; use actix_web::test::TestRequest; use editoast_schemas::infra::BufferStop; use editoast_schemas::infra::BufferStopExtension; - use serde_json::json; + use pretty_assertions::assert_eq; + use std::ops::DerefMut; use super::*; - use crate::fixtures::tests::db_pool; - use crate::fixtures::tests::empty_infra; - use crate::fixtures::tests::small_infra; use crate::generated_data::infra_error::InfraErrorType; use crate::infra_cache::object_cache::BufferStopCache; use crate::infra_cache::object_cache::DetectorCache; use crate::infra_cache::object_cache::SignalCache; + use crate::infra_cache::operation::create::apply_create_operation; use crate::infra_cache::operation::DeleteOperation; use crate::infra_cache::operation::Operation; use crate::infra_cache::InfraCacheEditoastError; + use crate::modelsv2::fixtures::create_empty_infra; + use crate::modelsv2::fixtures::create_small_infra; use crate::views::infra::errors::query_errors; - use crate::views::tests::create_test_service; + use crate::views::test_app::TestAppBuilder; use editoast_schemas::infra::ApplicableDirectionsTrackRange; use editoast_schemas::infra::Detector; use editoast_schemas::infra::Electrification; @@ -359,19 +360,6 @@ mod tests { use editoast_schemas::primitives::ObjectRef; use editoast_schemas::primitives::ObjectType; - async fn get_infra_cache(infra: &Infra) -> InfraCache { - InfraCache::load(&mut db_pool().get().await.unwrap(), infra) - .await - .unwrap() - } - - async fn force_refresh(infra: &mut Infra) { - infra - .refresh(db_pool(), true, &get_infra_cache(infra).await) - .await - .unwrap(); - } - fn auto_fixes_request(infra_id: i64) -> Request { TestRequest::get() .uri(format!("/infra/{infra_id}/auto_fixes").as_str()) @@ -380,56 +368,75 @@ mod tests { #[rstest::rstest] async fn test_no_fix() { - let app = create_test_service().await; - let small_infra = small_infra(db_pool()).await; - let small_infra_id = small_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; + let small_infra_id = small_infra.id; - let response = call_service(&app, auto_fixes_request(small_infra_id)).await; + let operations: Vec = app + .fetch(auto_fixes_request(small_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); - assert_eq!(response.status(), StatusCode::OK); - let operations: Vec = read_body_json(response).await; assert!(operations.is_empty()); } #[rstest::rstest] async fn test_fix_invalid_ref_puntual_objects() { // GIVEN - let app = create_test_service().await; - let mut small_infra = small_infra(db_pool()).await; - let small_infra_id = small_infra.id(); - let conn = &mut db_pool().get().await.unwrap(); - force_refresh(&mut small_infra).await; + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let mut small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; + let small_infra_id = small_infra.id; + let mut infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &small_infra) + .await + .expect("Failed to get infra cache"); + small_infra + .refresh(db_pool.clone(), true, &infra_cache) + .await + .expect("Failed to refresh infra"); // Check the only initial issues are "overlapping_speed_sections" warnings - let (infra_errors_before_all, before_all_count) = query_errors(conn, &small_infra).await; + let (infra_errors_before_all, before_all_count) = + query_errors(db_pool.get_ok().deref_mut(), &small_infra).await; assert!(infra_errors_before_all .iter() .all(|e| matches!(e.sub_type, InfraErrorType::OverlappingSpeedSections { .. }))); // Remove a track - let deletion = Operation::Delete(DeleteOperation { + let delete_operation = DeleteOperation { obj_id: "TA1".to_string(), obj_type: ObjectType::TrackSection, - }); - let req_del = TestRequest::post() - .uri(format!("/infra/{small_infra_id}/").as_str()) - .set_json(json!([deletion])) - .to_request(); - assert_eq!(call_service(&app, req_del).await.status(), StatusCode::OK); + }; + let deletion = Operation::Delete(delete_operation.clone()); + let _ = deletion + .apply(small_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to delete a track"); + infra_cache + .apply_operations(&vec![CacheOperation::Delete(delete_operation.into())]) + .expect("Failed to apply operations"); + small_infra + .refresh(db_pool.clone(), true, &infra_cache) + .await + .expect("Failed to refresh infra"); // Check that some new issues appeared - let (infra_errors_before_fix, before_fix_count) = query_errors(conn, &small_infra).await; + let (infra_errors_before_fix, before_fix_count) = + query_errors(db_pool.get_ok().deref_mut(), &small_infra).await; assert!(before_fix_count > before_all_count); // WHEN - let response = call_service(&app, auto_fixes_request(small_infra_id)).await; + let operations: Vec = app + .fetch(auto_fixes_request(small_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); // THEN - let (infra_errors_after_fix, _) = query_errors(conn, &small_infra).await; + let (infra_errors_after_fix, _) = + query_errors(db_pool.get_ok().deref_mut(), &small_infra).await; assert_eq!(infra_errors_after_fix, infra_errors_before_fix); - assert_eq!(response.status(), StatusCode::OK); - let operations: Vec = read_body_json(response).await; assert!(operations.contains(&Operation::Delete(DeleteOperation { obj_id: "SA0".to_string(), obj_type: ObjectType::Signal, @@ -450,24 +457,25 @@ mod tests { #[rstest::rstest] async fn test_fix_invalid_ref_route_entry_exit() { - let app = create_test_service().await; - let small_infra = small_infra(db_pool()).await; - let small_infra_id = small_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; + let small_infra_id = small_infra.id; // Remove a buffer stop let deletion = Operation::Delete(DeleteOperation { obj_id: "buffer_stop.4".to_string(), obj_type: ObjectType::BufferStop, }); - let req_del = TestRequest::post() - .uri(format!("/infra/{small_infra_id}/").as_str()) - .set_json(json!([deletion])) - .to_request(); - assert_eq!(call_service(&app, req_del).await.status(), StatusCode::OK); + deletion + .apply(small_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to delete BufferStop"); - let response = call_service(&app, auto_fixes_request(small_infra_id)).await; + let operations: Vec = app + .fetch(auto_fixes_request(small_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); - assert_eq!(response.status(), StatusCode::OK); - let operations: Vec = read_body_json(response).await; assert!(operations.contains(&Operation::Delete(DeleteOperation { obj_id: "rt.DE0->buffer_stop.4".to_string(), obj_type: ObjectType::Route, @@ -680,19 +688,12 @@ mod tests { assert!(operations.is_empty()); } - fn get_create_operation_request(railjson: InfraObject, infra_id: i64) -> Request { - let create_operation = Operation::Create(Box::new(railjson)); - TestRequest::post() - .uri(format!("/infra/{infra_id}/").as_str()) - .set_json(json!([create_operation])) - .to_request() - } - #[rstest::rstest] async fn invalid_switch_ports() { - let app = create_test_service().await; - let small_infra = small_infra(db_pool()).await; - let small_infra_id = small_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; + let small_infra_id = small_infra.id; let ports = HashMap::from([ ("WRONG".into(), TrackEndpoint::new("TA1", Endpoint::End)), @@ -705,17 +706,19 @@ mod tests { ..Default::default() }; // Create an invalid switch - let req_create = - get_create_operation_request(invalid_switch.clone().into(), small_infra_id); - assert_eq!( - call_service(&app, req_create).await.status(), - StatusCode::OK - ); - - let response = call_service(&app, auto_fixes_request(small_infra_id)).await; - assert_eq!(response.status(), StatusCode::OK); + apply_create_operation( + &invalid_switch.clone().into(), + small_infra_id, + db_pool.get_ok().deref_mut(), + ) + .await + .expect("Failed to create invalid_switch object"); + + let operations: Vec = app + .fetch(auto_fixes_request(small_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); - let operations: Vec = read_body_json(response).await; assert!(operations.contains(&Operation::Delete(DeleteOperation { obj_id: invalid_switch.get_id().to_string(), obj_type: ObjectType::Switch, @@ -724,9 +727,10 @@ mod tests { #[rstest::rstest] async fn odd_buffer_stop_location() { - let app = create_test_service().await; - let empty_infra = empty_infra(db_pool()).await; - let empty_infra_id = empty_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let empty_infra_id = empty_infra.id; // Create an odd buffer stops (to a track endpoint linked by a switch) let track: InfraObject = TrackSection { @@ -757,18 +761,16 @@ mod tests { } .into(); for obj in [&track, &bs_start, &bs_stop, &bs_odd] { - let req_create = get_create_operation_request(obj.clone(), empty_infra_id); - assert_eq!( - call_service(&app, req_create).await.status(), - StatusCode::OK - ); + apply_create_operation(obj, empty_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to create object"); } - let response = call_service(&app, auto_fixes_request(empty_infra_id)).await; + let operations: Vec = app + .fetch(auto_fixes_request(empty_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); - assert_eq!(response.status(), StatusCode::OK); - - let operations: Vec = read_body_json(response).await; assert!(operations.contains(&Operation::Delete(DeleteOperation { obj_id: bs_odd.get_id().clone(), obj_type: ObjectType::BufferStop, @@ -777,26 +779,25 @@ mod tests { #[rstest::rstest] async fn empty_object() { - let app = create_test_service().await; - let empty_infra = empty_infra(db_pool()).await; - let empty_infra_id = empty_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let empty_infra_id = empty_infra.id; let electrification: InfraObject = Electrification::default().into(); let operational_point = OperationalPoint::default().into(); let speed_section = SpeedSection::default().into(); for obj in [&electrification, &operational_point, &speed_section] { - let req_create = get_create_operation_request(obj.clone(), empty_infra_id); - assert_eq!( - call_service(&app, req_create).await.status(), - StatusCode::OK - ); + apply_create_operation(obj, empty_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to create object"); } - let response = call_service(&app, auto_fixes_request(empty_infra_id)).await; - assert_eq!(response.status(), StatusCode::OK); - - let operations: Vec = read_body_json(response).await; + let operations: Vec = app + .fetch(auto_fixes_request(empty_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); for obj in [&electrification, &operational_point, &speed_section] { assert!(operations.contains(&Operation::Delete(DeleteOperation { @@ -808,9 +809,10 @@ mod tests { #[rstest::rstest] async fn out_of_range_must_be_ignored() { - let app = create_test_service().await; - let empty_infra = empty_infra(db_pool()).await; - let empty_infra_id = empty_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let empty_infra_id = empty_infra.id; let track: InfraObject = TrackSection { id: "test_track".into(), @@ -857,17 +859,15 @@ mod tests { .into(); for obj in [&track, &electrification, &operational_point, &speed_section] { - let req_create = get_create_operation_request(obj.clone(), empty_infra_id); - assert_eq!( - call_service(&app, req_create).await.status(), - StatusCode::OK - ); + apply_create_operation(obj, empty_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to create object"); } - let response = call_service(&app, auto_fixes_request(empty_infra_id)).await; - assert_eq!(response.status(), StatusCode::OK); - - let operations: Vec = read_body_json(response).await; + let operations: Vec = app + .fetch(auto_fixes_request(empty_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); for obj in [&track, &electrification, &operational_point, &speed_section] { assert!(!operations.contains(&Operation::Delete(DeleteOperation { @@ -881,9 +881,10 @@ mod tests { #[case(250., 1)] #[case(1250., 5)] async fn out_of_range_must_be_deleted(#[case] pos: f64, #[case] error_count: usize) { - let app = create_test_service().await; - let empty_infra = empty_infra(db_pool()).await; - let empty_infra_id = empty_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let empty_infra_id = empty_infra.id; let track: InfraObject = TrackSection { id: "test_track".into(), @@ -918,17 +919,16 @@ mod tests { .into(); for obj in [&track, &signal, &detector, &buffer_stop] { - let req_create = get_create_operation_request(obj.clone(), empty_infra_id); - assert_eq!( - call_service(&app, req_create).await.status(), - StatusCode::OK - ); + apply_create_operation(obj, empty_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to create object"); } - let response = call_service(&app, auto_fixes_request(empty_infra_id)).await; - assert_eq!(response.status(), StatusCode::OK); + let operations: Vec = app + .fetch(auto_fixes_request(empty_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); - let operations: Vec = read_body_json(response).await; assert_eq!(operations.len(), error_count); if !operations.len() == 5 { @@ -944,9 +944,10 @@ mod tests { #[rstest::rstest] async fn missing_track_extremity_buffer_stop_fix() { // GIVEN - let app = create_test_service().await; - let empty_infra = empty_infra(db_pool()).await; - let empty_infra_id = empty_infra.id(); + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let empty_infra_id = empty_infra.id; let track: InfraObject = TrackSection { id: "track_with_no_buffer_stops".into(), @@ -954,19 +955,17 @@ mod tests { ..Default::default() } .into(); - let req_create = get_create_operation_request(track.clone(), empty_infra_id); - assert_eq!( - call_service(&app, req_create).await.status(), - StatusCode::OK - ); + apply_create_operation(&track, empty_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to create track section object"); // WHEN - let response = call_service(&app, auto_fixes_request(empty_infra_id)).await; + let operations: Vec = app + .fetch(auto_fixes_request(empty_infra_id)) + .assert_status(StatusCode::OK) + .json_into(); // THEN - assert_eq!(response.status(), StatusCode::OK); - - let operations: Vec = read_body_json(response).await; assert_eq!(operations.len(), 2); let mut positions = vec![]; for operation in operations { diff --git a/editoast/src/views/infra/edition.rs b/editoast/src/views/infra/edition.rs index d5a2f493df6..9c2cda13b98 100644 --- a/editoast/src/views/infra/edition.rs +++ b/editoast/src/views/infra/edition.rs @@ -3,6 +3,7 @@ use actix_web::web::Data; use actix_web::web::Json; use actix_web::web::Path; use chashmap::CHashMap; +use diesel_async::AsyncConnection; use editoast_derive::EditoastError; use editoast_schemas::infra::ApplicableDirectionsTrackRange; use editoast_schemas::infra::DirectionalTrackRange; @@ -19,6 +20,7 @@ use itertools::Itertools; use json_patch::{AddOperation, Patch, PatchOperation, RemoveOperation, ReplaceOperation}; use serde_json::json; use std::collections::HashMap; +use std::ops::DerefMut; use thiserror::Error; use tracing::error; use tracing::info; @@ -41,7 +43,7 @@ use crate::views::infra::InfraApiError; use crate::views::infra::InfraIdParam; use crate::RedisClient; use editoast_models::DbConnection; -use editoast_models::DbConnectionPool; +use editoast_models::DbConnectionPoolV2; use editoast_schemas::infra::InfraObject; crate::routes! { @@ -71,20 +73,27 @@ crate::routes! { pub async fn edit<'a>( infra: Path, operations: Json>, - db_pool: Data, + db_pool: Data, infra_caches: Data>, redis_client: Data, map_layers: Data, ) -> Result>> { let infra_id = infra.infra_id; - let mut conn = db_pool.get().await?; // TODO: lock for update - let mut infra = - Infra::retrieve_or_fail(&mut conn, infra_id, || InfraApiError::NotFound { infra_id }) + let mut infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), infra_id, || { + InfraApiError::NotFound { infra_id } + }) + .await?; + let mut infra_cache = + InfraCache::get_or_load_mut(db_pool.get().await?.deref_mut(), &infra_caches, &infra) .await?; - let mut infra_cache = InfraCache::get_or_load_mut(&mut conn, &infra_caches, &infra).await?; - let operation_results = - apply_edit(&mut conn, &mut infra, &operations, &mut infra_cache).await?; + let operation_results = apply_edit( + db_pool.get().await?.deref_mut(), + &mut infra, + &operations, + &mut infra_cache, + ) + .await?; let mut conn = redis_client.get_connection().await?; map::invalidate_all( @@ -109,7 +118,7 @@ pub async fn edit<'a>( pub async fn split_track_section<'a>( infra: Path, payload: Json, - db_pool: Data, + db_pool: Data, infra_caches: Data>, redis_client: Data, map_layers: Data, @@ -121,12 +130,15 @@ pub async fn split_track_section<'a>( offset = payload.offset, "Splitting track section" ); - let conn = &mut db_pool.get().await?; // Check the infra - let mut infra = - Infra::retrieve_or_fail(conn, infra_id, || InfraApiError::NotFound { infra_id }).await?; - let mut infra_cache = InfraCache::get_or_load_mut(conn, &infra_caches, &infra).await?; + let mut infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), infra_id, || { + InfraApiError::NotFound { infra_id } + }) + .await?; + let mut infra_cache = + InfraCache::get_or_load_mut(db_pool.get().await?.deref_mut(), &infra_caches, &infra) + .await?; // Get tracks cache if it exists let tracksection_cached = infra_cache.get_track_section(&payload.track)?.clone(); @@ -145,7 +157,11 @@ pub async fn split_track_section<'a>( // Calling the DB to get the full object and also the split geo let result = infra - .get_split_track_section_with_data(conn, payload.track.clone(), distance_fraction) + .get_split_track_section_with_data( + db_pool.get().await?.deref_mut(), + payload.track.clone(), + distance_fraction, + ) .await?; let tracksection_data = result.expect("Failed to retrieve split track section data. Ensure the track ID and distance fraction are valid.").clone(); let tracksection = tracksection_data.railjson.as_ref().clone(); @@ -306,7 +322,13 @@ pub async fn split_track_section<'a>( })); // Apply operations - apply_edit(conn, &mut infra, &operations, &mut infra_cache).await?; + apply_edit( + db_pool.get().await?.deref_mut(), + &mut infra, + &operations, + &mut infra_cache, + ) + .await?; let mut conn = redis_client.get_connection().await?; map::invalidate_all( &mut conn, @@ -823,8 +845,7 @@ async fn apply_edit( // Apply modifications in one transaction connection - .build_transaction() - .run(|conn| { + .transaction(|conn| { Box::pin(async { let mut railjsons = vec![]; let mut cache_operations = vec![]; @@ -886,8 +907,6 @@ enum EditionError { #[cfg(test)] pub mod tests { use actix_web::http::StatusCode; - use actix_web::test::call_and_read_body_json; - use actix_web::test::call_service; use actix_web::test::TestRequest; use pretty_assertions::assert_eq; use rstest::rstest; @@ -897,103 +916,100 @@ pub mod tests { use crate::fixtures::tests::small_infra; use crate::generated_data::infra_error::InfraError; use crate::generated_data::infra_error::InfraErrorType; + use crate::modelsv2::fixtures::create_small_infra; use crate::modelsv2::infra::ObjectQueryable; use crate::views::infra::errors::query_errors; - use crate::views::tests::create_test_service; + use crate::views::test_app::TestAppBuilder; #[rstest] async fn split_track_section_should_return_404_with_bad_infra() { // Init - let app = create_test_service().await; + let app = TestAppBuilder::default_app(); // Make a call with a bad infra ID - let req = TestRequest::post() + let request = TestRequest::post() .uri("/infra/123456789/split_track_section/") .set_json(json!({ "track": String::from("INVALID-ID"), "offset": 1, })) .to_request(); - let res = call_service(&app, req).await; // Check that we receive a 404 - assert_eq!(res.status(), StatusCode::NOT_FOUND); + app.fetch(request).assert_status(StatusCode::NOT_FOUND); } #[rstest] async fn split_track_section_should_return_404_with_bad_id() { // Init - let pg_db_pool = db_pool(); - let small_infra = small_infra(pg_db_pool.clone()).await; - let app = create_test_service().await; + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; // Make a call with a bad ID - let req = TestRequest::post() - .uri(format!("/infra/{}/split_track_section", small_infra.id()).as_str()) + let request = TestRequest::post() + .uri(format!("/infra/{}/split_track_section", small_infra.id).as_str()) .set_json(json!({ "track":"INVALID-ID", "offset": 1, })) .to_request(); - let res = call_service(&app, req).await; // Check that we receive a 404 - assert_eq!(res.status(), StatusCode::NOT_FOUND); + app.fetch(request).assert_status(StatusCode::NOT_FOUND); } #[rstest] async fn split_track_section_should_fail_with_bad_distance() { // Init - let pg_db_pool = db_pool(); - let small_infra = small_infra(pg_db_pool.clone()).await; - let app = create_test_service().await; + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; // Make a call with a bad distance - let req = TestRequest::post() - .uri(format!("/infra/{}/split_track_section", small_infra.id()).as_str()) + let request = TestRequest::post() + .uri(format!("/infra/{}/split_track_section", small_infra.id).as_str()) .set_json(json!({ "track": "TA0", "offset": 5000000, })) .to_request(); - let res = call_service(&app, req).await; // Check that we receive an error - assert_eq!(res.status(), StatusCode::BAD_REQUEST); + app.fetch(request).assert_status(StatusCode::BAD_REQUEST); } #[rstest] async fn split_track_section_should_work() { // Init - let pg_db_pool = db_pool(); - let conn = &mut pg_db_pool.get().await.unwrap(); - let small_infra = small_infra(pg_db_pool.clone()).await; - let app = create_test_service().await; + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; // Refresh the infra to get the good number of infra errors let req_refresh = TestRequest::post() - .uri(format!("/infra/refresh/?infras={}&force=true", small_infra.id()).as_str()) + .uri(format!("/infra/refresh/?infras={}&force=true", small_infra.id).as_str()) .to_request(); - call_service(&app, req_refresh).await; + app.fetch(req_refresh).assert_status(StatusCode::OK); // Get infra errors - let (init_errors, _) = query_errors(conn, &small_infra).await; + let (init_errors, _) = query_errors(db_pool.get_ok().deref_mut(), &small_infra).await; // Make a call to split the track section - let req = TestRequest::post() - .uri(format!("/infra/{}/split_track_section", small_infra.id()).as_str()) + let request = TestRequest::post() + .uri(format!("/infra/{}/split_track_section", small_infra.id).as_str()) .set_json(json!({ "track": "TA0", "offset": 1000000, })) .to_request(); - let res: Vec = call_and_read_body_json(&app, req).await; + let res: Vec = app.fetch(request).assert_status(StatusCode::OK).json_into(); // Check the response assert_eq!(res.len(), 2); // Check that infra errors has not increased with the split (omit route error for now) - let (errors, _) = query_errors(conn, &small_infra).await; + let (errors, _) = query_errors(db_pool.get_ok().deref_mut(), &small_infra).await; let errors_without_routes: Vec = errors .into_iter() .filter(|e| { diff --git a/editoast/src/views/infra/mod.rs b/editoast/src/views/infra/mod.rs index 9316f198e80..52ff99ba3af 100644 --- a/editoast/src/views/infra/mod.rs +++ b/editoast/src/views/infra/mod.rs @@ -46,7 +46,6 @@ use crate::modelsv2::Infra; use crate::views::pagination::PaginatedList as _; use crate::views::pagination::PaginationQueryParam; use crate::RedisClient; -use editoast_models::DbConnectionPool; use editoast_models::DbConnectionPoolV2; use editoast_schemas::infra::SwitchType; @@ -126,13 +125,12 @@ struct RefreshResponse { )] #[post("/refresh")] async fn refresh( - db_pool: Data, + db_pool: Data, redis_client: Data, Query(query_params): Query, infra_caches: Data>, map_layers: Data, ) -> Result> { - let mut conn = db_pool.get().await?; // Use a transaction to give scope to infra list lock let RefreshQueryParams { force, @@ -141,11 +139,13 @@ async fn refresh( let infras_list = if infras.is_empty() { // Retrieve all available infra - Infra::all(&mut conn).await + Infra::all(db_pool.get().await?.deref_mut()).await } else { // Retrieve given infras - Infra::retrieve_batch_or_fail(&mut conn, infras, |missing| InfraApiError::NotFound { - infra_id: missing.into_iter().next().unwrap(), + Infra::retrieve_batch_or_fail(db_pool.get().await?.deref_mut(), infras, |missing| { + InfraApiError::NotFound { + infra_id: missing.into_iter().next().unwrap(), + } }) .await? }; @@ -155,7 +155,9 @@ async fn refresh( let mut infra_refreshed = vec![]; for mut infra in infras_list { - let infra_cache = InfraCache::get_or_load(&mut conn, &infra_caches, &infra).await?; + let infra_cache = + InfraCache::get_or_load(db_pool.get().await?.deref_mut(), &infra_caches, &infra) + .await?; if infra.refresh(db_pool.clone(), force, &infra_cache).await? { infra_refreshed.push(infra.id); } @@ -191,7 +193,7 @@ struct InfraListResponse { )] #[get("")] async fn list( - db_pool: Data, + db_pool: Data, core: Data, pagination_params: Query, ) -> Result> { @@ -330,15 +332,16 @@ struct CloneQuery { #[post("/clone")] async fn clone( params: Path, - db_pool: Data, + db_pool: Data, Query(CloneQuery { name }): Query, ) -> Result> { - let conn = &mut db_pool.get().await?; - let infra = Infra::retrieve_or_fail(conn, params.infra_id, || InfraApiError::NotFound { - infra_id: params.infra_id, + let infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), params.infra_id, || { + InfraApiError::NotFound { + infra_id: params.infra_id, + } }) .await?; - let cloned_infra = infra.clone(conn, name).await?; + let cloned_infra = infra.clone(db_pool.get().await?.deref_mut(), name).await?; Ok(Json(cloned_infra.id)) } @@ -620,9 +623,6 @@ pub async fn fetch_all_infra_states( pub mod tests { use actix_http::Request; use actix_web::http::StatusCode; - use actix_web::test as actix_test; - use actix_web::test::call_service; - use actix_web::test::read_body_json; use actix_web::test::TestRequest; use diesel::sql_query; use diesel::sql_types::BigInt; @@ -631,30 +631,20 @@ pub mod tests { use rstest::rstest; use serde_json::json; use std::ops::DerefMut; - use std::sync::Arc; use strum::IntoEnumIterator; use super::*; - use crate::assert_status_and_read; use crate::core::mocking::MockingClient; - use crate::fixtures::tests::db_pool; - use crate::fixtures::tests::empty_infra; - use crate::fixtures::tests::small_infra; - use crate::fixtures::tests::IntoFixture; - use crate::fixtures::tests::TestFixture; use crate::generated_data; use crate::infra_cache::operation::create::apply_create_operation; - use crate::infra_cache::operation::Operation; use crate::modelsv2::fixtures::create_empty_infra; use crate::modelsv2::fixtures::create_rolling_stock_with_energy_sources; + use crate::modelsv2::fixtures::create_small_infra; use crate::modelsv2::get_geometry_layer_table; use crate::modelsv2::get_table; use crate::modelsv2::infra::DEFAULT_INFRA_VERSION; use crate::views::test_app::TestAppBuilder; - use crate::views::tests::create_test_service; - use crate::views::tests::create_test_service_with_core_client; use editoast_schemas::infra::Electrification; - use editoast_schemas::infra::InfraObject; use editoast_schemas::infra::Speed; use editoast_schemas::infra::SpeedSection; use editoast_schemas::infra::SwitchType; @@ -667,29 +657,22 @@ pub mod tests { .to_request() } - pub fn create_object_request(infra_id: i64, obj: InfraObject) -> Request { - let operation = Operation::Create(Box::new(obj)); - TestRequest::post() - .uri(format!("/infra/{infra_id}/").as_str()) - .set_json(json!([operation])) - .to_request() - } - #[rstest] - async fn infra_clone_empty(db_pool: Arc) { - let conn = &mut db_pool.get().await.unwrap(); - let infra = empty_infra(db_pool.clone()).await; - let app = create_test_service().await; - let req = TestRequest::post() - .uri(format!("/infra/{}/clone/?name=cloned_infra", infra.id).as_str()) + #[serial_test::serial] + async fn infra_clone_empty() { + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + + let request = TestRequest::post() + .uri(format!("/infra/{}/clone/?name=cloned_infra", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app, req).await; - let cloned_infra_id: i64 = assert_status_and_read!(response, StatusCode::OK); - let cloned_infra = Infra::retrieve(conn, cloned_infra_id) + + let cloned_infra_id: i64 = app.fetch(request).assert_status(StatusCode::OK).json_into(); + let cloned_infra = Infra::retrieve(db_pool.get_ok().deref_mut(), cloned_infra_id) .await .unwrap() - .expect("infra was not cloned") - .into_fixture(db_pool); + .expect("infra was not cloned"); assert_eq!(cloned_infra.name, "cloned_infra"); } @@ -700,45 +683,42 @@ pub mod tests { } #[rstest] // Slow test - async fn infra_clone(db_pool: Arc) { - let app = create_test_service().await; - let small_infra = small_infra(db_pool.clone()).await; + #[serial_test::serial] + async fn infra_clone() { + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; let small_infra_id = small_infra.id; - let conn = &mut db_pool.get().await.unwrap(); - let infra_cache = InfraCache::load(conn, &small_infra).await.unwrap(); + let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &small_infra) + .await + .unwrap(); generated_data::refresh_all(db_pool.clone(), small_infra_id, &infra_cache) .await .unwrap(); - let switch_type: InfraObject = SwitchType { + let switch_type = SwitchType { id: "test_switch_type".into(), ..Default::default() } .into(); - - let create_operation = TestRequest::post() - .uri(format!("/infra/{small_infra_id}/").as_str()) - .set_json(json!([Operation::Create(Box::new(switch_type))])) - .to_request(); - - assert_eq!( - call_service(&app, create_operation).await.status(), - StatusCode::OK - ); + apply_create_operation(&switch_type, small_infra_id, db_pool.get_ok().deref_mut()) + .await + .expect("Failed to create switch_type object"); let req_clone = TestRequest::post() .uri(format!("/infra/{}/clone/?name=cloned_infra", small_infra_id).as_str()) .to_request(); - let response = call_service(&app, req_clone).await; - assert_eq!(response.status(), StatusCode::OK); - let cloned_infra_id: i64 = read_body_json(response).await; - let _cloned_infra = Infra::retrieve(conn, cloned_infra_id) + let cloned_infra_id: i64 = app + .fetch(req_clone) + .assert_status(StatusCode::OK) + .json_into(); + + let _cloned_infra = Infra::retrieve(db_pool.get_ok().deref_mut(), cloned_infra_id) .await .unwrap() - .expect("infra was not cloned") - .into_fixture(db_pool); + .expect("infra was not cloned"); let mut tables = vec!["infra_layer_error"]; for object in ObjectType::iter() { @@ -757,7 +737,7 @@ pub mod tests { table )) .bind::(inf_id) - .get_result::(conn) + .get_result::(db_pool.get_ok().deref_mut()) .await .unwrap(); @@ -788,39 +768,45 @@ pub mod tests { let db_pool = app.db_pool(); let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let response = call_service(&app.service, delete_infra_request(empty_infra.id)).await; - assert_eq!(response.status(), StatusCode::NO_CONTENT); + app.fetch(delete_infra_request(empty_infra.id)) + .assert_status(StatusCode::NO_CONTENT); - let response = call_service(&app.service, delete_infra_request(empty_infra.id)).await; - assert_eq!(response.status(), StatusCode::NOT_FOUND); + app.fetch(delete_infra_request(empty_infra.id)) + .assert_status(StatusCode::NOT_FOUND); } - #[actix_test] + #[rstest] async fn infra_list() { + let db_pool = DbConnectionPoolV2::for_tests(); let mut core = MockingClient::new(); core.stub("/cache_status") .method(reqwest::Method::POST) .response(StatusCode::OK) .body("{}") .finish(); - let app = create_test_service_with_core_client(core).await; - let req = TestRequest::get().uri("/infra/").to_request(); - let response = call_service(&app, req).await; - assert_eq!(response.status(), StatusCode::OK); + + let app = TestAppBuilder::new() + .db_pool(db_pool.clone()) + .core_client(core.into()) + .build(); + let request = TestRequest::get().uri("/infra/").to_request(); + + app.fetch(request).assert_status(StatusCode::OK); } #[rstest] - async fn default_infra_create(db_pool: Arc) { + async fn default_infra_create() { let app = TestAppBuilder::default_app(); - let req = TestRequest::post() + + let request = TestRequest::post() .uri("/infra") .set_json(json!({ "name": "create_infra_test" })) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::CREATED); - let infra = read_body_json::(response) - .await - .into_fixture(db_pool); + let infra: Infra = app + .fetch(request) + .assert_status(StatusCode::CREATED) + .json_into(); + assert_eq!(infra.name, "create_infra_test"); assert_eq!(infra.railjson_version, RAILJSON_VERSION); assert_eq!(infra.version, DEFAULT_INFRA_VERSION); @@ -847,8 +833,8 @@ pub mod tests { let req = TestRequest::get() .uri(format!("/infra/{}", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::OK); + + app.fetch(req).assert_status(StatusCode::OK); empty_infra .delete(db_pool.get_ok().deref_mut()) @@ -858,8 +844,8 @@ pub mod tests { let req = TestRequest::get() .uri(format!("/infra/{}", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::NOT_FOUND); + + app.fetch(req).assert_status(StatusCode::NOT_FOUND); } #[rstest] @@ -872,9 +858,9 @@ pub mod tests { .uri(format!("/infra/{}", empty_infra.id).as_str()) .set_json(json!({"name": "rename_test"})) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::OK); - let infra: Infra = read_body_json(response).await; + + let infra: Infra = app.fetch(req).assert_status(StatusCode::OK).json_into(); + assert_eq!(infra.name, "rename_test"); } @@ -884,30 +870,35 @@ pub mod tests { } #[rstest] - async fn infra_refresh(#[future] empty_infra: TestFixture) { - let empty_infra = empty_infra.await; - let app = create_test_service().await; + async fn infra_refresh() { + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let req = TestRequest::post() .uri(format!("/infra/refresh/?infras={}", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app, req).await; - assert_eq!(response.status(), StatusCode::OK); - let refreshed_infras: InfraRefreshedResponse = read_body_json(response).await; + + let refreshed_infras: InfraRefreshedResponse = + app.fetch(req).assert_status(StatusCode::OK).json_into(); assert_eq!(refreshed_infras.infra_refreshed, vec![empty_infra.id]); } - #[rstest] // Slow test - async fn infra_refresh_force(#[future] empty_infra: TestFixture) { - let empty_infra = empty_infra.await; - let app = create_test_service().await; + #[rstest] + // Slow test + // PostgreSQL deadlock can happen in this test, see section `Deadlock` of [DbConnectionPoolV2::get] for more information + #[serial_test::serial] + async fn infra_refresh_force() { + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; let req = TestRequest::post() - .uri(format!("/infra/refresh/?infras={}&force=true", empty_infra.id()).as_str()) + .uri(format!("/infra/refresh/?infras={}&force=true", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app, req).await; - assert_eq!(response.status(), StatusCode::OK); - let refreshed_infras: InfraRefreshedResponse = read_body_json(response).await; - assert!(refreshed_infras.infra_refreshed.contains(&empty_infra.id())); + let refreshed_infras: InfraRefreshedResponse = + app.fetch(req).assert_status(StatusCode::OK).json_into(); + assert!(refreshed_infras.infra_refreshed.contains(&empty_infra.id)); } #[rstest] @@ -928,9 +919,10 @@ pub mod tests { let req = TestRequest::get() .uri(format!("/infra/{}/speed_limit_tags/", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::OK); - let speed_limit_tags: Vec = read_body_json(response).await; + + let speed_limit_tags: Vec = + app.fetch(req).assert_status(StatusCode::OK).json_into(); + assert_eq!(speed_limit_tags, vec!["test_tag"]); } @@ -970,9 +962,9 @@ pub mod tests { .await; let req = TestRequest::get().uri("/infra/voltages/").to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::OK); - let voltages: Vec = read_body_json(response).await; + + let voltages: Vec = app.fetch(req).assert_status(StatusCode::OK).json_into(); + assert!(voltages.len() >= 3); assert!(voltages.contains(&String::from("0V"))); assert!(voltages.contains(&String::from("1V"))); @@ -1018,15 +1010,13 @@ pub mod tests { .as_str(), ) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::OK); if !include_rolling_stock_modes { - let voltages: Vec = read_body_json(response).await; + let voltages: Vec = app.fetch(req).assert_status(StatusCode::OK).json_into(); assert_eq!(voltages[0], "0"); assert_eq!(voltages.len(), 1); } else { - let voltages: Vec = read_body_json(response).await; + let voltages: Vec = app.fetch(req).assert_status(StatusCode::OK).json_into(); assert!(voltages.contains(&String::from("25000V"))); assert!(voltages.len() >= 2); } @@ -1041,9 +1031,10 @@ pub mod tests { let req = TestRequest::get() .uri(format!("/infra/{}/switch_types/", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::OK); - let switch_types: Vec = read_body_json(response).await; + + let switch_types: Vec = + app.fetch(req).assert_status(StatusCode::OK).json_into(); + assert_eq!(switch_types.len(), 5); } @@ -1067,8 +1058,8 @@ pub mod tests { let req = TestRequest::post() .uri(format!("/infra/{}/lock/", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::NO_CONTENT); + + app.fetch(req).assert_status(StatusCode::NO_CONTENT); // Check lock let infra = Infra::retrieve(db_pool.get_ok().deref_mut(), empty_infra.id) @@ -1081,8 +1072,8 @@ pub mod tests { let req = TestRequest::post() .uri(format!("/infra/{}/unlock/", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::NO_CONTENT); + + app.fetch(req).assert_status(StatusCode::NO_CONTENT); // Check lock let infra = Infra::retrieve(db_pool.get_ok().deref_mut(), empty_infra.id) @@ -1111,7 +1102,7 @@ pub mod tests { let req = TestRequest::post() .uri(format!("/infra/{}/load", empty_infra.id).as_str()) .to_request(); - let response = call_service(&app.service, req).await; - assert_eq!(response.status(), StatusCode::NO_CONTENT); + + app.fetch(req).assert_status(StatusCode::NO_CONTENT); } } diff --git a/editoast/src/views/infra/railjson.rs b/editoast/src/views/infra/railjson.rs index feee2d2fdfe..d5f2edd6718 100644 --- a/editoast/src/views/infra/railjson.rs +++ b/editoast/src/views/infra/railjson.rs @@ -15,6 +15,7 @@ use enum_map::EnumMap; use futures::future::try_join_all; use serde::Deserialize; use serde::Serialize; +use std::ops::DerefMut; use strum::IntoEnumIterator; use thiserror::Error; use utoipa::IntoParams; @@ -26,7 +27,7 @@ use crate::modelsv2::prelude::*; use crate::modelsv2::Infra; use crate::views::infra::InfraApiError; use crate::views::infra::InfraIdParam; -use editoast_models::DbConnectionPool; +use editoast_models::DbConnectionPoolV2; use editoast_schemas::primitives::ObjectType; crate::routes! { @@ -53,12 +54,13 @@ enum ListErrorsRailjson { #[get("/{infra_id}/railjson")] async fn get_railjson( infra: Path, - db_pool: Data, + db_pool: Data, ) -> Result { let infra_id = infra.infra_id; - let conn = &mut db_pool.get().await?; - let infra_meta = - Infra::retrieve_or_fail(conn, infra_id, || InfraApiError::NotFound { infra_id }).await?; + let infra_meta = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), infra_id, || { + InfraApiError::NotFound { infra_id } + }) + .await?; let futures: Vec<_> = ObjectType::iter() .map(|object_type| (object_type, db_pool.get())) @@ -150,7 +152,7 @@ struct PostRailjsonResponse { async fn post_railjson( params: Query, railjson: Json, - db_pool: Data, + db_pool: Data, infra_caches: Data>, ) -> Result> { if railjson.version != RAILJSON_VERSION { @@ -162,17 +164,18 @@ async fn post_railjson( let mut infra = Infra::changeset() .name(params.name.clone()) .last_railjson_version() - .persist(railjson, db_pool.clone()) + .persist(railjson, db_pool.get().await?.deref_mut()) .await?; let infra_id = infra.id; - let mut conn = db_pool.get().await?; infra - .bump_version(&mut conn) + .bump_version(db_pool.get().await?.deref_mut()) .await .map_err(|_| InfraApiError::NotFound { infra_id })?; if params.generate_data { - let infra_cache = InfraCache::get_or_load(&mut conn, &infra_caches, &infra).await?; + let infra_cache = + InfraCache::get_or_load(db_pool.get().await?.deref_mut(), &infra_caches, &infra) + .await?; infra.refresh(db_pool, true, &infra_cache).await?; } @@ -183,43 +186,47 @@ async fn post_railjson( mod tests { use actix_http::StatusCode; use actix_web::test as actix_test; - use actix_web::test::call_service; - use actix_web::test::read_body_json; - use rstest::*; - use std::sync::Arc; + use pretty_assertions::assert_eq; + use rstest::rstest; use super::*; - use crate::fixtures::tests::db_pool; - use crate::fixtures::tests::empty_infra; - use crate::fixtures::tests::TestFixture; - use crate::views::infra::tests::create_object_request; - use crate::views::tests::create_test_service; + use crate::infra_cache::operation::create::apply_create_operation; + use crate::modelsv2::fixtures::create_empty_infra; + use crate::views::test_app::TestAppBuilder; use editoast_schemas::infra::SwitchType; #[rstest] + // PostgreSQL deadlock can happen in this test, see section `Deadlock` of [DbConnectionPoolV2::get] for more information #[serial_test::serial] - async fn test_get_railjson(#[future] empty_infra: TestFixture) { - let empty_infra = empty_infra.await; - let app = create_test_service().await; + async fn test_get_railjson() { + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let empty_infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let req = create_object_request(empty_infra.id(), SwitchType::default().into()); - let response = call_service(&app, req).await; - assert!(response.status().is_success()); + apply_create_operation( + &SwitchType::default().into(), + empty_infra.id, + db_pool.get_ok().deref_mut(), + ) + .await + .expect("Failed to create SwitchType object"); - let req = actix_test::TestRequest::get() - .uri(&format!("/infra/{}/railjson", empty_infra.id())) + let request = actix_test::TestRequest::get() + .uri(&format!("/infra/{}/railjson", empty_infra.id)) .to_request(); - let response = call_service(&app, req).await; - assert_eq!(response.status(), StatusCode::OK); - let railjson: RailJson = read_body_json(response).await; + + let railjson: RailJson = app.fetch(request).assert_status(StatusCode::OK).json_into(); + assert_eq!(railjson.version, RAILJSON_VERSION); assert_eq!(railjson.extended_switch_types.len(), 1); } #[rstest] + // PostgreSQL deadlock can happen in this test, see section `Deadlock` of [DbConnectionPoolV2::get] for more information #[serial_test::serial] - async fn test_post_railjson(db_pool: Arc) { - let app = create_test_service().await; + async fn test_post_railjson() { + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); let railjson = RailJson { buffer_stops: (0..10).map(|_| Default::default()).collect(), @@ -239,11 +246,13 @@ mod tests { .uri("/infra/railjson?name=post_railjson_test") .set_json(&railjson) .to_request(); - let response = call_service(&app, req).await; - assert_eq!(response.status(), StatusCode::OK); - let res: PostRailjsonResponse = read_body_json(response).await; - let conn = &mut db_pool.get().await.unwrap(); - assert!(Infra::delete_static(conn, res.infra).await.unwrap()); + let res: PostRailjsonResponse = app.fetch(req).assert_status(StatusCode::OK).json_into(); + + assert!( + Infra::delete_static(db_pool.get_ok().deref_mut(), res.infra) + .await + .unwrap() + ); } } diff --git a/editoast/src/views/infra/routes.rs b/editoast/src/views/infra/routes.rs index 6dc996fb9f4..739c239070a 100644 --- a/editoast/src/views/infra/routes.rs +++ b/editoast/src/views/infra/routes.rs @@ -22,7 +22,6 @@ use crate::modelsv2::Infra; use crate::views::infra::InfraApiError; use crate::views::infra::InfraIdParam; use crate::views::params::List; -use editoast_models::DbConnectionPool; use editoast_models::DbConnectionPoolV2; crate::routes! { @@ -192,12 +191,13 @@ async fn get_routes_track_ranges<'a>( async fn get_routes_nodes( params: Path, infra_caches: Data>, - db_pool: Data, + db_pool: Data, Json(node_states): Json>>, ) -> Result> { - let conn = &mut db_pool.get().await?; - let infra = Infra::retrieve_or_fail(conn, params.infra_id, || InfraApiError::NotFound { - infra_id: params.infra_id, + let infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), params.infra_id, || { + InfraApiError::NotFound { + infra_id: params.infra_id, + } }) .await?; @@ -205,7 +205,8 @@ async fn get_routes_nodes( return Ok(Json(RoutesFromNodesPositions::default())); } - let infra_cache = InfraCache::get_or_load(conn, &infra_caches, &infra).await?; + let infra_cache = + InfraCache::get_or_load(db_pool.get().await?.deref_mut(), &infra_caches, &infra).await?; let routes_cache = infra_cache.routes(); let filtered_routes = routes_cache @@ -262,7 +263,6 @@ async fn get_routes_nodes( #[cfg(test)] mod tests { use actix_http::StatusCode; - use actix_web::test::call_service; use actix_web::test::TestRequest; use pretty_assertions::assert_eq; use rstest::rstest; @@ -270,16 +270,13 @@ mod tests { use std::collections::HashMap; use std::collections::HashSet; - use crate::assert_status_and_read; - use crate::fixtures::tests::db_pool; - use crate::fixtures::tests::small_infra; use crate::infra_cache::operation::create::apply_create_operation; use crate::modelsv2::fixtures::create_empty_infra; + use crate::modelsv2::fixtures::create_small_infra; use crate::views::infra::routes::RoutesFromNodesPositions; use crate::views::infra::routes::RoutesResponse; use crate::views::infra::routes::WaypointType; use crate::views::test_app::TestAppBuilder; - use crate::views::tests::create_test_service; use editoast_schemas::infra::BufferStop; use editoast_schemas::infra::Detector; use editoast_schemas::infra::Route; @@ -354,8 +351,9 @@ mod tests { ), ]; - let app = create_test_service().await; - let small_infra = small_infra(db_pool()).await; + let app = TestAppBuilder::default_app(); + let db_pool = app.db_pool(); + let small_infra = create_small_infra(db_pool.get_ok().deref_mut()).await; fn compare_result(got: RoutesFromNodesPositions, expected: RoutesFromNodesPositions) { let mut got_routes = got.routes; @@ -393,12 +391,13 @@ mod tests { available_node_positions: expected.1.into_iter().collect::>(), }; let request = TestRequest::post() - .uri(&format!("/infra/{}/routes/nodes", small_infra.id())) + .uri(&format!("/infra/{}/routes/nodes", small_infra.id)) .set_json(¶ms) .to_request(); println!("{request:?} body:\n {params}"); - let response = call_service(&app, request).await; - let got: RoutesFromNodesPositions = assert_status_and_read!(response, StatusCode::OK); + + let got: RoutesFromNodesPositions = + app.fetch(request).assert_status(StatusCode::OK).json_into(); compare_result(got, expected_result) } } diff --git a/editoast/src/views/pathfinding/electrical_profiles.rs b/editoast/src/views/pathfinding/electrical_profiles.rs index b0980766a64..2903e5fd7cc 100644 --- a/editoast/src/views/pathfinding/electrical_profiles.rs +++ b/editoast/src/views/pathfinding/electrical_profiles.rs @@ -219,7 +219,6 @@ mod tests { } #[rstest] - #[serial_test::serial] async fn test_map_electrical_profiles( #[future] electrical_profile_set: TestFixture, ) { @@ -253,7 +252,6 @@ mod tests { } #[rstest] - #[serial_test::serial] async fn test_view_electrical_profiles_on_path( db_pool: Arc, #[future] empty_infra: TestFixture,