1
+ use std:: ops:: Deref ;
2
+ use std:: ops:: DerefMut ;
1
3
use std:: sync:: Arc ;
2
4
3
5
use diesel:: sql_query;
@@ -7,6 +9,8 @@ use diesel_async::pooled_connection::deadpool::Object;
7
9
use diesel_async:: pooled_connection:: deadpool:: Pool ;
8
10
use diesel_async:: pooled_connection:: AsyncDieselConnectionManager ;
9
11
use diesel_async:: pooled_connection:: ManagerConfig ;
12
+ use diesel_async:: scoped_futures:: ScopedBoxFuture ;
13
+ use diesel_async:: AsyncConnection ;
10
14
use diesel_async:: AsyncPgConnection ;
11
15
use diesel_async:: RunQueryDsl ;
12
16
use futures:: future:: BoxFuture ;
@@ -17,13 +21,12 @@ use openssl::ssl::SslMethod;
17
21
use openssl:: ssl:: SslVerifyMode ;
18
22
use url:: Url ;
19
23
20
- #[ cfg( feature = "testing" ) ]
21
24
use tokio:: sync:: OwnedRwLockWriteGuard ;
22
- #[ cfg( feature = "testing" ) ]
23
25
use tokio:: sync:: RwLock ;
24
26
25
27
use super :: DbConnection ;
26
28
use super :: DbConnectionPool ;
29
+ use super :: DieselConnection ;
27
30
28
31
pub type DbConnectionConfig = AsyncDieselConnectionManager < AsyncPgConnection > ;
29
32
@@ -33,6 +36,84 @@ pub type DbConnectionV2 = OwnedRwLockWriteGuard<Object<AsyncPgConnection>>;
33
36
#[ cfg( not( feature = "testing" ) ) ]
34
37
pub type DbConnectionV2 = Object < AsyncPgConnection > ;
35
38
39
+ #[ derive( Clone ) ]
40
+ pub struct DbConnectionV3 {
41
+ inner : Arc < RwLock < Object < AsyncPgConnection > > > ,
42
+ }
43
+
44
+ pub struct WriteHandle {
45
+ guard : OwnedRwLockWriteGuard < Object < AsyncPgConnection > > ,
46
+ }
47
+
48
+ impl DbConnectionV3 {
49
+ pub fn new ( inner : Arc < RwLock < Object < AsyncPgConnection > > > ) -> Self {
50
+ Self { inner }
51
+ }
52
+
53
+ pub async fn write ( & self ) -> WriteHandle {
54
+ WriteHandle {
55
+ guard : self . inner . clone ( ) . write_owned ( ) . await ,
56
+ }
57
+ }
58
+
59
+ pub async fn transaction < ' a , R , E , F > ( self , callback : F ) -> std:: result:: Result < R , E >
60
+ where
61
+ F : FnOnce ( Self ) -> ScopedBoxFuture < ' a , ' a , std:: result:: Result < R , E > > + Send + ' a ,
62
+ E : From < diesel:: result:: Error > + Send + ' a ,
63
+ R : Send + ' a ,
64
+ {
65
+ use diesel_async:: TransactionManager as _;
66
+
67
+ type TxManager = <AsyncPgConnection as AsyncConnection >:: TransactionManager ;
68
+
69
+ {
70
+ let mut handle = self . write ( ) . await ;
71
+ TxManager :: begin_transaction (
72
+ handle. deref_mut ( ) ,
73
+ )
74
+ . await ?;
75
+ }
76
+
77
+ match callback ( self . clone ( ) ) . await {
78
+ Ok ( result) => {
79
+ let mut handle = self . write ( ) . await ;
80
+ TxManager :: commit_transaction (
81
+ handle. deref_mut ( ) ,
82
+ )
83
+ . await ?;
84
+ Ok ( result)
85
+ }
86
+ Err ( callback_error) => {
87
+ let mut handle = self . write ( ) . await ;
88
+ match TxManager :: rollback_transaction (
89
+ handle. deref_mut ( ) ,
90
+ )
91
+ . await
92
+ {
93
+ Ok ( ( ) ) | Err ( diesel:: result:: Error :: BrokenTransactionManager ) => {
94
+ Err ( callback_error)
95
+ }
96
+ Err ( rollback_error) => Err ( rollback_error. into ( ) ) ,
97
+ }
98
+ }
99
+ }
100
+ }
101
+ }
102
+
103
+ impl Deref for WriteHandle {
104
+ type Target = AsyncPgConnection ;
105
+
106
+ fn deref ( & self ) -> & Self :: Target {
107
+ self . guard . deref ( )
108
+ }
109
+ }
110
+
111
+ impl DerefMut for WriteHandle {
112
+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
113
+ self . guard . deref_mut ( )
114
+ }
115
+ }
116
+
36
117
/// Wrapper for connection pooling with support for test connections on `cfg(test)`
37
118
///
38
119
/// # Testing pool
@@ -47,6 +128,8 @@ pub struct DbConnectionPoolV2 {
47
128
pool : Arc < Pool < AsyncPgConnection > > ,
48
129
#[ cfg( feature = "testing" ) ]
49
130
test_connection : Option < Arc < RwLock < Object < AsyncPgConnection > > > > ,
131
+ #[ cfg( feature = "testing" ) ]
132
+ test_connection_v3 : Option < DbConnectionV3 > ,
50
133
}
51
134
52
135
#[ cfg( feature = "testing" ) ]
@@ -61,8 +144,20 @@ impl Default for DbConnectionPoolV2 {
61
144
pub struct DatabasePoolBuildError ( #[ from] diesel_async:: pooled_connection:: deadpool:: BuildError ) ;
62
145
63
146
#[ derive( Debug , thiserror:: Error ) ]
64
- #[ error( "an error occurred while getting a connection from the database pool: '{0}'" ) ]
65
- pub struct DatabasePoolError ( #[ from] diesel_async:: pooled_connection:: deadpool:: PoolError ) ;
147
+ pub enum DatabasePoolError {
148
+ #[ error( "an error occurred while getting a connection from the database pool: '{0}'" ) ]
149
+ Pool ( #[ from] diesel_async:: pooled_connection:: deadpool:: PoolError ) ,
150
+ #[ error( "an error occured while querying the database: {0}" ) ]
151
+ DieselError ( #[ from] diesel:: result:: Error ) ,
152
+ }
153
+
154
+ #[ derive( Debug , thiserror:: Error ) ]
155
+ pub enum DatabaseTransactionError {
156
+ #[ error( "an error occurred while getting a connection from the database pool: '{0}'" ) ]
157
+ Pool ( #[ from] DatabasePoolError ) ,
158
+ #[ error( "an error occured while querying the database: {0}" ) ]
159
+ DieselError ( #[ from] diesel:: result:: Error ) ,
160
+ }
66
161
67
162
impl DbConnectionPoolV2 {
68
163
/// Get inner pool for retro compatibility
@@ -89,12 +184,27 @@ impl DbConnectionPoolV2 {
89
184
Ok ( connection)
90
185
}
91
186
187
+ #[ cfg( feature = "testing" ) ]
188
+ async fn get_connection_v3 ( & self ) -> Result < DbConnectionV3 , DatabasePoolError > {
189
+ Ok ( self
190
+ . test_connection_v3
191
+ . as_ref ( )
192
+ . expect ( "should already exist" )
193
+ . clone ( ) )
194
+ }
195
+
92
196
#[ cfg( not( feature = "testing" ) ) ]
93
197
async fn get_connection ( & self ) -> Result < DbConnectionV2 , DatabasePoolError > {
94
198
let connection = self . pool . get ( ) . await ?;
95
199
Ok ( connection)
96
200
}
97
201
202
+ #[ cfg( not( feature = "testing" ) ) ]
203
+ async fn get_connection_v3 ( & self ) -> Result < DbConnectionV3 , DatabasePoolError > {
204
+ let connection = self . pool . get ( ) . await ?;
205
+ Ok ( DbConnectionV3 :: new ( Arc :: new ( RwLock :: new ( connection) ) ) )
206
+ }
207
+
98
208
/// Get a connection from the pool
99
209
///
100
210
/// This function behaves differently in test mode.
@@ -228,6 +338,15 @@ impl DbConnectionPoolV2 {
228
338
self . get_connection ( ) . await
229
339
}
230
340
341
+ pub async fn get_v3 ( & self ) -> Result < DbConnectionV3 , DatabasePoolError > {
342
+ self . get_connection_v3 ( ) . await
343
+ }
344
+
345
+ #[ cfg( feature = "testing" ) ]
346
+ pub fn get_ok_v3 ( & self ) -> DbConnectionV3 {
347
+ futures:: executor:: block_on ( self . get_v3 ( ) ) . expect ( "Failed to get test connection" )
348
+ }
349
+
231
350
/// Gets a test connection from the pool synchronously, failing if the connection is not available
232
351
///
233
352
/// In unit tests, this is the preferred way to get a connection
@@ -300,9 +419,23 @@ impl DbConnectionPoolV2 {
300
419
}
301
420
let test_connection = Arc :: new ( RwLock :: new ( conn) ) ;
302
421
422
+ // Conn v3
423
+ let mut conn_v3 = pool
424
+ . get ( )
425
+ . await
426
+ . expect ( "cannot acquire a connection in the test pool" ) ;
427
+ if transaction {
428
+ conn_v3
429
+ . begin_test_transaction ( )
430
+ . await
431
+ . expect ( "cannot begin a test transaction" ) ;
432
+ }
433
+ let test_connection_v3 = Some ( DbConnectionV3 :: new ( Arc :: new ( RwLock :: new ( conn_v3) ) ) ) ;
434
+
303
435
Self {
304
436
pool,
305
437
test_connection : Some ( test_connection) ,
438
+ test_connection_v3,
306
439
}
307
440
}
308
441
@@ -312,7 +445,7 @@ impl DbConnectionPoolV2 {
312
445
. unwrap_or_else ( |_| String :: from ( "postgresql://osrd:password@localhost/osrd" ) ) ;
313
446
let url = Url :: parse ( & url) . expect ( "Failed to parse postgresql url" ) ;
314
447
let pool =
315
- create_connection_pool ( url, 1 ) . expect ( "Failed to initialize test connection pool" ) ;
448
+ create_connection_pool ( url, 2 ) . expect ( "Failed to initialize test connection pool" ) ;
316
449
futures:: executor:: block_on ( Self :: from_pool_test ( Arc :: new ( pool) , transaction) )
317
450
}
318
451
@@ -353,7 +486,7 @@ pub fn create_connection_pool(
353
486
Ok ( Pool :: builder ( manager) . max_size ( max_size) . build ( ) ?)
354
487
}
355
488
356
- fn establish_connection ( config : & str ) -> BoxFuture < ConnectionResult < DbConnection > > {
489
+ fn establish_connection ( config : & str ) -> BoxFuture < ConnectionResult < DieselConnection > > {
357
490
let fut = async {
358
491
let mut connector_builder = SslConnector :: builder ( SslMethod :: tls ( ) ) . unwrap ( ) ;
359
492
connector_builder. set_verify ( SslVerifyMode :: NONE ) ;
@@ -368,7 +501,7 @@ fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DbConnection
368
501
tracing:: error!( "connection error: {}" , e) ;
369
502
}
370
503
} ) ;
371
- DbConnection :: try_from ( client) . await
504
+ DieselConnection :: try_from ( client) . await
372
505
} ;
373
506
fut. boxed ( )
374
507
}
0 commit comments