diff --git a/editoast/editoast_derive/src/model.rs b/editoast/editoast_derive/src/model.rs index 2068a5dc5ea..9a242cd089f 100644 --- a/editoast/editoast_derive/src/model.rs +++ b/editoast/editoast_derive/src/model.rs @@ -84,7 +84,7 @@ fn create_functions(config: &Config) -> TokenStream { quote! { #[async_trait::async_trait] impl crate::models::Create for #model_name { - async fn create_conn(self, conn: &mut editoast_models::DbConnection) -> crate::error::Result { + async fn create_conn(self, conn: &mut editoast_models::DieselConnection) -> crate::error::Result { use #table::table; use diesel_async::RunQueryDsl; @@ -126,7 +126,7 @@ fn retrieve_functions(config: &Config) -> TokenStream { quote! { #[async_trait::async_trait] impl crate::models::Retrieve for #model_name { - async fn retrieve_conn(conn: &mut editoast_models::DbConnection, obj_id: i64) -> crate::error::Result> { + async fn retrieve_conn(conn: &mut editoast_models::DieselConnection, obj_id: i64) -> crate::error::Result> { use #table::table; use #table::dsl; use diesel_async::RunQueryDsl; @@ -174,7 +174,7 @@ fn delete_functions(config: &Config) -> TokenStream { quote! { #[async_trait::async_trait] impl crate::models::Delete for #model_name { - async fn delete_conn(conn: &mut editoast_models::DbConnection, obj_id: i64) -> crate::error::Result { + async fn delete_conn(conn: &mut editoast_models::DieselConnection, obj_id: i64) -> crate::error::Result { use #table::table; use #table::dsl; use diesel_async::RunQueryDsl; @@ -207,7 +207,7 @@ fn update_functions(config: &Config) -> TokenStream { quote! { #[async_trait::async_trait] impl crate::models::Update for #model_name { - async fn update_conn(self, conn: &mut editoast_models::DbConnection, obj_id: i64) -> crate::error::Result> { + async fn update_conn(self, conn: &mut editoast_models::DieselConnection, obj_id: i64) -> crate::error::Result> { use #table::table; match diesel::update(table.find(obj_id)).set(&self).get_result(conn).await diff --git a/editoast/editoast_derive/src/modelv2/codegen/count_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/count_impl.rs index 44f621ab587..3472c8e5f69 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/count_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/count_impl.rs @@ -28,6 +28,7 @@ impl ToTokens for CountImpl { use diesel::QueryDsl; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; + use std::ops::DerefMut; let mut query = #table_mod::table.select(diesel::dsl::count_star()).into_boxed(); @@ -48,7 +49,7 @@ impl ToTokens for CountImpl { } } - Ok(query.get_result::(conn).await? as u64) + Ok(query.get_result::(conn.write().await.deref_mut()).await? as u64) } } diff --git a/editoast/editoast_derive/src/modelv2/codegen/create_batch_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/create_batch_impl.rs index 34a7439ec29..587ff917411 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/create_batch_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/create_batch_impl.rs @@ -38,6 +38,7 @@ impl ToTokens for CreateBatchImpl { ) -> crate::error::Result { use crate::modelsv2::Model; use #table_mod::dsl; + use std::ops::DerefMut; use diesel::prelude::*; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; @@ -50,7 +51,7 @@ impl ToTokens for CreateBatchImpl { chunk => { diesel::insert_into(dsl::#table_name) .values(chunk) - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? .await? diff --git a/editoast/editoast_derive/src/modelv2/codegen/create_batch_with_key_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/create_batch_with_key_impl.rs index 03456ac3143..775b4fd3031 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/create_batch_with_key_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/create_batch_with_key_impl.rs @@ -43,6 +43,7 @@ impl ToTokens for CreateBatchWithKeyImpl { ) -> crate::error::Result { use crate::models::Identifiable; use crate::modelsv2::Model; + use std::ops::DerefMut; use #table_mod::dsl; use diesel::prelude::*; use diesel_async::RunQueryDsl; @@ -56,7 +57,7 @@ impl ToTokens for CreateBatchWithKeyImpl { chunk => { diesel::insert_into(dsl::#table_name) .values(chunk) - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await .map(|s| { s.map_ok(|row| { diff --git a/editoast/editoast_derive/src/modelv2/codegen/create_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/create_impl.rs index 4a4911883fd..44ed1e2bb38 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/create_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/create_impl.rs @@ -28,9 +28,10 @@ impl ToTokens for CreateImpl { conn: &mut editoast_models::DbConnection, ) -> crate::error::Result<#model> { use diesel_async::RunQueryDsl; + use std::ops::DerefMut; diesel::insert_into(#table_mod::table) .values(&self) - .get_result::<#row>(conn) + .get_result::<#row>(conn.write().await.deref_mut()) .await .map(Into::into) .map_err(Into::into) diff --git a/editoast/editoast_derive/src/modelv2/codegen/delete_batch_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/delete_batch_impl.rs index 11b10dbd6c0..2b7ec5457e5 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/delete_batch_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/delete_batch_impl.rs @@ -38,6 +38,7 @@ impl ToTokens for DeleteBatchImpl { use #table_mod::dsl; use diesel::prelude::*; use diesel_async::RunQueryDsl; + use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); let counts = crate::chunked_for_libpq! { @@ -49,7 +50,7 @@ impl ToTokens for DeleteBatchImpl { for #id_ident in chunk.into_iter() { query = query.or_filter(#filters); } - query.execute(conn).await? + query.execute(conn.write().await.deref_mut()).await? } }; Ok(counts.into_iter().sum()) diff --git a/editoast/editoast_derive/src/modelv2/codegen/delete_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/delete_impl.rs index e13ff089ae1..2a7a7454920 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/delete_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/delete_impl.rs @@ -28,9 +28,10 @@ impl ToTokens for DeleteImpl { use diesel::prelude::*; use diesel_async::RunQueryDsl; use #table_mod::dsl; + use std::ops::DerefMut; let id = self.#primary_key; diesel::delete(#table_mod::table.find(id)) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await .map(|n| n == 1) .map_err(Into::into) diff --git a/editoast/editoast_derive/src/modelv2/codegen/delete_static_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/delete_static_impl.rs index c7fc8840a54..9b9981fa0f1 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/delete_static_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/delete_static_impl.rs @@ -35,10 +35,11 @@ impl ToTokens for DeleteStaticImpl { ) -> crate::error::Result { use diesel::prelude::*; use diesel_async::RunQueryDsl; + use std::ops::DerefMut; use #table_mod::dsl; tracing::Span::current().record("query_id", tracing::field::debug(#id_ref_ident)); diesel::delete(dsl::#table_name.#(filter(#eqs)).*) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await .map(|n| n == 1) .map_err(Into::into) diff --git a/editoast/editoast_derive/src/modelv2/codegen/exists_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/exists_impl.rs index 76176a6a0c8..4e0d0125065 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/exists_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/exists_impl.rs @@ -35,10 +35,11 @@ impl ToTokens for ExistsImpl { ) -> crate::error::Result { use diesel::prelude::*; use diesel_async::RunQueryDsl; + use std::ops::DerefMut; use #table_mod::dsl; tracing::Span::current().record("query_id", tracing::field::debug(#id_ref_ident)); diesel::select(diesel::dsl::exists(dsl::#table_name.#(filter(#eqs)).*)) - .get_result(conn) + .get_result(conn.write().await.deref_mut()) .await .map_err(Into::into) } diff --git a/editoast/editoast_derive/src/modelv2/codegen/list_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/list_impl.rs index f1e9ef0f066..2518a510e93 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/list_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/list_impl.rs @@ -34,6 +34,7 @@ impl ToTokens for ListImpl { use diesel::QueryDsl; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; + use std::ops::DerefMut; let mut query = #table_mod::table.into_boxed(); @@ -58,7 +59,7 @@ impl ToTokens for ListImpl { } let results: Vec<#model> = query - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await? .map_ok(<#model as crate::modelsv2::prelude::Model>::from_row) .try_collect() diff --git a/editoast/editoast_derive/src/modelv2/codegen/retrieve_batch_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/retrieve_batch_impl.rs index 3a637bf0d11..52f773d8a00 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/retrieve_batch_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/retrieve_batch_impl.rs @@ -46,6 +46,7 @@ impl ToTokens for RetrieveBatchImpl { use diesel::prelude::*; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; + use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); Ok(crate::chunked_for_libpq! { @@ -62,7 +63,7 @@ impl ToTokens for RetrieveBatchImpl { query = query.or_filter(#filters); } query - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? .await? @@ -84,6 +85,7 @@ impl ToTokens for RetrieveBatchImpl { use diesel::prelude::*; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; + use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); Ok(crate::chunked_for_libpq! { @@ -97,7 +99,7 @@ impl ToTokens for RetrieveBatchImpl { query = query.or_filter(#filters); } query - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await .map(|s| { s.map_ok(|row| { diff --git a/editoast/editoast_derive/src/modelv2/codegen/retrieve_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/retrieve_impl.rs index 5284430ddb8..80c08ce88ba 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/retrieve_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/retrieve_impl.rs @@ -38,10 +38,11 @@ impl ToTokens for RetrieveImpl { use diesel::prelude::*; use diesel_async::RunQueryDsl; use #table_mod::dsl; + use std::ops::DerefMut; tracing::Span::current().record("query_id", tracing::field::debug(#id_ref_ident)); dsl::#table_name .#(filter(#eqs)).* - .first::<#row>(conn) + .first::<#row>(conn.write().await.deref_mut()) .await .map(Into::into) .optional() diff --git a/editoast/editoast_derive/src/modelv2/codegen/update_batch_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/update_batch_impl.rs index 8eb568f6b7e..7d71e0d28a4 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/update_batch_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/update_batch_impl.rs @@ -51,6 +51,7 @@ impl ToTokens for UpdateBatchImpl { use diesel::prelude::*; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; + use std::ops::DerefMut; let ids = ids.into_iter().collect::>(); tracing::Span::current().record("query_ids", tracing::field::debug(&ids)); Ok(crate::chunked_for_libpq! { @@ -69,7 +70,7 @@ impl ToTokens for UpdateBatchImpl { diesel::update(dsl::#table_name) .filter(dsl::#primary_key_column.eq_any(query)) .set(&self) - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await .map(|s| s.map_ok(<#model as Model>::from_row).try_collect::>())? .await? @@ -89,6 +90,7 @@ impl ToTokens for UpdateBatchImpl { use crate::models::Identifiable; use crate::modelsv2::Model; use #table_mod::dsl; + use std::ops::DerefMut; use diesel::prelude::*; use diesel_async::RunQueryDsl; use futures_util::stream::TryStreamExt; @@ -109,7 +111,7 @@ impl ToTokens for UpdateBatchImpl { diesel::update(dsl::#table_name) .filter(dsl::#primary_key_column.eq_any(query)) .set(&self) - .load_stream::<#row>(conn) + .load_stream::<#row>(conn.write().await.deref_mut()) .await .map(|s| { s.map_ok(|row| { diff --git a/editoast/editoast_derive/src/modelv2/codegen/update_impl.rs b/editoast/editoast_derive/src/modelv2/codegen/update_impl.rs index 97354184cdd..06eec25c866 100644 --- a/editoast/editoast_derive/src/modelv2/codegen/update_impl.rs +++ b/editoast/editoast_derive/src/modelv2/codegen/update_impl.rs @@ -40,11 +40,12 @@ impl ToTokens for UpdateImpl { ) -> crate::error::Result> { use diesel::prelude::*; use diesel_async::RunQueryDsl; + use std::ops::DerefMut; use #table_mod::dsl; tracing::Span::current().record("query_id", tracing::field::debug(#id_ref_ident)); diesel::update(dsl::#table_name.#(filter(#eqs)).*) .set(&self) - .get_result::<#row>(conn) + .get_result::<#row>(conn.write().await.deref_mut()) .await .map(Into::into) .optional() diff --git a/editoast/editoast_models/src/db_connection_pool.rs b/editoast/editoast_models/src/db_connection_pool.rs index e2d4e37595f..348ff3e1e25 100644 --- a/editoast/editoast_models/src/db_connection_pool.rs +++ b/editoast/editoast_models/src/db_connection_pool.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; +use std::ops::DerefMut; use std::sync::Arc; use diesel::sql_query; @@ -7,6 +9,8 @@ use diesel_async::pooled_connection::deadpool::Object; use diesel_async::pooled_connection::deadpool::Pool; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::ManagerConfig; +use diesel_async::scoped_futures::ScopedBoxFuture; +use diesel_async::AsyncConnection; use diesel_async::AsyncPgConnection; use diesel_async::RunQueryDsl; use futures::future::BoxFuture; @@ -17,21 +21,87 @@ use openssl::ssl::SslMethod; use openssl::ssl::SslVerifyMode; use url::Url; -#[cfg(feature = "testing")] use tokio::sync::OwnedRwLockWriteGuard; -#[cfg(feature = "testing")] use tokio::sync::RwLock; -use super::DbConnection; use super::DbConnectionPool; +use super::DieselConnection; pub type DbConnectionConfig = AsyncDieselConnectionManager; -#[cfg(feature = "testing")] -pub type DbConnectionV2 = OwnedRwLockWriteGuard>; +#[derive(Clone)] +pub struct DbConnection { + inner: Arc>>, +} + +pub struct WriteHandle { + guard: OwnedRwLockWriteGuard>, +} + +impl DbConnection { + pub fn new(inner: Arc>>) -> Self { + Self { inner } + } + + pub async fn write(&self) -> WriteHandle { + WriteHandle { + guard: self.inner.clone().write_owned().await, + } + } + + // Implementation of this function is taking a strong inspiration from + // https://docs.rs/diesel/2.1.6/src/diesel/connection/transaction_manager.rs.html#50-71 + // Sadly, this function is private so we can't use it. + // + // :WARNING: If you ever need to modify this function, please take a look at the + // original `diesel` function, they probably do it right more than us. + pub async fn transaction<'a, R, E, F>(self, callback: F) -> std::result::Result + where + F: FnOnce(Self) -> ScopedBoxFuture<'a, 'a, std::result::Result> + Send + 'a, + E: From + Send + 'a, + R: Send + 'a, + { + use diesel_async::TransactionManager as _; + + type TxManager = ::TransactionManager; -#[cfg(not(feature = "testing"))] -pub type DbConnectionV2 = Object; + { + let mut handle = self.write().await; + TxManager::begin_transaction(handle.deref_mut()).await?; + } + + match callback(self.clone()).await { + Ok(result) => { + let mut handle = self.write().await; + TxManager::commit_transaction(handle.deref_mut()).await?; + Ok(result) + } + Err(callback_error) => { + let mut handle = self.write().await; + match TxManager::rollback_transaction(handle.deref_mut()).await { + Ok(()) | Err(diesel::result::Error::BrokenTransactionManager) => { + Err(callback_error) + } + Err(rollback_error) => Err(rollback_error.into()), + } + } + } + } +} + +impl Deref for WriteHandle { + type Target = AsyncPgConnection; + + fn deref(&self) -> &Self::Target { + self.guard.deref() + } +} + +impl DerefMut for WriteHandle { + fn deref_mut(&mut self) -> &mut Self::Target { + self.guard.deref_mut() + } +} /// Wrapper for connection pooling with support for test connections on `cfg(test)` /// @@ -46,7 +116,7 @@ pub type DbConnectionV2 = Object; pub struct DbConnectionPoolV2 { pool: Arc>, #[cfg(feature = "testing")] - test_connection: Option>>>, + test_connection: Option, } #[cfg(feature = "testing")] @@ -79,20 +149,18 @@ impl DbConnectionPoolV2 { } #[cfg(feature = "testing")] - async fn get_connection(&self) -> Result { - let Some(test_connection) = &self.test_connection else { - panic!( - "Test connection not initialized in test DatabasePool -- was `for_tests` called?" - ); - }; - let connection = test_connection.clone().write_owned().await; - Ok(connection) + async fn get_connection(&self) -> Result { + Ok(self + .test_connection + .as_ref() + .expect("should already exist") + .clone()) } #[cfg(not(feature = "testing"))] - async fn get_connection(&self) -> Result { + async fn get_connection(&self) -> Result { let connection = self.pool.get().await?; - Ok(connection) + Ok(DbConnection::new(Arc::new(RwLock::new(connection)))) } /// Get a connection from the pool @@ -159,7 +227,7 @@ impl DbConnectionPoolV2 { /// - Don't declare a variable for a single-use connection: /// /// ``` - /// # async fn my_function_using_conn(conn: tokio::sync::OwnedRwLockWriteGuard>) { + /// # async fn my_function_using_conn(conn: &mut editoast_models::DbConnection) { /// # // Do something with the connection /// # } /// # @@ -167,10 +235,10 @@ impl DbConnectionPoolV2 { /// # async fn main() -> Result<(), editoast_models::db_connection_pool::DatabasePoolError> { /// let pool = editoast_models::DbConnectionPoolV2::for_tests(); /// // do - /// my_function_using_conn(pool.get().await?).await; + /// my_function_using_conn(&mut pool.get().await?).await; /// // instead of - /// let conn = pool.get().await?; - /// my_function_using_conn(conn).await; + /// let mut conn = pool.get().await?; + /// my_function_using_conn(&mut conn).await; /// # Ok(()) /// # } /// ``` @@ -178,10 +246,10 @@ impl DbConnectionPoolV2 { /// - If a connection is used repeatedly, prefer using explicit scoping: /// /// ``` - /// # async fn foo(conn: &mut tokio::sync::OwnedRwLockWriteGuard>) -> u8 { + /// # async fn foo(conn: &mut editoast_models::DbConnection) -> u8 { /// # 0 /// # } - /// # async fn bar(conn: &mut tokio::sync::OwnedRwLockWriteGuard>) -> u8 { + /// # async fn bar(conn: &mut editoast_models::DbConnection) -> u8 { /// # 42 /// # } /// # #[tokio::main] @@ -202,7 +270,7 @@ impl DbConnectionPoolV2 { /// /// ``` /// # trait DoSomething: Sized { - /// # async fn do_something(self, conn: tokio::sync::OwnedRwLockWriteGuard>) -> Result<(), editoast_models::db_connection_pool::DatabasePoolError> { + /// # async fn do_something(self, conn: &mut editoast_models::DbConnection) -> Result<(), editoast_models::db_connection_pool::DatabasePoolError> { /// # // Do something with the connection /// # Ok(()) /// # } @@ -216,15 +284,15 @@ impl DbConnectionPoolV2 { /// items.into_iter() /// .zip(pool.iter_conn()) /// .map(|(item, conn)| async move { - /// let conn = conn.await?; // note the await here - /// item.do_something(conn).await + /// let mut conn = conn.await?; // note the await here + /// item.do_something(&mut conn).await /// }); /// let results = futures::future::try_join_all(operations).await?; /// // you may acquire a new connection afterwards /// # Ok(()) /// # } /// ``` - pub async fn get(&self) -> Result { + pub async fn get(&self) -> Result { self.get_connection().await } @@ -235,7 +303,7 @@ impl DbConnectionPoolV2 { /// See [DbConnectionPoolV2::get] for more information on how connections should be used /// in tests. #[cfg(feature = "testing")] - pub fn get_ok(&self) -> DbConnectionV2 { + pub fn get_ok(&self) -> DbConnection { futures::executor::block_on(self.get()).expect("Failed to get test connection") } @@ -247,7 +315,7 @@ impl DbConnectionPoolV2 { /// /// ``` /// # trait DoSomething: Sized { - /// # async fn do_something(self, conn: tokio::sync::OwnedRwLockWriteGuard>) -> Result<(), editoast_models::db_connection_pool::DatabasePoolError> { + /// # async fn do_something(self, conn: &mut editoast_models::DbConnection) -> Result<(), editoast_models::db_connection_pool::DatabasePoolError> { /// # // Do something with the connection /// # Ok(()) /// # } @@ -261,8 +329,8 @@ impl DbConnectionPoolV2 { /// items.into_iter() /// .zip(pool.iter_conn()) /// .map(|(item, conn)| async move { - /// let conn = conn.await?; // note the await here - /// item.do_something(conn).await + /// let mut conn = conn.await?; // note the await here + /// item.do_something(&mut conn).await /// }); /// let results = futures::future::try_join_all(operations).await?; /// // you may acquire a new connection afterwards @@ -271,7 +339,7 @@ impl DbConnectionPoolV2 { /// ``` pub fn iter_conn( &self, - ) -> impl Iterator> + '_> + ) -> impl Iterator> + '_> { std::iter::repeat_with(|| self.get()) } @@ -298,11 +366,11 @@ impl DbConnectionPoolV2 { .await .expect("cannot begin a test transaction"); } - let test_connection = Arc::new(RwLock::new(conn)); + let test_connection = Some(DbConnection::new(Arc::new(RwLock::new(conn)))); Self { pool, - test_connection: Some(test_connection), + test_connection, } } @@ -312,7 +380,7 @@ impl DbConnectionPoolV2 { .unwrap_or_else(|_| String::from("postgresql://osrd:password@localhost/osrd")); let url = Url::parse(&url).expect("Failed to parse postgresql url"); let pool = - create_connection_pool(url, 1).expect("Failed to initialize test connection pool"); + create_connection_pool(url, 2).expect("Failed to initialize test connection pool"); futures::executor::block_on(Self::from_pool_test(Arc::new(pool), transaction)) } @@ -339,7 +407,9 @@ impl DbConnectionPoolV2 { pub struct PingError(#[from] diesel::result::Error); pub async fn ping_database(conn: &mut DbConnection) -> Result<(), PingError> { - sql_query("SELECT 1").execute(conn).await?; + sql_query("SELECT 1") + .execute(conn.write().await.deref_mut()) + .await?; Ok(()) } @@ -353,7 +423,7 @@ pub fn create_connection_pool( Ok(Pool::builder(manager).max_size(max_size).build()?) } -fn establish_connection(config: &str) -> BoxFuture> { +fn establish_connection(config: &str) -> BoxFuture> { let fut = async { let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap(); connector_builder.set_verify(SslVerifyMode::NONE); @@ -368,7 +438,7 @@ fn establish_connection(config: &str) -> BoxFuture; +type DieselConnection = AsyncPgConnection; +pub type DbConnectionPool = Pool; /// Generic error type to forward errors from the database /// diff --git a/editoast/src/generated_data/buffer_stop.rs b/editoast/src/generated_data/buffer_stop.rs index df57b1125d9..532674460f9 100644 --- a/editoast/src/generated_data/buffer_stop.rs +++ b/editoast/src/generated_data/buffer_stop.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_buffer_stop::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_buffer_stop::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct BufferStopLayer; @@ -28,7 +30,7 @@ impl GeneratedData for BufferStopLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { let _res = sql_query(include_str!("sql/generate_buffer_stop_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -49,7 +51,7 @@ impl GeneratedData for BufferStopLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(involved_objects.deleted)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -58,7 +60,7 @@ impl GeneratedData for BufferStopLayer { sql_query(include_str!("sql/insert_update_buffer_stop_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/detector.rs b/editoast/src/generated_data/detector.rs index bb50af54394..9faf98a8812 100644 --- a/editoast/src/generated_data/detector.rs +++ b/editoast/src/generated_data/detector.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_detector::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_detector::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct DetectorLayer; @@ -28,7 +30,7 @@ impl GeneratedData for DetectorLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_detector_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -49,7 +51,7 @@ impl GeneratedData for DetectorLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(involved_objects.deleted)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -58,7 +60,7 @@ impl GeneratedData for DetectorLayer { sql_query(include_str!("sql/insert_update_detector_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/electrification.rs b/editoast/src/generated_data/electrification.rs index 3dd3fd91cda..252424807a8 100644 --- a/editoast/src/generated_data/electrification.rs +++ b/editoast/src/generated_data/electrification.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_electrification::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_electrification::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct ElectrificationLayer; @@ -32,7 +34,7 @@ impl GeneratedData for ElectrificationLayer { ) -> Result<()> { sql_query(include_str!("sql/generate_electrification_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -59,7 +61,7 @@ impl GeneratedData for ElectrificationLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(objs)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -68,7 +70,7 @@ impl GeneratedData for ElectrificationLayer { sql_query(include_str!("sql/insert_electrification_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/error/mod.rs b/editoast/src/generated_data/error/mod.rs index 106a512d0d0..2bf4e21353a 100644 --- a/editoast/src/generated_data/error/mod.rs +++ b/editoast/src/generated_data/error/mod.rs @@ -22,11 +22,15 @@ use diesel::sql_types::BigInt; use diesel::sql_types::Json; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::OSRDObject; +use editoast_schemas::primitives::ObjectType; use futures_util::Future; use itertools::Itertools; use serde_json::to_value; use sha1::Digest; use sha1::Sha1; +use std::ops::DerefMut; use tracing::warn; use super::GeneratedData; @@ -36,9 +40,6 @@ use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::Graph; use crate::infra_cache::InfraCache; use crate::infra_cache::ObjectCache; -use editoast_models::DbConnection; -use editoast_schemas::primitives::OSRDObject; -use editoast_schemas::primitives::ObjectType; editoast_common::schemas! { infra_error::schemas(), @@ -308,7 +309,7 @@ async fn retrieve_current_errors_hash( Ok(dsl::infra_layer_error .filter(dsl::infra_id.eq(infra_id)) .select(dsl::info_hash) - .load(conn) + .load(conn.write().await.deref_mut()) .await?) } @@ -324,7 +325,7 @@ async fn remove_errors_from_hashes( .filter(dsl::infra_id.eq(infra_id)) .filter(dsl::info_hash.eq_any(errors_hash)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; debug_assert_eq!(nb_deleted, errors_hash.len()); Ok(()) @@ -349,7 +350,7 @@ async fn create_errors( .bind::(infra_id) .bind::, _>(&errors_information) .bind::, _>(&errors_hash) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; assert_eq!(count, errors_hash.len()); } diff --git a/editoast/src/generated_data/mod.rs b/editoast/src/generated_data/mod.rs index a87f7414599..0e00ef300d4 100644 --- a/editoast/src/generated_data/mod.rs +++ b/editoast/src/generated_data/mod.rs @@ -59,7 +59,7 @@ pub trait GeneratedData { Self::table_name() )) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -74,7 +74,7 @@ pub trait GeneratedData { infra: i64, infra_cache: &InfraCache, ) -> Result<()> { - Self::refresh(pool.get().await?.deref_mut(), infra, infra_cache).await + Self::refresh(&mut pool.get().await?, infra, infra_cache).await } /// Search and update all objects that needs to be refreshed given a list of operation. @@ -101,7 +101,7 @@ pub async fn refresh_all( // It doesn’t seem to make a different when the generation step is ran separately // It isn’t clear why without analyze the Postgres server seems to run at 100% without halting sql_query("analyze") - .execute(db_pool.get().await?.deref_mut()) + .execute(&mut db_pool.get().await?.write().await.deref_mut()) .await?; debug!("⚙️ Infra {infra_id}: database analyzed"); futures::try_join!( @@ -165,7 +165,6 @@ pub async fn update_all( #[cfg(test)] pub mod tests { use rstest::rstest; - use std::ops::DerefMut; use crate::generated_data::clear_all; use crate::generated_data::refresh_all; @@ -179,7 +178,7 @@ pub mod tests { #[serial_test::serial] async fn refresh_all_test() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; assert!(refresh_all(db_pool.into(), infra.id, &Default::default()) .await .is_ok()); @@ -188,23 +187,18 @@ pub mod tests { #[rstest] async fn update_all_test() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - assert!(update_all( - db_pool.get_ok().deref_mut(), - infra.id, - &[], - &Default::default() - ) - .await - .is_ok()); + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + assert!( + update_all(&mut db_pool.get_ok(), infra.id, &[], &Default::default()) + .await + .is_ok() + ); } #[rstest] async fn clear_all_test() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - assert!(clear_all(db_pool.get_ok().deref_mut(), infra.id) - .await - .is_ok()); + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + assert!(clear_all(&mut db_pool.get_ok(), infra.id).await.is_ok()); } } diff --git a/editoast/src/generated_data/neutral_section.rs b/editoast/src/generated_data/neutral_section.rs index 966053a98b4..b6c43111936 100644 --- a/editoast/src/generated_data/neutral_section.rs +++ b/editoast/src/generated_data/neutral_section.rs @@ -1,13 +1,15 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::sql_query; use diesel::sql_types::BigInt; use diesel_async::RunQueryDsl; +use editoast_models::DbConnection; use super::GeneratedData; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::DbConnection; pub struct NeutralSectionLayer; @@ -20,7 +22,7 @@ impl GeneratedData for NeutralSectionLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_neutral_section_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } diff --git a/editoast/src/generated_data/neutral_sign.rs b/editoast/src/generated_data/neutral_sign.rs index 74bbbd5c46e..379145aa5e1 100644 --- a/editoast/src/generated_data/neutral_sign.rs +++ b/editoast/src/generated_data/neutral_sign.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_neutral_sign::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_neutral_sign::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct NeutralSignLayer; @@ -28,7 +30,7 @@ impl GeneratedData for NeutralSignLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_neutral_sign_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -52,7 +54,7 @@ impl GeneratedData for NeutralSignLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(objs)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -61,7 +63,7 @@ impl GeneratedData for NeutralSignLayer { sql_query(include_str!("sql/insert_neutral_sign_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/operational_point.rs b/editoast/src/generated_data/operational_point.rs index 956fadc65e5..618ddbc475e 100644 --- a/editoast/src/generated_data/operational_point.rs +++ b/editoast/src/generated_data/operational_point.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_operational_point::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_operational_point::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct OperationalPointLayer; @@ -28,7 +30,7 @@ impl GeneratedData for OperationalPointLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_operational_point_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -55,7 +57,7 @@ impl GeneratedData for OperationalPointLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(objs)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -64,7 +66,7 @@ impl GeneratedData for OperationalPointLayer { sql_query(include_str!("sql/insert_operational_point_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/psl_sign.rs b/editoast/src/generated_data/psl_sign.rs index 77a316f7bec..c774725c7d0 100644 --- a/editoast/src/generated_data/psl_sign.rs +++ b/editoast/src/generated_data/psl_sign.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_psl_sign::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_psl_sign::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct PSLSignLayer; @@ -28,7 +30,7 @@ impl GeneratedData for PSLSignLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_psl_sign_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -54,7 +56,7 @@ impl GeneratedData for PSLSignLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(objs)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -63,7 +65,7 @@ impl GeneratedData for PSLSignLayer { sql_query(include_str!("sql/insert_psl_sign_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/signal.rs b/editoast/src/generated_data/signal.rs index 123dd8e060f..f6fc146594e 100644 --- a/editoast/src/generated_data/signal.rs +++ b/editoast/src/generated_data/signal.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::ops::DerefMut; use async_trait::async_trait; use diesel::delete; @@ -8,6 +9,10 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Nullable; use diesel::sql_types::Text; +use editoast_models::tables::infra_layer_signal::dsl; +use editoast_models::DbConnection; +use editoast_schemas::infra::LogicalSignal; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -17,10 +22,6 @@ use crate::generated_data::sprite_config::SpriteConfig; use crate::generated_data::sprite_config::SpriteConfigs; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_signal::dsl; -use editoast_models::DbConnection; -use editoast_schemas::infra::LogicalSignal; -use editoast_schemas::primitives::ObjectType; pub struct SignalLayer; @@ -77,7 +78,7 @@ async fn generate_signaling_system_and_sprite<'a, T: Iterator .bind::(signaling_system) .bind::, _>(sprite_id) .bind::, _>(signals) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) @@ -94,7 +95,7 @@ impl GeneratedData for SignalLayer { sql_query(include_str!("sql/generate_signal_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; generate_signaling_system_and_sprite( conn, @@ -124,7 +125,7 @@ impl GeneratedData for SignalLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(involved_objects.deleted)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -134,7 +135,7 @@ impl GeneratedData for SignalLayer { sql_query(include_str!("sql/insert_update_signal_layer.sql")) .bind::(infra) .bind::, _>(&updated_signals) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; generate_signaling_system_and_sprite( conn, diff --git a/editoast/src/generated_data/speed_section.rs b/editoast/src/generated_data/speed_section.rs index 2a3afa95a2f..d8a8869abe6 100644 --- a/editoast/src/generated_data/speed_section.rs +++ b/editoast/src/generated_data/speed_section.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use async_trait::async_trait; use diesel::delete; use diesel::query_dsl::methods::FilterDsl; @@ -6,6 +8,9 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_speed_section::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +18,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_speed_section::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct SpeedSectionLayer; @@ -28,7 +30,7 @@ impl GeneratedData for SpeedSectionLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_speed_section_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -55,7 +57,7 @@ impl GeneratedData for SpeedSectionLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(objs)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -64,7 +66,7 @@ impl GeneratedData for SpeedSectionLayer { sql_query(include_str!("sql/insert_speed_section_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/switch.rs b/editoast/src/generated_data/switch.rs index bd62fadcf9f..ee83d26ec5a 100644 --- a/editoast/src/generated_data/switch.rs +++ b/editoast/src/generated_data/switch.rs @@ -4,14 +4,15 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; +use std::ops::DerefMut; use super::utils::InvolvedObjects; use super::GeneratedData; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct SwitchLayer; @@ -24,7 +25,7 @@ impl GeneratedData for SwitchLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_switch_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -46,7 +47,7 @@ impl GeneratedData for SwitchLayer { )) .bind::(infra) .bind::, _>(involved_objects.deleted.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -55,7 +56,7 @@ impl GeneratedData for SwitchLayer { sql_query(include_str!("sql/insert_update_switch_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/generated_data/track_section.rs b/editoast/src/generated_data/track_section.rs index bc1ec627e8f..d6a89d60877 100644 --- a/editoast/src/generated_data/track_section.rs +++ b/editoast/src/generated_data/track_section.rs @@ -6,6 +6,10 @@ use diesel::sql_types::Array; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::tables::infra_layer_track_section::dsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectType; +use std::ops::DerefMut; use super::utils::InvolvedObjects; use super::GeneratedData; @@ -13,9 +17,6 @@ use crate::diesel::ExpressionMethods; use crate::error::Result; use crate::infra_cache::operation::CacheOperation; use crate::infra_cache::InfraCache; -use editoast_models::tables::infra_layer_track_section::dsl; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectType; pub struct TrackSectionLayer; @@ -28,7 +29,7 @@ impl GeneratedData for TrackSectionLayer { async fn generate(conn: &mut DbConnection, infra: i64, _cache: &InfraCache) -> Result<()> { sql_query(include_str!("sql/generate_track_section_layer.sql")) .bind::(infra) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; Ok(()) } @@ -49,7 +50,7 @@ impl GeneratedData for TrackSectionLayer { .filter(dsl::infra_id.eq(infra)) .filter(dsl::obj_id.eq_any(involved_objects.deleted)), ) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } @@ -58,7 +59,7 @@ impl GeneratedData for TrackSectionLayer { sql_query(include_str!("sql/insert_update_track_section_layer.sql")) .bind::(infra) .bind::, _>(involved_objects.updated.into_iter().collect::>()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await?; } Ok(()) diff --git a/editoast/src/infra_cache/mod.rs b/editoast/src/infra_cache/mod.rs index e462e59dd5c..de34a5ca773 100644 --- a/editoast/src/infra_cache/mod.rs +++ b/editoast/src/infra_cache/mod.rs @@ -43,6 +43,7 @@ use enum_map::EnumMap; use geos::geojson::Geometry; pub use graph::Graph; use itertools::Itertools as _; +use std::ops::DerefMut; use thiserror::Error; use crate::error::Result; @@ -436,7 +437,7 @@ impl InfraCache { FROM infra_object_track_section WHERE infra_id = $1", ) .bind::(infra_id) - .load::(conn) + .load::(conn.write().await.deref_mut()) .await? .into_iter() .try_for_each(|track| infra_cache.add::(track.into()))?; @@ -445,7 +446,7 @@ impl InfraCache { sql_query( "SELECT obj_id, data->>'track' AS track, (data->>'position')::float AS position, data->'logical_signals' as logical_signals FROM infra_object_signal WHERE infra_id = $1") .bind::(infra_id) - .load::(conn).await?.into_iter().try_for_each(|signal| + .load::(conn.write().await.deref_mut()).await?.into_iter().try_for_each(|signal| infra_cache.add(signal) )?; @@ -471,7 +472,7 @@ impl InfraCache { sql_query( "SELECT obj_id, data->>'parts' AS parts FROM infra_object_operational_point WHERE infra_id = $1") .bind::(infra_id) - .load::(conn).await?.into_iter().try_for_each(|op| + .load::(&mut conn.write().await).await?.into_iter().try_for_each(|op| infra_cache.add::(op.into()) )?; @@ -479,12 +480,12 @@ impl InfraCache { sql_query( "SELECT obj_id, data->>'switch_type' AS switch_type, data->>'ports' AS ports FROM infra_object_switch WHERE infra_id = $1") .bind::(infra_id) - .load::(conn).await?.into_iter().try_for_each(|switch| + .load::(&mut conn.write().await).await?.into_iter().try_for_each(|switch| infra_cache.add::(switch.into()) )?; // Load switch types references - find_all_schemas::<_, Vec>(conn, infra_id) + find_all_schemas::<_, Vec>(&mut conn.clone(), infra_id) .await? .into_iter() .try_for_each(|switch_type| infra_cache.add::(switch_type))?; @@ -500,7 +501,7 @@ impl InfraCache { sql_query( "SELECT obj_id, data->>'track' AS track, (data->>'position')::float AS position FROM infra_object_detector WHERE infra_id = $1") .bind::(infra_id) - .load::(conn).await?.into_iter().try_for_each(|detector| + .load::(&mut conn.write().await).await?.into_iter().try_for_each(|detector| infra_cache.add(detector) )?; @@ -508,7 +509,7 @@ impl InfraCache { sql_query( "SELECT obj_id, data->>'track' AS track, (data->>'position')::float AS position FROM infra_object_buffer_stop WHERE infra_id = $1") .bind::(infra_id) - .load::(conn).await?.into_iter().try_for_each(|buffer_stop| + .load::(&mut conn.write().await).await?.into_iter().try_for_each(|buffer_stop| infra_cache.add(buffer_stop) )?; @@ -533,7 +534,7 @@ impl InfraCache { return Ok(infra_cache); } // Cache miss - infra_caches.insert_new(infra.id, InfraCache::load(conn, infra).await?); + infra_caches.insert_new(infra.id, InfraCache::load(&mut conn.clone(), infra).await?); Ok(infra_caches.get(&infra.id).unwrap()) } @@ -549,7 +550,7 @@ impl InfraCache { return Ok(infra_cache); } // Cache miss - infra_caches.insert_new(infra.id, InfraCache::load(conn, infra).await?); + infra_caches.insert_new(infra.id, InfraCache::load(&mut conn.clone(), infra).await?); Ok(infra_caches.get_mut(&infra.id).unwrap()) } @@ -885,7 +886,6 @@ pub mod tests { use pretty_assertions::assert_eq; use rstest::rstest; use std::collections::HashMap; - use std::ops::DerefMut; use super::OperationalPointCache; use crate::infra_cache::object_cache::BufferStopCache; @@ -920,14 +920,10 @@ pub mod tests { #[rstest] async fn load_track_section() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let track = create_infra_object( - db_pool.get_ok().deref_mut(), - infra.id, - TrackSection::default(), - ) - .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let track = + create_infra_object(&mut db_pool.get_ok(), infra.id, TrackSection::default()).await; + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -938,10 +934,9 @@ pub mod tests { #[rstest] async fn load_signal() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let signal = - create_infra_object(db_pool.get_ok().deref_mut(), infra.id, Signal::default()).await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let signal = create_infra_object(&mut db_pool.get_ok(), infra.id, Signal::default()).await; + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -953,9 +948,9 @@ pub mod tests { #[rstest] async fn load_speed_section() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let speed = create_infra_object( - db_pool.get_ok().deref_mut(), + &mut db_pool.get_ok(), infra.id, SpeedSection { track_ranges: vec![Default::default()], @@ -963,7 +958,7 @@ pub mod tests { }, ) .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -975,10 +970,9 @@ pub mod tests { #[rstest] async fn load_route() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let route = - create_infra_object(db_pool.get_ok().deref_mut(), infra.id, Route::default()).await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let route = create_infra_object(&mut db_pool.get_ok(), infra.id, Route::default()).await; + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -988,9 +982,9 @@ pub mod tests { #[rstest] async fn load_operational_point() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let op = create_infra_object( - db_pool.get_ok().deref_mut(), + &mut db_pool.get_ok(), infra.id, OperationalPoint { parts: vec![Default::default()], @@ -999,7 +993,7 @@ pub mod tests { ) .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -1011,9 +1005,9 @@ pub mod tests { #[rstest] async fn load_switch() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let switch = create_infra_object( - db_pool.get_ok().deref_mut(), + &mut db_pool.get_ok(), infra.id, Switch { ports: HashMap::from([("port".into(), Default::default())]), @@ -1021,7 +1015,7 @@ pub mod tests { }, ) .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -1031,14 +1025,10 @@ pub mod tests { #[rstest] async fn load_switch_type() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let s_type = create_infra_object( - db_pool.get_ok().deref_mut(), - infra.id, - SwitchType::default(), - ) - .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let s_type = + create_infra_object(&mut db_pool.get_ok(), infra.id, SwitchType::default()).await; + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -1048,10 +1038,10 @@ pub mod tests { #[rstest] async fn load_detector() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let detector = - create_infra_object(db_pool.get_ok().deref_mut(), infra.id, Detector::default()).await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + create_infra_object(&mut db_pool.get_ok(), infra.id, Detector::default()).await; + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -1063,14 +1053,9 @@ pub mod tests { #[rstest] async fn load_buffer_stop() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let bs = create_infra_object( - db_pool.get_ok().deref_mut(), - infra.id, - BufferStop::default(), - ) - .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let bs = create_infra_object(&mut db_pool.get_ok(), infra.id, BufferStop::default()).await; + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -1082,9 +1067,9 @@ pub mod tests { #[rstest] async fn load_electrification() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let electrification = create_infra_object( - db_pool.get_ok().deref_mut(), + &mut db_pool.get_ok(), infra.id, Electrification { track_ranges: vec![Default::default()], @@ -1093,7 +1078,7 @@ pub mod tests { ) .await; - let infra_cache = InfraCache::load(db_pool.get_ok().deref_mut(), &infra) + let infra_cache = InfraCache::load(&mut db_pool.get_ok(), &infra) .await .unwrap(); @@ -1387,13 +1372,13 @@ pub mod tests { #[rstest] async fn load_infra_cache() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let infra_caches = CHashMap::new(); - InfraCache::get_or_load(db_pool.get_ok().deref_mut(), &infra_caches, &infra) + InfraCache::get_or_load(&mut db_pool.get_ok(), &infra_caches, &infra) .await .unwrap(); assert_eq!(infra_caches.len(), 1); - InfraCache::get_or_load(db_pool.get_ok().deref_mut(), &infra_caches, &infra) + InfraCache::get_or_load(&mut db_pool.get_ok(), &infra_caches, &infra) .await .unwrap(); assert_eq!(infra_caches.len(), 1); @@ -1402,13 +1387,13 @@ pub mod tests { #[rstest] async fn load_infra_cache_mut() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; let infra_caches = CHashMap::new(); - InfraCache::get_or_load_mut(db_pool.get_ok().deref_mut(), &infra_caches, &infra) + InfraCache::get_or_load_mut(&mut db_pool.get_ok(), &infra_caches, &infra) .await .unwrap(); assert_eq!(infra_caches.len(), 1); - InfraCache::get_or_load_mut(db_pool.get_ok().deref_mut(), &infra_caches, &infra) + InfraCache::get_or_load_mut(&mut db_pool.get_ok(), &infra_caches, &infra) .await .unwrap(); assert_eq!(infra_caches.len(), 1); diff --git a/editoast/src/infra_cache/operation/create.rs b/editoast/src/infra_cache/operation/create.rs index 708102ac87b..066f764bac7 100644 --- a/editoast/src/infra_cache/operation/create.rs +++ b/editoast/src/infra_cache/operation/create.rs @@ -3,14 +3,15 @@ use diesel::sql_types::BigInt; use diesel::sql_types::Json; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; - -use super::OperationError; -use crate::error::Result; -use crate::modelsv2::get_table; use editoast_models::DbConnection; use editoast_schemas::infra::InfraObject; use editoast_schemas::primitives::OSRDIdentified; use editoast_schemas::primitives::OSRDObject; +use std::ops::DerefMut; + +use super::OperationError; +use crate::error::Result; +use crate::modelsv2::get_table; pub async fn apply_create_operation<'r>( infra_object: &'r InfraObject, @@ -27,7 +28,7 @@ pub async fn apply_create_operation<'r>( .bind::(infra_id) .bind::(infra_object.get_id()) .bind::(infra_object.get_data()) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await .map(|idx| (idx, infra_object)) .map_err(|err| err.into()) @@ -46,7 +47,6 @@ pub mod tests { use editoast_schemas::infra::Switch; use editoast_schemas::infra::SwitchType; use editoast_schemas::infra::TrackSection; - use std::ops::DerefMut; macro_rules! test_create_object { ($obj:ident) => { @@ -54,11 +54,11 @@ pub mod tests { #[rstest::rstest] async fn []() { let db_pool = editoast_models::DbConnectionPoolV2::for_tests(); - let infra = crate::modelsv2::fixtures::create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = crate::modelsv2::fixtures::create_empty_infra(&mut db_pool.get_ok()).await; let infra_object = editoast_schemas::infra::InfraObject::$obj { railjson: $obj::default(), }; - let result = crate::infra_cache::operation::create::apply_create_operation(&infra_object, infra.id, db_pool.get_ok().deref_mut()).await; + let result = crate::infra_cache::operation::create::apply_create_operation(&infra_object, infra.id, &mut db_pool.get_ok()).await; assert!(result.is_ok(), "Failed to create a {}", stringify!($obj)); } } diff --git a/editoast/src/infra_cache/operation/delete.rs b/editoast/src/infra_cache/operation/delete.rs index fa1543a5d30..16939dd8739 100644 --- a/editoast/src/infra_cache/operation/delete.rs +++ b/editoast/src/infra_cache/operation/delete.rs @@ -2,15 +2,16 @@ use diesel::sql_query; use diesel::sql_types::BigInt; use diesel::sql_types::Text; use diesel_async::RunQueryDsl; +use editoast_models::DbConnection; +use editoast_schemas::primitives::ObjectRef; +use editoast_schemas::primitives::ObjectType; use serde::Deserialize; use serde::Serialize; +use std::ops::DerefMut; use super::OperationError; use crate::error::Result; use crate::modelsv2::get_table; -use editoast_models::DbConnection; -use editoast_schemas::primitives::ObjectRef; -use editoast_schemas::primitives::ObjectType; #[derive(Clone, Debug, PartialEq, Eq, Hash, Deserialize, Serialize, utoipa::ToSchema)] #[serde(deny_unknown_fields)] @@ -28,7 +29,7 @@ impl DeleteOperation { )) .bind::(&self.obj_id) .bind::(&infra_id) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await { Ok(1) => Ok(()), @@ -93,16 +94,16 @@ mod tests { use std::ops::DerefMut; let db_pool = DbConnectionPoolV2::for_tests(); - let infra = crate::modelsv2::fixtures::create_empty_infra(db_pool.get_ok().deref_mut()).await; + let infra = crate::modelsv2::fixtures::create_empty_infra(&mut db_pool.get_ok()).await; let railjson_object = editoast_schemas::infra::InfraObject::$obj { railjson: $obj::default(), }; - let result = crate::infra_cache::operation::create::apply_create_operation(&railjson_object, infra.id, db_pool.get_ok().deref_mut()).await; + let result = crate::infra_cache::operation::create::apply_create_operation(&railjson_object, infra.id, &mut db_pool.get_ok()).await; assert!(result.is_ok(), "Failed to create a {}", stringify!($obj)); let object_deletion: crate::infra_cache::operation::delete::DeleteOperation = railjson_object.get_ref().into(); - let result = object_deletion.apply(infra.id, db_pool.get_ok().deref_mut()).await; + let result = object_deletion.apply(infra.id, &mut db_pool.get_ok()).await; assert!(result.is_ok(), "Failed to delete a {}", stringify!($obj)); let res_del = diesel::sql_query(format!( @@ -111,7 +112,7 @@ mod tests { railjson_object.get_id(), infra.id )) - .get_result::(db_pool.get_ok().deref_mut()).await.unwrap(); + .get_result::(&mut db_pool.get_ok().write().await.deref_mut()).await.unwrap(); pretty_assertions::assert_eq!(res_del.nb, 0); } diff --git a/editoast/src/infra_cache/operation/update.rs b/editoast/src/infra_cache/operation/update.rs index 684674f0114..dc2fda85df1 100644 --- a/editoast/src/infra_cache/operation/update.rs +++ b/editoast/src/infra_cache/operation/update.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use diesel::result::Error as DieselError; use diesel::sql_query; use diesel::sql_types::BigInt; @@ -6,6 +8,10 @@ use diesel::sql_types::Jsonb; use diesel::sql_types::Text; use diesel::QueryableByName; use diesel_async::RunQueryDsl; +use editoast_models::DbConnection; +use editoast_schemas::infra::InfraObject; +use editoast_schemas::primitives::OSRDIdentified; +use editoast_schemas::primitives::ObjectType; use json_patch::Patch; use serde::Deserialize; use serde::Serialize; @@ -16,10 +22,6 @@ use serde_json::Value; use super::OperationError; use crate::error::Result; use crate::modelsv2::get_table; -use editoast_models::DbConnection; -use editoast_schemas::infra::InfraObject; -use editoast_schemas::primitives::OSRDIdentified; -use editoast_schemas::primitives::ObjectType; #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, utoipa::ToSchema)] #[serde(deny_unknown_fields)] @@ -40,7 +42,7 @@ impl UpdateOperation { )) .bind::(infra_id) .bind::(&self.obj_id) - .get_result(conn) + .get_result(conn.write().await.deref_mut()) .await { Ok(obj) => obj, @@ -65,7 +67,7 @@ impl UpdateOperation { .bind::(obj.data) .bind::(infra_id) .bind::(&self.obj_id) - .execute(conn) + .execute(conn.write().await.deref_mut()) .await { Ok(1) => Ok(railjson_obj), @@ -157,13 +159,9 @@ mod tests { #[rstest] async fn valid_update_track() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let track = create_infra_object( - db_pool.get_ok().deref_mut(), - infra.id, - TrackSection::default(), - ) - .await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let track = + create_infra_object(&mut db_pool.get_ok(), infra.id, TrackSection::default()).await; let update_track = UpdateOperation { obj_id: track.get_id().clone(), obj_type: ObjectType::TrackSection, @@ -176,7 +174,7 @@ mod tests { }; assert!(update_track - .apply(infra.id, db_pool.get_ok().deref_mut()) + .apply(infra.id, &mut db_pool.get_ok()) .await .is_ok()); @@ -185,7 +183,7 @@ mod tests { track.get_id(), infra.id )) - .get_result::(db_pool.get_ok().deref_mut()).await.unwrap(); + .get_result::(db_pool.get_ok().write().await.deref_mut()).await.unwrap(); assert_eq!(updated_length.val, 80.0); } @@ -193,13 +191,9 @@ mod tests { #[rstest] async fn invalid_update_track() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let track = create_infra_object( - db_pool.get_ok().deref_mut(), - infra.id, - TrackSection::default(), - ) - .await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let track = + create_infra_object(&mut db_pool.get_ok(), infra.id, TrackSection::default()).await; let update_track = UpdateOperation { obj_id: track.get_id().clone(), obj_type: ObjectType::TrackSection, @@ -210,9 +204,7 @@ mod tests { ) .unwrap(), }; - let res = update_track - .apply(infra.id, db_pool.get_ok().deref_mut()) - .await; + let res = update_track.apply(infra.id, &mut db_pool.get_ok()).await; assert!(res.is_err()); assert_eq!( @@ -224,9 +216,8 @@ mod tests { #[rstest] async fn valid_update_signal() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let signal = - create_infra_object(db_pool.get_ok().deref_mut(), infra.id, Signal::default()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let signal = create_infra_object(&mut db_pool.get_ok(), infra.id, Signal::default()).await; let update_signal = UpdateOperation { obj_id: signal.get_id().clone(), obj_type: ObjectType::Signal, @@ -239,7 +230,7 @@ mod tests { }; assert!(update_signal - .apply(infra.id, db_pool.get_ok().deref_mut()) + .apply(infra.id, &mut db_pool.get_ok()) .await .is_ok()); @@ -248,7 +239,7 @@ mod tests { signal.get_id(), infra.id )) - .get_result::(db_pool.get_ok().deref_mut()).await.unwrap(); + .get_result::(db_pool.get_ok().write().await.deref_mut()).await.unwrap(); assert_eq!(updated_length.val, 15.0); } @@ -256,9 +247,8 @@ mod tests { #[rstest] async fn valid_update_switch_extension() { let db_pool = DbConnectionPoolV2::for_tests(); - let infra = create_empty_infra(db_pool.get_ok().deref_mut()).await; - let switch = - create_infra_object(db_pool.get_ok().deref_mut(), infra.id, Switch::default()).await; + let infra = create_empty_infra(&mut db_pool.get_ok()).await; + let switch = create_infra_object(&mut db_pool.get_ok(), infra.id, Switch::default()).await; let update_switch = UpdateOperation { obj_id: switch.get_id().clone(), obj_type: ObjectType::Switch, @@ -271,7 +261,7 @@ mod tests { }; assert!(update_switch - .apply(infra.id, db_pool.get_ok().deref_mut()) + .apply(infra.id, &mut db_pool.get_ok()) .await .is_ok()); @@ -280,7 +270,7 @@ mod tests { switch.get_id(), infra.id )) - .get_result::