Skip to content

Commit 5b79743

Browse files
committed
editoast: destructure AppState early in handlers for consistency
Signed-off-by: Leo Valais <[email protected]>
1 parent c339b26 commit 5b79743

File tree

10 files changed

+114
-96
lines changed

10 files changed

+114
-96
lines changed

editoast/src/views/infra/mod.rs

+38-30
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,13 @@ struct RefreshResponse {
123123
)
124124
)]
125125
async fn refresh(
126-
app_state: State<AppState>,
126+
State(AppState {
127+
db_pool,
128+
valkey: valkey_client,
129+
infra_caches,
130+
map_layers,
131+
..
132+
}): State<AppState>,
127133
Extension(auth): AuthenticationExt,
128134
Query(query_params): Query<RefreshQueryParams>,
129135
) -> Result<Json<RefreshResponse>> {
@@ -135,11 +141,6 @@ async fn refresh(
135141
return Err(AuthorizationError::Unauthorized.into());
136142
}
137143

138-
let db_pool = app_state.db_pool.clone();
139-
let valkey_client = app_state.valkey.clone();
140-
let infra_caches = app_state.infra_caches.clone();
141-
let map_layers = app_state.map_layers.clone();
142-
143144
// Use a transaction to give scope to infra list lock
144145
let RefreshQueryParams {
145146
force,
@@ -160,7 +161,6 @@ async fn refresh(
160161
};
161162

162163
// Refresh each infras
163-
let db_pool = db_pool;
164164
let mut infra_refreshed = vec![];
165165

166166
for mut infra in infras_list {
@@ -201,7 +201,11 @@ struct InfraListResponse {
201201
),
202202
)]
203203
async fn list(
204-
app_state: State<AppState>,
204+
State(AppState {
205+
db_pool,
206+
osrdyne_client,
207+
..
208+
}): State<AppState>,
205209
Extension(auth): AuthenticationExt,
206210
pagination_params: Query<PaginationQueryParams>,
207211
) -> Result<Json<InfraListResponse>> {
@@ -212,8 +216,6 @@ async fn list(
212216
if !authorized {
213217
return Err(AuthorizationError::Unauthorized.into());
214218
}
215-
let db_pool = app_state.db_pool.clone();
216-
let osrdyne_client = app_state.osrdyne_client.clone();
217219

218220
let settings = pagination_params
219221
.validate(1000)?
@@ -295,7 +297,11 @@ struct InfraIdParam {
295297
),
296298
)]
297299
async fn get(
298-
app_state: State<AppState>,
300+
State(AppState {
301+
db_pool,
302+
osrdyne_client,
303+
..
304+
}): State<AppState>,
299305
Extension(auth): AuthenticationExt,
300306
Path(infra): Path<InfraIdParam>,
301307
) -> Result<Json<InfraWithState>> {
@@ -307,9 +313,6 @@ async fn get(
307313
return Err(AuthorizationError::Unauthorized.into());
308314
}
309315

310-
let db_pool = app_state.db_pool.clone();
311-
let osrdyne_client = app_state.osrdyne_client.clone();
312-
313316
let infra_id = infra.infra_id;
314317
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
315318
InfraApiError::NotFound { infra_id }
@@ -344,7 +347,7 @@ impl From<InfraCreateForm> for Changeset<Infra> {
344347
),
345348
)]
346349
async fn create(
347-
db_pool: State<DbConnectionPoolV2>,
350+
State(db_pool): State<DbConnectionPoolV2>,
348351
Extension(auth): AuthenticationExt,
349352
Json(infra_form): Json<InfraCreateForm>,
350353
) -> Result<impl IntoResponse> {
@@ -381,7 +384,7 @@ struct CloneQuery {
381384
async fn clone(
382385
Extension(auth): AuthenticationExt,
383386
Path(params): Path<InfraIdParam>,
384-
db_pool: State<DbConnectionPoolV2>,
387+
State(db_pool): State<DbConnectionPoolV2>,
385388
Query(CloneQuery { name }): Query<CloneQuery>,
386389
) -> Result<Json<i64>> {
387390
let authorized = auth
@@ -421,7 +424,11 @@ async fn clone(
421424
),
422425
)]
423426
async fn delete(
424-
app_state: State<AppState>,
427+
State(AppState {
428+
db_pool,
429+
infra_caches,
430+
..
431+
}): State<AppState>,
425432
Extension(auth): AuthenticationExt,
426433
infra: Path<InfraIdParam>,
427434
) -> Result<impl IntoResponse> {
@@ -433,8 +440,6 @@ async fn delete(
433440
return Err(AuthorizationError::Unauthorized.into());
434441
}
435442

436-
let db_pool = app_state.db_pool.clone();
437-
let infra_caches = app_state.infra_caches.clone();
438443
let infra_id = infra.infra_id;
439444
if Infra::fast_delete_static(db_pool.get().await?, infra_id).await? {
440445
infra_caches.remove(&infra_id);
@@ -468,7 +473,7 @@ impl From<InfraPatchForm> for Changeset<Infra> {
468473
),
469474
)]
470475
async fn put(
471-
db_pool: State<DbConnectionPoolV2>,
476+
State(db_pool): State<DbConnectionPoolV2>,
472477
Extension(auth): AuthenticationExt,
473478
Path(infra): Path<i64>,
474479
Json(patch): Json<InfraPatchForm>,
@@ -501,7 +506,11 @@ async fn put(
501506
)
502507
)]
503508
async fn get_switch_types(
504-
app_state: State<AppState>,
509+
State(AppState {
510+
db_pool,
511+
infra_caches,
512+
..
513+
}): State<AppState>,
505514
Extension(auth): AuthenticationExt,
506515
Path(infra): Path<InfraIdParam>,
507516
) -> Result<Json<Vec<SwitchType>>> {
@@ -513,9 +522,7 @@ async fn get_switch_types(
513522
return Err(AuthorizationError::Unauthorized.into());
514523
}
515524

516-
let db_pool = app_state.db_pool.clone();
517525
let conn = &mut db_pool.get().await?;
518-
let infra_caches = app_state.infra_caches.clone();
519526

520527
let infra = Infra::retrieve_or_fail(conn, infra.infra_id, || InfraApiError::NotFound {
521528
infra_id: infra.infra_id,
@@ -546,7 +553,7 @@ async fn get_switch_types(
546553
async fn get_speed_limit_tags(
547554
Extension(auth): AuthenticationExt,
548555
Path(infra): Path<InfraIdParam>,
549-
db_pool: State<DbConnectionPoolV2>,
556+
State(db_pool): State<DbConnectionPoolV2>,
550557
) -> Result<Json<Vec<String>>> {
551558
let authorized = auth
552559
.check_roles([BuiltinRole::InfraRead].into())
@@ -590,7 +597,7 @@ async fn get_voltages(
590597
Extension(auth): AuthenticationExt,
591598
Path(infra): Path<InfraIdParam>,
592599
Query(param): Query<GetVoltagesQueryParams>,
593-
db_pool: State<DbConnectionPoolV2>,
600+
State(db_pool): State<DbConnectionPoolV2>,
594601
) -> Result<Json<Vec<String>>> {
595602
let authorized = auth
596603
.check_roles([BuiltinRole::InfraRead].into())
@@ -623,7 +630,7 @@ async fn get_voltages(
623630
)
624631
)]
625632
async fn get_all_voltages(
626-
db_pool: State<DbConnectionPoolV2>,
633+
State(db_pool): State<DbConnectionPoolV2>,
627634
Extension(auth): AuthenticationExt,
628635
) -> Result<Json<Vec<String>>> {
629636
let authorized = auth
@@ -712,7 +719,11 @@ async fn unlock(
712719
)
713720
)]
714721
async fn load(
715-
app_state: State<AppState>,
722+
State(AppState {
723+
db_pool,
724+
core_client,
725+
..
726+
}): State<AppState>,
716727
Extension(auth): AuthenticationExt,
717728
Path(path): Path<InfraIdParam>,
718729
) -> Result<impl IntoResponse> {
@@ -724,9 +735,6 @@ async fn load(
724735
return Err(AuthorizationError::Unauthorized.into());
725736
}
726737

727-
let db_pool = app_state.db_pool.clone();
728-
let core_client = app_state.core_client.clone();
729-
730738
let infra_id = path.infra_id;
731739
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
732740
InfraApiError::NotFound { infra_id }

editoast/src/views/infra/pathfinding.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ struct QueryParam {
9696
)
9797
)]
9898
async fn pathfinding_view(
99-
app_state: State<AppState>,
99+
State(AppState {
100+
db_pool,
101+
infra_caches,
102+
..
103+
}): State<AppState>,
100104
Extension(auth): AuthenticationExt,
101105
Path(infra): Path<InfraIdParam>,
102106
Query(params): Query<QueryParam>,
@@ -110,9 +114,6 @@ async fn pathfinding_view(
110114
return Err(AuthorizationError::Unauthorized.into());
111115
}
112116

113-
let db_pool = app_state.db_pool.clone();
114-
let infra_caches = app_state.infra_caches.clone();
115-
116117
// Parse and check input
117118
let infra_id = infra.infra_id;
118119
let number = params.number.unwrap_or(DEFAULT_NUMBER_OF_PATHS);

editoast/src/views/infra/railjson.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ enum ListErrorsRailjson {
5555
)]
5656
async fn get_railjson(
5757
Path(infra): Path<InfraIdParam>,
58-
db_pool: State<DbConnectionPoolV2>,
58+
State(db_pool): State<DbConnectionPoolV2>,
5959
Extension(auth): AuthenticationExt,
6060
) -> Result<impl IntoResponse> {
6161
let authorized = auth
@@ -168,7 +168,11 @@ struct PostRailjsonResponse {
168168
)
169169
)]
170170
async fn post_railjson(
171-
app_state: State<AppState>,
171+
State(AppState {
172+
db_pool,
173+
infra_caches,
174+
..
175+
}): State<AppState>,
172176
Extension(auth): AuthenticationExt,
173177
Query(params): Query<PostRailjsonQueryParams>,
174178
Json(railjson): Json<RailJson>,
@@ -181,8 +185,6 @@ async fn post_railjson(
181185
return Err(AuthorizationError::Unauthorized.into());
182186
}
183187

184-
let db_pool = app_state.db_pool.clone();
185-
let infra_caches = app_state.infra_caches.clone();
186188
if railjson.version != RAILJSON_VERSION {
187189
return Err(ListErrorsRailjson::WrongRailjsonVersionProvided.into());
188190
}

editoast/src/views/infra/routes.rs

+13-8
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct RoutesResponse {
6767
)]
6868
async fn get_routes_from_waypoint(
6969
Path(path): Path<RoutesFromWaypointParams>,
70-
db_pool: State<DbConnectionPoolV2>,
70+
State(db_pool): State<DbConnectionPoolV2>,
7171
Extension(auth): AuthenticationExt,
7272
) -> Result<Json<RoutesResponse>> {
7373
let authorized = auth
@@ -145,7 +145,11 @@ struct RoutesFromNodesPositions {
145145
),
146146
)]
147147
async fn get_routes_track_ranges(
148-
app_state: State<AppState>,
148+
State(AppState {
149+
db_pool,
150+
infra_caches,
151+
..
152+
}): State<AppState>,
149153
Extension(auth): AuthenticationExt,
150154
Path(infra): Path<i64>,
151155
Query(params): Query<RouteTrackRangesParams>,
@@ -158,8 +162,8 @@ async fn get_routes_track_ranges(
158162
return Err(AuthorizationError::Unauthorized.into());
159163
}
160164

161-
let db_pool = app_state.db_pool.clone();
162-
let infra_caches = app_state.infra_caches.clone();
165+
let db_pool = db_pool.clone();
166+
let infra_caches = infra_caches.clone();
163167
let infra_id = infra;
164168
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
165169
InfraApiError::NotFound { infra_id }
@@ -206,7 +210,11 @@ async fn get_routes_track_ranges(
206210
),
207211
)]
208212
async fn get_routes_nodes(
209-
app_state: State<AppState>,
213+
State(AppState {
214+
db_pool,
215+
infra_caches,
216+
..
217+
}): State<AppState>,
210218
Extension(auth): AuthenticationExt,
211219
Path(params): Path<InfraIdParam>,
212220
Json(node_states): Json<HashMap<String, Option<String>>>,
@@ -219,9 +227,6 @@ async fn get_routes_nodes(
219227
return Err(AuthorizationError::Unauthorized.into());
220228
}
221229

222-
let db_pool = app_state.db_pool.clone();
223-
let infra_caches = app_state.infra_caches.clone();
224-
225230
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, params.infra_id, || {
226231
InfraApiError::NotFound {
227232
infra_id: params.infra_id,

editoast/src/views/mod.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,11 @@ async fn version() -> Json<Version> {
334334
(status = 200, description = "Return the core service version", body = Version),
335335
),
336336
)]
337-
async fn core_version(app_state: State<AppState>) -> Json<Version> {
338-
let core = app_state.core_client.clone();
337+
async fn core_version(
338+
State(AppState {
339+
core_client: core, ..
340+
}): State<AppState>,
341+
) -> Json<Version> {
339342
let response = CoreVersionRequest {}.fetch(&core).await;
340343
let response = response.unwrap_or(Version { git_describe: None });
341344
Json(response)

editoast/src/views/stdcm_search_environment.rs

+3-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use axum::response::Response;
66
use axum::Extension;
77
use chrono::NaiveDateTime;
88
use editoast_authz::BuiltinRole;
9+
use editoast_models::DbConnectionPoolV2;
910
use serde::de::Error as SerdeError;
1011
use serde::Deserialize;
1112
use std::result::Result as StdResult;
@@ -19,7 +20,6 @@ use crate::models::stdcm_search_environment::StdcmSearchEnvironment;
1920
use crate::models::Changeset;
2021
use crate::views::AuthenticationExt;
2122
use crate::views::AuthorizationError;
22-
use crate::AppState;
2323
use crate::Model;
2424

2525
crate::routes! {
@@ -106,7 +106,7 @@ impl From<StdcmSearchEnvironmentCreateForm> for Changeset<StdcmSearchEnvironment
106106
)
107107
)]
108108
async fn overwrite(
109-
State(app_state): State<AppState>,
109+
State(db_pool): State<DbConnectionPoolV2>,
110110
Extension(auth): AuthenticationExt,
111111
Json(form): Json<StdcmSearchEnvironmentCreateForm>,
112112
) -> Result<impl IntoResponse> {
@@ -118,11 +118,8 @@ async fn overwrite(
118118
return Err(AuthorizationError::Unauthorized.into());
119119
}
120120

121-
let db_pool = app_state.db_pool.clone();
122121
let conn = &mut db_pool.get().await?;
123-
124122
let changeset: Changeset<StdcmSearchEnvironment> = form.into();
125-
126123
Ok((StatusCode::CREATED, Json(changeset.overwrite(conn).await?)))
127124
}
128125

@@ -135,7 +132,7 @@ async fn overwrite(
135132
)
136133
)]
137134
async fn retrieve_latest(
138-
State(app_state): State<AppState>,
135+
State(db_pool): State<DbConnectionPoolV2>,
139136
Extension(auth): AuthenticationExt,
140137
) -> Result<Response> {
141138
let authorized = auth
@@ -146,9 +143,7 @@ async fn retrieve_latest(
146143
return Err(AuthorizationError::Unauthorized.into());
147144
}
148145

149-
let db_pool = app_state.db_pool.clone();
150146
let conn = &mut db_pool.get().await?;
151-
152147
let search_env = StdcmSearchEnvironment::retrieve_latest(conn).await;
153148
if let Some(search_env) = search_env {
154149
Ok(Json(search_env).into_response())

0 commit comments

Comments
 (0)