Skip to content

Commit 9c8b0a1

Browse files
committed
editoast: provide PgAuthDriver a connection instead of the pool
Signed-off-by: Leo Valais <[email protected]>
1 parent 3c6a80d commit 9c8b0a1

File tree

3 files changed

+18
-26
lines changed

3 files changed

+18
-26
lines changed

editoast/editoast_models/src/db_connection_pool.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl DbConnection {
5555
//
5656
// :WARNING: If you ever need to modify this function, please take a look at the
5757
// original `diesel` function, they probably do it right more than us.
58-
pub async fn transaction<'a, R, E, F>(self, callback: F) -> std::result::Result<R, E>
58+
pub async fn transaction<'a, R, E, F>(&self, callback: F) -> std::result::Result<R, E>
5959
where
6060
F: FnOnce(Self) -> ScopedBoxFuture<'a, 'a, std::result::Result<R, E>> + Send + 'a,
6161
E: From<diesel::result::Error> + Send + 'a,

editoast/src/models/auth.rs

+12-23
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
use std::collections::HashSet;
22
use std::ops::DerefMut;
3-
use std::sync::Arc;
43

54
use diesel::{dsl, prelude::*};
6-
use diesel_async::{scoped_futures::ScopedFutureExt, RunQueryDsl};
5+
use diesel_async::{scoped_futures::ScopedFutureExt as _, RunQueryDsl};
76
use editoast_authz::{
87
authorizer::{StorageDriver, UserInfo},
98
roles::{BuiltinRoleSet, RoleConfig},
109
};
11-
use editoast_models::DbConnectionPoolV2;
10+
use editoast_models::DbConnection;
1211

1312
use editoast_models::tables::*;
1413
use itertools::Itertools as _;
1514
use tracing::Level;
1615

1716
#[derive(Clone)]
1817
pub struct PgAuthDriver<B: BuiltinRoleSet + Send + Sync> {
19-
pool: Arc<DbConnectionPoolV2>,
18+
conn: DbConnection,
2019
_role_set: std::marker::PhantomData<B>,
2120
}
2221

2322
impl<B: BuiltinRoleSet + Send + Sync> PgAuthDriver<B> {
24-
pub fn new(pool: Arc<DbConnectionPoolV2>) -> Self {
23+
pub fn new(conn: DbConnection) -> Self {
2524
Self {
26-
pool,
25+
conn,
2726
_role_set: Default::default(),
2827
}
2928
}
@@ -43,23 +42,21 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
4342

4443
#[tracing::instrument(skip_all, fields(%user_info), ret(level = Level::DEBUG), err)]
4544
async fn get_user_id(&self, user_info: &UserInfo) -> Result<Option<i64>, Self::Error> {
46-
let conn = self.pool.get().await?;
4745
let id = authn_user::table
4846
.select(authn_user::id)
4947
.filter(authn_user::identity_id.eq(&user_info.identity))
50-
.first::<i64>(conn.write().await.deref_mut())
48+
.first::<i64>(self.conn.write().await.deref_mut())
5149
.await
5250
.optional()?;
5351
Ok(id)
5452
}
5553

5654
#[tracing::instrument(skip_all, fields(%user_id), ret(level = Level::DEBUG), err)]
5755
async fn get_user_info(&self, user_id: i64) -> Result<Option<UserInfo>, Self::Error> {
58-
let conn = self.pool.get().await?;
5956
let info = authn_user::table
6057
.select((authn_user::identity_id, authn_user::name))
6158
.filter(authn_user::id.eq(user_id))
62-
.first::<(String, Option<String>)>(conn.write().await.deref_mut())
59+
.first::<(String, Option<String>)>(self.conn.write().await.deref_mut())
6360
.await
6461
.optional()
6562
.map(|res| {
@@ -73,9 +70,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
7370

7471
#[tracing::instrument(skip_all, fields(%user), ret(level = Level::DEBUG), err)]
7572
async fn ensure_user(&self, user: &UserInfo) -> Result<i64, Self::Error> {
76-
self.pool
77-
.get()
78-
.await?
73+
self.conn
7974
.transaction(|conn| {
8075
async move {
8176
let user_id = self.get_user_id(user).await?;
@@ -118,16 +113,14 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
118113
subject_id: i64,
119114
roles_config: &RoleConfig<Self::BuiltinRole>,
120115
) -> Result<HashSet<Self::BuiltinRole>, Self::Error> {
121-
let conn = self.pool.get().await?;
122-
123116
let roles = authz_role::table
124117
.select(authz_role::role)
125118
.left_join(
126119
authn_group_membership::table.on(authn_group_membership::user.eq(subject_id)),
127120
)
128121
.filter(authz_role::subject.eq(subject_id))
129122
.or_filter(authz_role::subject.eq(authn_group_membership::group))
130-
.load::<String>(conn.write().await.deref_mut())
123+
.load::<String>(self.conn.write().await.deref_mut())
131124
.await?
132125
.into_iter()
133126
.map(|role| {
@@ -147,8 +140,6 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
147140
roles_config: &RoleConfig<Self::BuiltinRole>,
148141
roles: HashSet<Self::BuiltinRole>,
149142
) -> Result<(), Self::Error> {
150-
let conn = self.pool.get().await?;
151-
152143
dsl::insert_into(authz_role::table)
153144
.values(
154145
roles
@@ -163,7 +154,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
163154
)
164155
.on_conflict((authz_role::subject, authz_role::role))
165156
.do_nothing()
166-
.execute(conn.write().await.deref_mut())
157+
.execute(self.conn.write().await.deref_mut())
167158
.await?;
168159

169160
Ok(())
@@ -176,8 +167,6 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
176167
roles_config: &RoleConfig<Self::BuiltinRole>,
177168
roles: HashSet<Self::BuiltinRole>,
178169
) -> Result<HashSet<Self::BuiltinRole>, Self::Error> {
179-
let conn = self.pool.get().await?;
180-
181170
let deleted_roles = dsl::delete(
182171
authz_role::table
183172
.filter(authz_role::subject.eq(subject_id))
@@ -186,7 +175,7 @@ impl<B: BuiltinRoleSet + Send + Sync> StorageDriver for PgAuthDriver<B> {
186175
),
187176
)
188177
.returning(authz_role::role)
189-
.load::<String>(conn.write().await.deref_mut())
178+
.load::<String>(self.conn.write().await.deref_mut())
190179
.await?
191180
.into_iter()
192181
.map(|role| {
@@ -227,7 +216,7 @@ mod tests {
227216
#[rstest::rstest]
228217
async fn test_auth_driver() {
229218
let pool = DbConnectionPoolV2::for_tests();
230-
let mut driver = PgAuthDriver::<TestBuiltinRole>::new(pool.into());
219+
let mut driver = PgAuthDriver::<TestBuiltinRole>::new(pool.get_ok());
231220
let config = default_test_config();
232221

233222
let uid = driver

editoast/src/views/mod.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async fn make_authorizer(
123123
if roles.is_superuser() {
124124
return Ok(Authorizer::new_superuser(
125125
roles,
126-
PgAuthDriver::<BuiltinRole>::new(db_pool.clone()),
126+
PgAuthDriver::<BuiltinRole>::new(db_pool.get().await?),
127127
));
128128
}
129129
let Some(header) = headers.get("x-remote-user") else {
@@ -140,7 +140,7 @@ async fn make_authorizer(
140140
name: name.to_owned(),
141141
},
142142
roles,
143-
PgAuthDriver::<BuiltinRole>::new(db_pool.clone()),
143+
PgAuthDriver::<BuiltinRole>::new(db_pool.get().await?),
144144
)
145145
.await?;
146146
Ok(authorizer)
@@ -175,6 +175,9 @@ pub enum AuthorizationError {
175175
AuthError(
176176
#[from] <PgAuthDriver<BuiltinRole> as editoast_authz::authorizer::StorageDriver>::Error,
177177
),
178+
#[error(transparent)]
179+
#[editoast_error(status = 500)]
180+
DbError(#[from] editoast_models::db_connection_pool::DatabasePoolError),
178181
}
179182

180183
#[derive(Debug, Error, EditoastError)]

0 commit comments

Comments
 (0)