Skip to content

Commit c1359a6

Browse files
committed
editoast: rename DbConnection to DieselConnection
1 parent f2f2cc9 commit c1359a6

File tree

5 files changed

+160
-17
lines changed

5 files changed

+160
-17
lines changed

editoast/editoast_derive/src/model.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ fn create_functions(config: &Config) -> TokenStream {
8484
quote! {
8585
#[async_trait::async_trait]
8686
impl crate::models::Create for #model_name {
87-
async fn create_conn(self, conn: &mut editoast_models::DbConnection) -> crate::error::Result<Self> {
87+
async fn create_conn(self, conn: &mut editoast_models::DieselConnection) -> crate::error::Result<Self> {
8888
use #table::table;
8989
use diesel_async::RunQueryDsl;
9090

@@ -126,7 +126,7 @@ fn retrieve_functions(config: &Config) -> TokenStream {
126126
quote! {
127127
#[async_trait::async_trait]
128128
impl crate::models::Retrieve for #model_name {
129-
async fn retrieve_conn(conn: &mut editoast_models::DbConnection, obj_id: i64) -> crate::error::Result<Option<Self>> {
129+
async fn retrieve_conn(conn: &mut editoast_models::DieselConnection, obj_id: i64) -> crate::error::Result<Option<Self>> {
130130
use #table::table;
131131
use #table::dsl;
132132
use diesel_async::RunQueryDsl;
@@ -174,7 +174,7 @@ fn delete_functions(config: &Config) -> TokenStream {
174174
quote! {
175175
#[async_trait::async_trait]
176176
impl crate::models::Delete for #model_name {
177-
async fn delete_conn(conn: &mut editoast_models::DbConnection, obj_id: i64) -> crate::error::Result<bool> {
177+
async fn delete_conn(conn: &mut editoast_models::DieselConnection, obj_id: i64) -> crate::error::Result<bool> {
178178
use #table::table;
179179
use #table::dsl;
180180
use diesel_async::RunQueryDsl;
@@ -207,7 +207,7 @@ fn update_functions(config: &Config) -> TokenStream {
207207
quote! {
208208
#[async_trait::async_trait]
209209
impl crate::models::Update for #model_name {
210-
async fn update_conn(self, conn: &mut editoast_models::DbConnection, obj_id: i64) -> crate::error::Result<Option<Self>> {
210+
async fn update_conn(self, conn: &mut editoast_models::DieselConnection, obj_id: i64) -> crate::error::Result<Option<Self>> {
211211
use #table::table;
212212

213213
match diesel::update(table.find(obj_id)).set(&self).get_result(conn).await

editoast/editoast_models/src/db_connection_pool.rs

+140-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::ops::Deref;
2+
use std::ops::DerefMut;
13
use std::sync::Arc;
24

35
use diesel::sql_query;
@@ -7,6 +9,8 @@ use diesel_async::pooled_connection::deadpool::Object;
79
use diesel_async::pooled_connection::deadpool::Pool;
810
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
911
use diesel_async::pooled_connection::ManagerConfig;
12+
use diesel_async::scoped_futures::ScopedBoxFuture;
13+
use diesel_async::AsyncConnection;
1014
use diesel_async::AsyncPgConnection;
1115
use diesel_async::RunQueryDsl;
1216
use futures::future::BoxFuture;
@@ -17,13 +21,12 @@ use openssl::ssl::SslMethod;
1721
use openssl::ssl::SslVerifyMode;
1822
use url::Url;
1923

20-
#[cfg(feature = "testing")]
2124
use tokio::sync::OwnedRwLockWriteGuard;
22-
#[cfg(feature = "testing")]
2325
use tokio::sync::RwLock;
2426

2527
use super::DbConnection;
2628
use super::DbConnectionPool;
29+
use super::DieselConnection;
2730

2831
pub type DbConnectionConfig = AsyncDieselConnectionManager<AsyncPgConnection>;
2932

@@ -33,6 +36,84 @@ pub type DbConnectionV2 = OwnedRwLockWriteGuard<Object<AsyncPgConnection>>;
3336
#[cfg(not(feature = "testing"))]
3437
pub type DbConnectionV2 = Object<AsyncPgConnection>;
3538

39+
#[derive(Clone)]
40+
pub struct DbConnectionV3 {
41+
inner: Arc<RwLock<Object<AsyncPgConnection>>>,
42+
}
43+
44+
pub struct WriteHandle {
45+
guard: OwnedRwLockWriteGuard<Object<AsyncPgConnection>>,
46+
}
47+
48+
impl DbConnectionV3 {
49+
pub fn new(inner: Arc<RwLock<Object<AsyncPgConnection>>>) -> Self {
50+
Self { inner }
51+
}
52+
53+
pub async fn write(&self) -> WriteHandle {
54+
WriteHandle {
55+
guard: self.inner.clone().write_owned().await,
56+
}
57+
}
58+
59+
pub async fn transaction<'a, R, E, F>(self, callback: F) -> std::result::Result<R, E>
60+
where
61+
F: FnOnce(Self) -> ScopedBoxFuture<'a, 'a, std::result::Result<R, E>> + Send + 'a,
62+
E: From<diesel::result::Error> + Send + 'a,
63+
R: Send + 'a,
64+
{
65+
use diesel_async::TransactionManager as _;
66+
67+
type TxManager = <AsyncPgConnection as AsyncConnection>::TransactionManager;
68+
69+
{
70+
let mut handle = self.write().await;
71+
TxManager::begin_transaction(
72+
handle.deref_mut(),
73+
)
74+
.await?;
75+
}
76+
77+
match callback(self.clone()).await {
78+
Ok(result) => {
79+
let mut handle = self.write().await;
80+
TxManager::commit_transaction(
81+
handle.deref_mut(),
82+
)
83+
.await?;
84+
Ok(result)
85+
}
86+
Err(callback_error) => {
87+
let mut handle = self.write().await;
88+
match TxManager::rollback_transaction(
89+
handle.deref_mut(),
90+
)
91+
.await
92+
{
93+
Ok(()) | Err(diesel::result::Error::BrokenTransactionManager) => {
94+
Err(callback_error)
95+
}
96+
Err(rollback_error) => Err(rollback_error.into()),
97+
}
98+
}
99+
}
100+
}
101+
}
102+
103+
impl Deref for WriteHandle {
104+
type Target = AsyncPgConnection;
105+
106+
fn deref(&self) -> &Self::Target {
107+
self.guard.deref()
108+
}
109+
}
110+
111+
impl DerefMut for WriteHandle {
112+
fn deref_mut(&mut self) -> &mut Self::Target {
113+
self.guard.deref_mut()
114+
}
115+
}
116+
36117
/// Wrapper for connection pooling with support for test connections on `cfg(test)`
37118
///
38119
/// # Testing pool
@@ -47,6 +128,8 @@ pub struct DbConnectionPoolV2 {
47128
pool: Arc<Pool<AsyncPgConnection>>,
48129
#[cfg(feature = "testing")]
49130
test_connection: Option<Arc<RwLock<Object<AsyncPgConnection>>>>,
131+
#[cfg(feature = "testing")]
132+
test_connection_v3: Option<DbConnectionV3>,
50133
}
51134

52135
#[cfg(feature = "testing")]
@@ -61,8 +144,20 @@ impl Default for DbConnectionPoolV2 {
61144
pub struct DatabasePoolBuildError(#[from] diesel_async::pooled_connection::deadpool::BuildError);
62145

63146
#[derive(Debug, thiserror::Error)]
64-
#[error("an error occurred while getting a connection from the database pool: '{0}'")]
65-
pub struct DatabasePoolError(#[from] diesel_async::pooled_connection::deadpool::PoolError);
147+
pub enum DatabasePoolError {
148+
#[error("an error occurred while getting a connection from the database pool: '{0}'")]
149+
Pool(#[from] diesel_async::pooled_connection::deadpool::PoolError),
150+
#[error("an error occured while querying the database: {0}")]
151+
DieselError(#[from] diesel::result::Error),
152+
}
153+
154+
#[derive(Debug, thiserror::Error)]
155+
pub enum DatabaseTransactionError {
156+
#[error("an error occurred while getting a connection from the database pool: '{0}'")]
157+
Pool(#[from] DatabasePoolError),
158+
#[error("an error occured while querying the database: {0}")]
159+
DieselError(#[from] diesel::result::Error),
160+
}
66161

67162
impl DbConnectionPoolV2 {
68163
/// Get inner pool for retro compatibility
@@ -89,12 +184,27 @@ impl DbConnectionPoolV2 {
89184
Ok(connection)
90185
}
91186

187+
#[cfg(feature = "testing")]
188+
async fn get_connection_v3(&self) -> Result<DbConnectionV3, DatabasePoolError> {
189+
Ok(self
190+
.test_connection_v3
191+
.as_ref()
192+
.expect("should already exist")
193+
.clone())
194+
}
195+
92196
#[cfg(not(feature = "testing"))]
93197
async fn get_connection(&self) -> Result<DbConnectionV2, DatabasePoolError> {
94198
let connection = self.pool.get().await?;
95199
Ok(connection)
96200
}
97201

202+
#[cfg(not(feature = "testing"))]
203+
async fn get_connection_v3(&self) -> Result<DbConnectionV3, DatabasePoolError> {
204+
let connection = self.pool.get().await?;
205+
Ok(DbConnectionV3::new(Arc::new(RwLock::new(connection))))
206+
}
207+
98208
/// Get a connection from the pool
99209
///
100210
/// This function behaves differently in test mode.
@@ -228,6 +338,15 @@ impl DbConnectionPoolV2 {
228338
self.get_connection().await
229339
}
230340

341+
pub async fn get_v3(&self) -> Result<DbConnectionV3, DatabasePoolError> {
342+
self.get_connection_v3().await
343+
}
344+
345+
#[cfg(feature = "testing")]
346+
pub fn get_ok_v3(&self) -> DbConnectionV3 {
347+
futures::executor::block_on(self.get_v3()).expect("Failed to get test connection")
348+
}
349+
231350
/// Gets a test connection from the pool synchronously, failing if the connection is not available
232351
///
233352
/// In unit tests, this is the preferred way to get a connection
@@ -300,9 +419,23 @@ impl DbConnectionPoolV2 {
300419
}
301420
let test_connection = Arc::new(RwLock::new(conn));
302421

422+
// Conn v3
423+
let mut conn_v3 = pool
424+
.get()
425+
.await
426+
.expect("cannot acquire a connection in the test pool");
427+
if transaction {
428+
conn_v3
429+
.begin_test_transaction()
430+
.await
431+
.expect("cannot begin a test transaction");
432+
}
433+
let test_connection_v3 = Some(DbConnectionV3::new(Arc::new(RwLock::new(conn_v3))));
434+
303435
Self {
304436
pool,
305437
test_connection: Some(test_connection),
438+
test_connection_v3,
306439
}
307440
}
308441

@@ -312,7 +445,7 @@ impl DbConnectionPoolV2 {
312445
.unwrap_or_else(|_| String::from("postgresql://osrd:password@localhost/osrd"));
313446
let url = Url::parse(&url).expect("Failed to parse postgresql url");
314447
let pool =
315-
create_connection_pool(url, 1).expect("Failed to initialize test connection pool");
448+
create_connection_pool(url, 2).expect("Failed to initialize test connection pool");
316449
futures::executor::block_on(Self::from_pool_test(Arc::new(pool), transaction))
317450
}
318451

@@ -353,7 +486,7 @@ pub fn create_connection_pool(
353486
Ok(Pool::builder(manager).max_size(max_size).build()?)
354487
}
355488

356-
fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DbConnection>> {
489+
fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DieselConnection>> {
357490
let fut = async {
358491
let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
359492
connector_builder.set_verify(SslVerifyMode::NONE);
@@ -368,7 +501,7 @@ fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DbConnection
368501
tracing::error!("connection error: {}", e);
369502
}
370503
});
371-
DbConnection::try_from(client).await
504+
DieselConnection::try_from(client).await
372505
};
373506
fut.boxed()
374507
}

editoast/editoast_models/src/lib.rs

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
use diesel_async::{pooled_connection::deadpool::Pool, AsyncPgConnection};
1+
use db_connection_pool::DatabasePoolError;
2+
use diesel_async::pooled_connection::deadpool::Pool;
3+
use diesel_async::AsyncPgConnection;
24

35
pub mod db_connection_pool;
46
pub mod tables;
57

68
pub use db_connection_pool::DbConnectionPoolV2;
9+
pub use db_connection_pool::DbConnectionV3;
710

811
pub type DbConnection = AsyncPgConnection;
9-
pub type DbConnectionPool = Pool<DbConnection>;
12+
pub type DieselConnection = AsyncPgConnection;
13+
pub type DbConnectionPool = Pool<DieselConnection>;
1014

1115
/// Generic error type to forward errors from the database
1216
///
1317
/// Useful for functions which only points of failure are the DB calls.
1418
#[derive(Debug, thiserror::Error)]
15-
#[error("an error occured while querying the database: {0}")]
16-
pub struct DatabaseError(#[from] diesel::result::Error);
19+
pub enum DatabaseError {
20+
#[error("an error occured while querying the database: {0}")]
21+
DieselError(#[from] diesel::result::Error),
22+
#[error("an error occured while retrieving a connection from the pool: {0}")]
23+
DatabasePoolError(#[from] DatabasePoolError),
24+
}

editoast/src/modelsv2/prelude/create.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use std::fmt::Debug;
22

3+
use editoast_models::DbConnection;
4+
use editoast_models::DbConnectionV3;
5+
36
use crate::error::EditoastError;
47
use crate::error::Result;
5-
use editoast_models::DbConnection;
68

79
/// Describes how a [Model](super::Model) can be created in the database
810
///

editoast/src/modelsv2/prelude/delete.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
use editoast_models::DbConnection;
12
use crate::error::EditoastError;
23
use crate::error::Result;
3-
use editoast_models::DbConnection;
44

55
/// Describes how a [Model](super::Model) can be deleted from the database
66
///

0 commit comments

Comments
 (0)