Skip to content

Commit 999cee5

Browse files
committed
editoast: impl modelsv2::List for Infra and refactor cache queries
1 parent 08271f1 commit 999cee5

File tree

4 files changed

+78
-105
lines changed

4 files changed

+78
-105
lines changed

editoast/src/modelsv2/infra.rs

-33
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@ mod voltage;
77

88
use std::pin::Pin;
99

10-
use async_trait::async_trait;
1110
use chrono::NaiveDateTime;
1211
use chrono::Utc;
1312
use derivative::Derivative;
1413
use diesel::sql_query;
1514
use diesel::sql_types::BigInt;
16-
use diesel::QueryDsl;
1715
use diesel_async::RunQueryDsl;
1816
use editoast_derive::ModelV2;
1917
use futures::future::try_join_all;
@@ -29,8 +27,6 @@ use uuid::Uuid;
2927
use crate::error::Result;
3028
use crate::generated_data;
3129
use crate::infra_cache::InfraCache;
32-
use crate::models::List;
33-
use crate::models::NoParams;
3430
use crate::modelsv2::get_geometry_layer_table;
3531
use crate::modelsv2::get_table;
3632
use crate::modelsv2::prelude::*;
@@ -39,8 +35,6 @@ use crate::modelsv2::Create;
3935
use crate::modelsv2::DbConnection;
4036
use crate::modelsv2::DbConnectionPool;
4137
use crate::tables::infra::dsl;
42-
use crate::views::pagination::Paginate;
43-
use crate::views::pagination::PaginatedResponse;
4438
use editoast_schemas::infra::RailJson;
4539
use editoast_schemas::infra::RAILJSON_VERSION;
4640
use editoast_schemas::primitives::ObjectType;
@@ -227,33 +221,6 @@ impl Infra {
227221
}
228222
}
229223

230-
#[async_trait]
231-
impl List<NoParams> for Infra {
232-
async fn list_conn(
233-
conn: &mut DbConnection,
234-
page: i64,
235-
page_size: i64,
236-
_: NoParams,
237-
) -> Result<PaginatedResponse<Self>> {
238-
let PaginatedResponse {
239-
count,
240-
previous,
241-
next,
242-
results,
243-
} = dsl::infra
244-
.distinct()
245-
.paginate(page, page_size)
246-
.load_and_count::<Row<Self>>(conn)
247-
.await?;
248-
Ok(PaginatedResponse {
249-
count,
250-
previous,
251-
next,
252-
results: results.into_iter().map(Self::from_row).collect(),
253-
})
254-
}
255-
}
256-
257224
#[cfg(test)]
258225
pub mod tests {
259226
use actix_web::test as actix_test;

editoast/src/views/infra/mod.rs

+68-61
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ use editoast_derive::EditoastError;
2626
use serde::Deserialize;
2727
use serde::Serialize;
2828
use std::collections::HashMap;
29+
use std::ops::DerefMut as _;
2930
use std::sync::Arc;
3031
use thiserror::Error;
3132
use utoipa::IntoParams;
3233
use utoipa::ToSchema;
3334

3435
use self::edition::edit;
3536
use self::edition::split_track_section;
37+
use super::pagination::PaginationStats;
3638
use super::params::List;
3739
use crate::core::infra_loading::InfraLoadRequest;
3840
use crate::core::infra_state::InfraStateRequest;
@@ -44,12 +46,10 @@ use crate::infra_cache::InfraCache;
4446
use crate::infra_cache::ObjectCache;
4547
use crate::map;
4648
use crate::map::MapLayers;
47-
use crate::models::List as ModelList;
48-
use crate::models::NoParams;
4949
use crate::modelsv2::prelude::*;
5050
use crate::modelsv2::DbConnectionPool;
5151
use crate::modelsv2::Infra;
52-
use crate::views::pagination::PaginatedResponse;
52+
use crate::views::pagination::PaginatedList as _;
5353
use crate::views::pagination::PaginationQueryParam;
5454
use crate::RedisClient;
5555
use editoast_schemas::infra::SwitchType;
@@ -211,40 +211,47 @@ async fn refresh(
211211
Ok(Json(RefreshResponse { infra_refreshed }))
212212
}
213213

214+
#[derive(Serialize, ToSchema)]
215+
struct InfraListResponse {
216+
#[serde(flatten)]
217+
stats: PaginationStats,
218+
results: Vec<InfraWithState>,
219+
}
220+
214221
/// Return a list of infras
215222
#[get("")]
216223
async fn list(
217224
db_pool: Data<DbConnectionPool>,
218225
core: Data<CoreClient>,
219226
pagination_params: Query<PaginationQueryParam>,
220-
) -> Result<Json<PaginatedResponse<InfraWithState>>> {
221-
let (page, per_page) = pagination_params
227+
) -> Result<Json<InfraListResponse>> {
228+
let settings = pagination_params
222229
.validate(1000)?
223230
.warn_page_size(100)
224-
.unpack();
225-
let db_pool = db_pool.into_inner();
226-
let infras = Infra::list(db_pool.clone(), page, per_page, NoParams).await?;
227-
let infra_state = call_core_infra_state(None, db_pool, core).await?;
228-
let infras_with_state: Vec<InfraWithState> = infras
229-
.results
230-
.into_iter()
231-
.map(|infra| {
232-
let infra_id = infra.id;
233-
let state = infra_state
234-
.get(&infra_id.to_string())
235-
.unwrap_or(&InfraStateResponse::default())
236-
.status;
237-
InfraWithState { infra, state }
238-
})
239-
.collect();
240-
let infras_with_state = PaginatedResponse::<InfraWithState> {
241-
count: infras.count,
242-
previous: infras.previous,
243-
next: infras.next,
244-
results: infras_with_state,
231+
.into_selection_settings();
232+
233+
let ((infras, stats), infra_states) = {
234+
let conn = &mut db_pool.get().await?;
235+
futures::try_join!(
236+
Infra::list_paginated(conn, settings),
237+
fetch_all_infra_states(core.as_ref()),
238+
)?
245239
};
246240

247-
Ok(Json(infras_with_state))
241+
let response = InfraListResponse {
242+
stats,
243+
results: infras
244+
.into_iter()
245+
.map(|infra| {
246+
let state = infra_states
247+
.get(&infra.id.to_string())
248+
.map(|response| response.status)
249+
.unwrap_or_default();
250+
InfraWithState { infra, state }
251+
})
252+
.collect(),
253+
};
254+
Ok(Json(response))
248255
}
249256

250257
#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq, Serialize, ToSchema)]
@@ -296,11 +303,7 @@ async fn get(
296303
let conn = &mut db_pool.get().await?;
297304
let infra =
298305
Infra::retrieve_or_fail(conn, infra_id, || InfraApiError::NotFound { infra_id }).await?;
299-
let infra_state = call_core_infra_state(Some(infra_id), db_pool.into_inner(), core).await?;
300-
let state = infra_state
301-
.get(&infra_id.to_string())
302-
.unwrap_or(&InfraStateResponse::default())
303-
.status;
306+
let state = fetch_infra_state(infra.id, core.as_ref()).await?.status;
304307
Ok(Json(InfraWithState { infra, state }))
305308
}
306309

@@ -562,12 +565,6 @@ async fn unlock(
562565
Ok(HttpResponse::NoContent().finish())
563566
}
564567

565-
#[derive(Debug, Default, Deserialize)]
566-
567-
pub struct StatePayload {
568-
infra: Option<i64>,
569-
}
570-
571568
/// Instructs Core to load an infra
572569
#[utoipa::path(
573570
tag = "infra",
@@ -595,21 +592,9 @@ async fn load(
595592
Ok(HttpResponse::NoContent().finish())
596593
}
597594

598-
/// Builds a Core cache_status request, runs it
599-
pub async fn call_core_infra_state(
600-
infra_id: Option<i64>,
601-
db_pool: Arc<DbConnectionPool>,
602-
core: Data<CoreClient>,
603-
) -> Result<HashMap<String, InfraStateResponse>> {
604-
if let Some(infra_id) = infra_id {
605-
let conn = &mut db_pool.get().await?;
606-
if !Infra::exists(conn, infra_id).await? {
607-
return Err(InfraApiError::NotFound { infra_id }.into());
608-
}
609-
}
610-
let infra_request = InfraStateRequest { infra: infra_id };
611-
let response = infra_request.fetch(core.as_ref()).await?;
612-
Ok(response)
595+
#[derive(Debug, Default, Deserialize)]
596+
pub struct StatePayload {
597+
infra: Option<i64>,
613598
}
614599

615600
#[get("/cache_status")]
@@ -618,14 +603,36 @@ async fn cache_status(
618603
db_pool: Data<DbConnectionPool>,
619604
core: Data<CoreClient>,
620605
) -> Result<Json<HashMap<String, InfraStateResponse>>> {
621-
let payload = match payload {
622-
Either::Left(state) => state.into_inner(),
623-
Either::Right(_) => Default::default(),
624-
};
625-
let infra_id = payload.infra;
626-
Ok(Json(
627-
call_core_infra_state(infra_id, db_pool.into_inner(), core).await?,
628-
))
606+
if let Either::Left(Json(StatePayload {
607+
infra: Some(infra_id),
608+
})) = payload
609+
{
610+
if !Infra::exists(db_pool.get().await?.deref_mut(), infra_id).await? {
611+
return Err(InfraApiError::NotFound { infra_id }.into());
612+
}
613+
let infra_state = fetch_infra_state(infra_id, core.as_ref()).await?;
614+
Ok(Json(HashMap::from([(infra_id.to_string(), infra_state)])))
615+
} else {
616+
Ok(Json(fetch_all_infra_states(core.as_ref()).await?))
617+
}
618+
}
619+
620+
/// Builds a Core cache_status request, runs it
621+
pub async fn fetch_infra_state(infra_id: i64, core: &CoreClient) -> Result<InfraStateResponse> {
622+
Ok(InfraStateRequest {
623+
infra: Some(infra_id),
624+
}
625+
.fetch(core)
626+
.await?
627+
.get(&infra_id.to_string())
628+
.cloned()
629+
.unwrap_or_default())
630+
}
631+
632+
pub async fn fetch_all_infra_states(
633+
core: &CoreClient,
634+
) -> Result<HashMap<String, InfraStateResponse>> {
635+
InfraStateRequest::default().fetch(core).await
629636
}
630637

631638
#[cfg(test)]

editoast/src/views/timetable/import.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use crate::modelsv2::Infra;
3636
use crate::modelsv2::OperationalPointModel;
3737
use crate::modelsv2::RetrieveBatch;
3838
use crate::modelsv2::RollingStockModel;
39-
use crate::views::infra::call_core_infra_state;
39+
use crate::views::infra::fetch_infra_state;
4040
use crate::views::infra::InfraApiError;
4141
use crate::views::infra::InfraState;
4242
use crate::views::pathfinding::save_core_pathfinding;
@@ -207,17 +207,13 @@ pub async fn post_timetable(
207207
.await?;
208208

209209
// Check infra is loaded
210-
let db_pool = db_pool.into_inner();
211-
let mut infra_state =
212-
call_core_infra_state(Some(infra_id), db_pool.clone(), core_client.clone()).await?;
213-
let infra_status = infra_state
214-
.remove(&infra_id.to_string())
215-
.unwrap_or_default()
216-
.status;
210+
let core_client = core_client.as_ref();
211+
let infra_status = fetch_infra_state(infra_id, core_client).await?.status;
217212
if infra_status != InfraState::Cached {
218213
return Err(TimetableError::InfraNotLoaded { infra_id }.into());
219214
}
220215

216+
let db_pool = db_pool.into_inner();
221217
let mut item_futures = Vec::new();
222218

223219
for item in data.into_inner() {
@@ -227,7 +223,7 @@ pub async fn post_timetable(
227223
db_pool.clone(),
228224
item,
229225
timetable_id,
230-
&core_client,
226+
core_client,
231227
));
232228
}
233229
let item_results = try_join_all(item_futures).await?;

tests/tests/test_infra.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ class _InfraDetails:
2323
@dataclass(frozen=True)
2424
class _InfraResponse:
2525
count: int
26-
next: Optional[Any]
27-
previous: Optional[Any]
26+
page_size: int
27+
page_count: int
28+
current: int
29+
previous: Optional[int]
30+
next: Optional[int]
2831
results: Iterable[Mapping[str, Any]]
2932

3033

0 commit comments

Comments
 (0)