1
1
use std:: collections:: HashSet ;
2
2
use std:: ops:: DerefMut ;
3
- use std:: sync:: Arc ;
4
3
5
4
use diesel:: { dsl, prelude:: * } ;
6
- use diesel_async:: { scoped_futures:: ScopedFutureExt , RunQueryDsl } ;
5
+ use diesel_async:: { scoped_futures:: ScopedFutureExt as _ , RunQueryDsl } ;
7
6
use editoast_authz:: {
8
7
authorizer:: { StorageDriver , UserInfo } ,
9
8
roles:: { BuiltinRoleSet , RoleConfig } ,
10
9
} ;
11
- use editoast_models:: DbConnectionPoolV2 ;
10
+ use editoast_models:: DbConnection ;
12
11
13
12
use editoast_models:: tables:: * ;
14
13
use itertools:: Itertools as _;
15
14
use tracing:: Level ;
16
15
17
16
#[ derive( Clone ) ]
18
17
pub struct PgAuthDriver < B : BuiltinRoleSet + Send + Sync > {
19
- pool : Arc < DbConnectionPoolV2 > ,
18
+ conn : DbConnection ,
20
19
_role_set : std:: marker:: PhantomData < B > ,
21
20
}
22
21
23
22
impl < B : BuiltinRoleSet + Send + Sync > PgAuthDriver < B > {
24
- pub fn new ( pool : Arc < DbConnectionPoolV2 > ) -> Self {
23
+ pub fn new ( conn : DbConnection ) -> Self {
25
24
Self {
26
- pool ,
25
+ conn ,
27
26
_role_set : Default :: default ( ) ,
28
27
}
29
28
}
@@ -33,8 +32,6 @@ impl<B: BuiltinRoleSet + Send + Sync> PgAuthDriver<B> {
33
32
pub enum AuthDriverError {
34
33
#[ error( transparent) ]
35
34
DieselError ( #[ from] diesel:: result:: Error ) ,
36
- #[ error( transparent) ]
37
- PoolError ( #[ from] editoast_models:: db_connection_pool:: DatabasePoolError ) ,
38
35
}
39
36
40
37
impl < B : BuiltinRoleSet + Send + Sync > StorageDriver for PgAuthDriver < B > {
@@ -43,23 +40,21 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
43
40
44
41
#[ tracing:: instrument( skip_all, fields( %user_info) , ret( level = Level :: DEBUG ) , err) ]
45
42
async fn get_user_id ( & self , user_info : & UserInfo ) -> Result < Option < i64 > , Self :: Error > {
46
- let conn = self . pool . get ( ) . await ?;
47
43
let id = authn_user:: table
48
44
. select ( authn_user:: id)
49
45
. filter ( authn_user:: identity_id. eq ( & user_info. identity ) )
50
- . first :: < i64 > ( conn. write ( ) . await . deref_mut ( ) )
46
+ . first :: < i64 > ( self . conn . write ( ) . await . deref_mut ( ) )
51
47
. await
52
48
. optional ( ) ?;
53
49
Ok ( id)
54
50
}
55
51
56
52
#[ tracing:: instrument( skip_all, fields( %user_id) , ret( level = Level :: DEBUG ) , err) ]
57
53
async fn get_user_info ( & self , user_id : i64 ) -> Result < Option < UserInfo > , Self :: Error > {
58
- let conn = self . pool . get ( ) . await ?;
59
54
let info = authn_user:: table
60
55
. select ( ( authn_user:: identity_id, authn_user:: name) )
61
56
. filter ( authn_user:: id. eq ( user_id) )
62
- . first :: < ( String , Option < String > ) > ( conn. write ( ) . await . deref_mut ( ) )
57
+ . first :: < ( String , Option < String > ) > ( self . conn . write ( ) . await . deref_mut ( ) )
63
58
. await
64
59
. optional ( )
65
60
. map ( |res| {
@@ -73,9 +68,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
73
68
74
69
#[ tracing:: instrument( skip_all, fields( %user) , ret( level = Level :: DEBUG ) , err) ]
75
70
async fn ensure_user ( & self , user : & UserInfo ) -> Result < i64 , Self :: Error > {
76
- self . pool
77
- . get ( )
78
- . await ?
71
+ self . conn
79
72
. transaction ( |conn| {
80
73
async move {
81
74
let user_id = self . get_user_id ( user) . await ?;
@@ -118,16 +111,14 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
118
111
subject_id : i64 ,
119
112
roles_config : & RoleConfig < Self :: BuiltinRole > ,
120
113
) -> Result < HashSet < Self :: BuiltinRole > , Self :: Error > {
121
- let conn = self . pool . get ( ) . await ?;
122
-
123
114
let roles = authz_role:: table
124
115
. select ( authz_role:: role)
125
116
. left_join (
126
117
authn_group_membership:: table. on ( authn_group_membership:: user. eq ( subject_id) ) ,
127
118
)
128
119
. filter ( authz_role:: subject. eq ( subject_id) )
129
120
. or_filter ( authz_role:: subject. eq ( authn_group_membership:: group) )
130
- . load :: < String > ( conn. write ( ) . await . deref_mut ( ) )
121
+ . load :: < String > ( self . conn . write ( ) . await . deref_mut ( ) )
131
122
. await ?
132
123
. into_iter ( )
133
124
. map ( |role| {
@@ -147,8 +138,6 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
147
138
roles_config : & RoleConfig < Self :: BuiltinRole > ,
148
139
roles : HashSet < Self :: BuiltinRole > ,
149
140
) -> Result < ( ) , Self :: Error > {
150
- let conn = self . pool . get ( ) . await ?;
151
-
152
141
dsl:: insert_into ( authz_role:: table)
153
142
. values (
154
143
roles
@@ -163,7 +152,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
163
152
)
164
153
. on_conflict ( ( authz_role:: subject, authz_role:: role) )
165
154
. do_nothing ( )
166
- . execute ( conn. write ( ) . await . deref_mut ( ) )
155
+ . execute ( self . conn . write ( ) . await . deref_mut ( ) )
167
156
. await ?;
168
157
169
158
Ok ( ( ) )
@@ -176,8 +165,6 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
176
165
roles_config : & RoleConfig < Self :: BuiltinRole > ,
177
166
roles : HashSet < Self :: BuiltinRole > ,
178
167
) -> Result < HashSet < Self :: BuiltinRole > , Self :: Error > {
179
- let conn = self . pool . get ( ) . await ?;
180
-
181
168
let deleted_roles = dsl:: delete (
182
169
authz_role:: table
183
170
. filter ( authz_role:: subject. eq ( subject_id) )
@@ -186,7 +173,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
186
173
) ,
187
174
)
188
175
. returning ( authz_role:: role)
189
- . load :: < String > ( conn. write ( ) . await . deref_mut ( ) )
176
+ . load :: < String > ( self . conn . write ( ) . await . deref_mut ( ) )
190
177
. await ?
191
178
. into_iter ( )
192
179
. map ( |role| {
@@ -227,7 +214,7 @@ mod tests {
227
214
#[ rstest:: rstest]
228
215
async fn test_auth_driver ( ) {
229
216
let pool = DbConnectionPoolV2 :: for_tests ( ) ;
230
- let mut driver = PgAuthDriver :: < TestBuiltinRole > :: new ( pool. into ( ) ) ;
217
+ let mut driver = PgAuthDriver :: < TestBuiltinRole > :: new ( pool. get_ok ( ) ) ;
231
218
let config = default_test_config ( ) ;
232
219
233
220
let uid = driver
0 commit comments