1
1
use std:: collections:: HashSet ;
2
2
use std:: ops:: DerefMut ;
3
- use std:: sync:: Arc ;
4
3
5
4
use diesel:: { dsl, prelude:: * } ;
6
- use diesel_async:: { scoped_futures:: ScopedFutureExt , RunQueryDsl } ;
5
+ use diesel_async:: { scoped_futures:: ScopedFutureExt as _ , RunQueryDsl } ;
7
6
use editoast_authz:: {
8
7
authorizer:: { StorageDriver , UserInfo } ,
9
8
roles:: { BuiltinRoleSet , RoleConfig } ,
10
9
} ;
11
- use editoast_models:: DbConnectionPoolV2 ;
10
+ use editoast_models:: DbConnection ;
12
11
13
12
use editoast_models:: tables:: * ;
14
13
use itertools:: Itertools as _;
15
14
use tracing:: Level ;
16
15
17
16
#[ derive( Clone ) ]
18
17
pub struct PgAuthDriver < B : BuiltinRoleSet + Send + Sync > {
19
- pool : Arc < DbConnectionPoolV2 > ,
18
+ conn : DbConnection ,
20
19
_role_set : std:: marker:: PhantomData < B > ,
21
20
}
22
21
23
22
impl < B : BuiltinRoleSet + Send + Sync > PgAuthDriver < B > {
24
- pub fn new ( pool : Arc < DbConnectionPoolV2 > ) -> Self {
23
+ pub fn new ( conn : DbConnection ) -> Self {
25
24
Self {
26
- pool ,
25
+ conn ,
27
26
_role_set : Default :: default ( ) ,
28
27
}
29
28
}
@@ -43,23 +42,21 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
43
42
44
43
#[ tracing:: instrument( skip_all, fields( %user_info) , ret( level = Level :: DEBUG ) , err) ]
45
44
async fn get_user_id ( & self , user_info : & UserInfo ) -> Result < Option < i64 > , Self :: Error > {
46
- let conn = self . pool . get ( ) . await ?;
47
45
let id = authn_user:: table
48
46
. select ( authn_user:: id)
49
47
. filter ( authn_user:: identity_id. eq ( & user_info. identity ) )
50
- . first :: < i64 > ( conn. write ( ) . await . deref_mut ( ) )
48
+ . first :: < i64 > ( self . conn . write ( ) . await . deref_mut ( ) )
51
49
. await
52
50
. optional ( ) ?;
53
51
Ok ( id)
54
52
}
55
53
56
54
#[ tracing:: instrument( skip_all, fields( %user_id) , ret( level = Level :: DEBUG ) , err) ]
57
55
async fn get_user_info ( & self , user_id : i64 ) -> Result < Option < UserInfo > , Self :: Error > {
58
- let conn = self . pool . get ( ) . await ?;
59
56
let info = authn_user:: table
60
57
. select ( ( authn_user:: identity_id, authn_user:: name) )
61
58
. filter ( authn_user:: id. eq ( user_id) )
62
- . first :: < ( String , Option < String > ) > ( conn. write ( ) . await . deref_mut ( ) )
59
+ . first :: < ( String , Option < String > ) > ( self . conn . write ( ) . await . deref_mut ( ) )
63
60
. await
64
61
. optional ( )
65
62
. map ( |res| {
@@ -73,9 +70,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
73
70
74
71
#[ tracing:: instrument( skip_all, fields( %user) , ret( level = Level :: DEBUG ) , err) ]
75
72
async fn ensure_user ( & self , user : & UserInfo ) -> Result < i64 , Self :: Error > {
76
- self . pool
77
- . get ( )
78
- . await ?
73
+ self . conn
79
74
. transaction ( |conn| {
80
75
async move {
81
76
let user_id = self . get_user_id ( user) . await ?;
@@ -118,16 +113,14 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
118
113
subject_id : i64 ,
119
114
roles_config : & RoleConfig < Self :: BuiltinRole > ,
120
115
) -> Result < HashSet < Self :: BuiltinRole > , Self :: Error > {
121
- let conn = self . pool . get ( ) . await ?;
122
-
123
116
let roles = authz_role:: table
124
117
. select ( authz_role:: role)
125
118
. left_join (
126
119
authn_group_membership:: table. on ( authn_group_membership:: user. eq ( subject_id) ) ,
127
120
)
128
121
. filter ( authz_role:: subject. eq ( subject_id) )
129
122
. or_filter ( authz_role:: subject. eq ( authn_group_membership:: group) )
130
- . load :: < String > ( conn. write ( ) . await . deref_mut ( ) )
123
+ . load :: < String > ( self . conn . write ( ) . await . deref_mut ( ) )
131
124
. await ?
132
125
. into_iter ( )
133
126
. map ( |role| {
@@ -147,8 +140,6 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
147
140
roles_config : & RoleConfig < Self :: BuiltinRole > ,
148
141
roles : HashSet < Self :: BuiltinRole > ,
149
142
) -> Result < ( ) , Self :: Error > {
150
- let conn = self . pool . get ( ) . await ?;
151
-
152
143
dsl:: insert_into ( authz_role:: table)
153
144
. values (
154
145
roles
@@ -163,7 +154,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
163
154
)
164
155
. on_conflict ( ( authz_role:: subject, authz_role:: role) )
165
156
. do_nothing ( )
166
- . execute ( conn. write ( ) . await . deref_mut ( ) )
157
+ . execute ( self . conn . write ( ) . await . deref_mut ( ) )
167
158
. await ?;
168
159
169
160
Ok ( ( ) )
@@ -176,8 +167,6 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
176
167
roles_config : & RoleConfig < Self :: BuiltinRole > ,
177
168
roles : HashSet < Self :: BuiltinRole > ,
178
169
) -> Result < HashSet < Self :: BuiltinRole > , Self :: Error > {
179
- let conn = self . pool . get ( ) . await ?;
180
-
181
170
let deleted_roles = dsl:: delete (
182
171
authz_role:: table
183
172
. filter ( authz_role:: subject. eq ( subject_id) )
@@ -186,7 +175,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
186
175
) ,
187
176
)
188
177
. returning ( authz_role:: role)
189
- . load :: < String > ( conn. write ( ) . await . deref_mut ( ) )
178
+ . load :: < String > ( self . conn . write ( ) . await . deref_mut ( ) )
190
179
. await ?
191
180
. into_iter ( )
192
181
. map ( |role| {
@@ -227,7 +216,7 @@ mod tests {
227
216
#[ rstest:: rstest]
228
217
async fn test_auth_driver ( ) {
229
218
let pool = DbConnectionPoolV2 :: for_tests ( ) ;
230
- let mut driver = PgAuthDriver :: < TestBuiltinRole > :: new ( pool. into ( ) ) ;
219
+ let mut driver = PgAuthDriver :: < TestBuiltinRole > :: new ( pool. get_ok ( ) ) ;
231
220
let config = default_test_config ( ) ;
232
221
233
222
let uid = driver
0 commit comments