Skip to content

Commit 281e005

Browse files
committed
editoast: refactor persist_railjson and persist
1 parent 3040c34 commit 281e005

File tree

11 files changed

+156
-156
lines changed

11 files changed

+156
-156
lines changed

editoast/src/fixtures.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ pub mod tests {
472472
Infra::changeset()
473473
.name("small_infra".to_owned())
474474
.last_railjson_version()
475-
.persist(railjson, db_pool)
475+
.persist(railjson, db_pool.get().await.unwrap().deref_mut())
476476
.await
477477
.unwrap()
478478
}

editoast/src/generated_data/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub async fn refresh_all(
9494
// The other layers depend on track section layer.
9595
// We must wait until its completion before running the other requests in parallel
9696
TrackSectionLayer::refresh_pool(db_pool.clone(), infra, infra_cache).await?;
97-
debug!("⚙️ Infra {infra}: track section layer is generated");
97+
debug!(infra_id = infra, "Infra: track section layer is generated");
9898
// The analyze step significantly improves the performance when importing and generating together
9999
// It doesn’t seem to make a different when the generation step is ran separately
100100
// It isn’t clear why without analyze the Postgres server seems to run at 100% without halting

editoast/src/main.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,9 @@ async fn import_railjson(
629629
let railjson: RailJson = serde_json::from_reader(BufReader::new(railjson_file))?;
630630

631631
println!("🍞 Importing infra {infra_name}");
632-
let mut infra = infra.persist_v2(railjson, db_pool.clone()).await?;
632+
let mut infra = infra
633+
.persist(railjson, db_pool.get().await?.deref_mut())
634+
.await?;
633635

634636
infra
635637
.bump_version(db_pool.get().await?.deref_mut())

editoast/src/modelsv2/fixtures.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::io::Cursor;
2-
use std::sync::Arc;
32

43
use chrono::Utc;
54
use editoast_schemas::infra::InfraObject;
@@ -13,7 +12,6 @@ use crate::modelsv2::rolling_stock_livery::RollingStockLiveryModel;
1312
use crate::modelsv2::timetable::Timetable;
1413
use crate::modelsv2::train_schedule::TrainSchedule;
1514
use crate::modelsv2::DbConnection;
16-
use crate::modelsv2::DbConnectionPoolV2;
1715
use crate::modelsv2::Document;
1816
use crate::modelsv2::ElectricalProfileSet;
1917
use crate::modelsv2::Infra;
@@ -286,15 +284,15 @@ where
286284
railjson_object
287285
}
288286

289-
pub async fn create_small_infra(db_pool: Arc<DbConnectionPoolV2>) -> Infra {
287+
pub async fn create_small_infra(conn: &mut DbConnection) -> Infra {
290288
let railjson: RailJson = serde_json::from_str(include_str!(
291289
"../../../tests/data/infras/small_infra/infra.json"
292290
))
293291
.unwrap();
294292
Infra::changeset()
295293
.name("small_infra".to_owned())
296294
.last_railjson_version()
297-
.persist_v2(railjson, db_pool)
295+
.persist(railjson, conn)
298296
.await
299297
.unwrap()
300298
}

editoast/src/modelsv2/infra.rs

+62-46
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@ use crate::modelsv2::get_geometry_layer_table;
3636
use crate::modelsv2::get_table;
3737
use crate::modelsv2::prelude::*;
3838
use crate::modelsv2::railjson::persist_railjson;
39-
use crate::modelsv2::railjson::persist_railjson_v2;
4039
use crate::modelsv2::Create;
4140
use crate::modelsv2::DbConnection;
42-
use crate::modelsv2::DbConnectionPool;
4341
use crate::modelsv2::DbConnectionPoolV2;
4442
use crate::tables::infra::dsl;
4543
use editoast_schemas::infra::RailJson;
@@ -77,16 +75,11 @@ pub struct Infra {
7775
}
7876

7977
impl InfraChangeset {
80-
pub async fn persist(
81-
self,
82-
railjson: RailJson,
83-
db_pool: Arc<DbConnectionPool>,
84-
) -> Result<Infra> {
85-
let conn = &mut db_pool.get().await?;
78+
pub async fn persist(self, railjson: RailJson, conn: &mut DbConnection) -> Result<Infra> {
8679
let infra = self.create(conn).await?;
8780
// TODO: lock infra for update
8881
debug!("🛤 Begin importing all railjson objects");
89-
if let Err(e) = persist_railjson(db_pool, infra.id, railjson).await {
82+
if let Err(e) = persist_railjson(conn, infra.id, railjson).await {
9083
error!("Could not import infrastructure {}. Rolling back", infra.id);
9184
infra.delete(conn).await?;
9285
return Err(e);
@@ -95,23 +88,6 @@ impl InfraChangeset {
9588
Ok(infra)
9689
}
9790

98-
pub async fn persist_v2(
99-
self,
100-
railjson: RailJson,
101-
db_pool: Arc<DbConnectionPoolV2>,
102-
) -> Result<Infra> {
103-
let infra = self.create(db_pool.get().await?.deref_mut()).await?;
104-
// TODO: lock infra for update
105-
debug!("🛤 Begin importing all railjson objects");
106-
if let Err(e) = persist_railjson_v2(db_pool.clone(), infra.id, railjson).await {
107-
error!("Could not import infrastructure {}. Rolling back", infra.id);
108-
infra.delete(db_pool.get().await?.deref_mut()).await?;
109-
return Err(e);
110-
};
111-
debug!("🛤 Import finished successfully");
112-
Ok(infra)
113-
}
114-
11591
#[must_use = "builder methods are intended to be chained"]
11692
pub fn last_railjson_version(self) -> Self {
11793
self.railjson_version(RAILJSON_VERSION.to_owned())
@@ -338,8 +314,6 @@ pub mod tests {
338314

339315
use super::Infra;
340316
use crate::error::EditoastError;
341-
use crate::fixtures::tests::db_pool;
342-
use crate::fixtures::tests::IntoFixture;
343317
use crate::modelsv2::fixtures::create_empty_infra;
344318
use crate::modelsv2::infra::DEFAULT_INFRA_VERSION;
345319
use crate::modelsv2::prelude::*;
@@ -379,15 +353,15 @@ pub mod tests {
379353

380354
#[rstest]
381355
async fn persists_railjson_ko_version() {
382-
let pool = db_pool();
356+
let db_pool = DbConnectionPoolV2::for_tests();
383357
let railjson_with_invalid_version = RailJson {
384358
version: "0".to_string(),
385359
..Default::default()
386360
};
387361
let res = Infra::changeset()
388362
.name("test".to_owned())
389363
.last_railjson_version()
390-
.persist(railjson_with_invalid_version, pool)
364+
.persist(railjson_with_invalid_version, db_pool.get_ok().deref_mut())
391365
.await;
392366
assert!(res.is_err());
393367
let expected_error = RailJsonError::UnsupportedVersion {
@@ -419,14 +393,13 @@ pub mod tests {
419393
version: RAILJSON_VERSION.to_string(),
420394
};
421395

422-
let pool = db_pool();
396+
let db_pool = DbConnectionPoolV2::for_tests();
423397
let infra = Infra::changeset()
424398
.name("persist_railjson_ok_infra".to_owned())
425399
.last_railjson_version()
426-
.persist(railjson.clone(), pool.clone())
400+
.persist(railjson.clone(), db_pool.get_ok().deref_mut())
427401
.await
428-
.expect("could not persist infra")
429-
.into_fixture(pool.clone());
402+
.expect("could not persist infra");
430403

431404
// THEN
432405
assert_eq!(infra.railjson_version, railjson.version);
@@ -436,51 +409,94 @@ pub mod tests {
436409
objects
437410
}
438411

439-
let conn = &mut pool.get().await.unwrap();
440412
let id = infra.id;
441413

442414
assert_eq!(
443-
sort::<BufferStop>(find_all_schemas(conn, id).await.unwrap()),
415+
sort::<BufferStop>(
416+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
417+
.await
418+
.unwrap()
419+
),
444420
sort(railjson.buffer_stops)
445421
);
446422
assert_eq!(
447-
sort::<Route>(find_all_schemas(conn, id).await.unwrap()),
423+
sort::<Route>(
424+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
425+
.await
426+
.unwrap()
427+
),
448428
sort(railjson.routes)
449429
);
450430
assert_eq!(
451-
sort::<SwitchType>(find_all_schemas(conn, id).await.unwrap()),
431+
sort::<SwitchType>(
432+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
433+
.await
434+
.unwrap()
435+
),
452436
sort(railjson.extended_switch_types)
453437
);
454438
assert_eq!(
455-
sort::<Switch>(find_all_schemas(conn, id).await.unwrap()),
439+
sort::<Switch>(
440+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
441+
.await
442+
.unwrap()
443+
),
456444
sort(railjson.switches)
457445
);
458446
assert_eq!(
459-
sort::<TrackSection>(find_all_schemas(conn, id).await.unwrap()),
447+
sort::<TrackSection>(
448+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
449+
.await
450+
.unwrap()
451+
),
460452
sort(railjson.track_sections)
461453
);
462454
assert_eq!(
463-
sort::<SpeedSection>(find_all_schemas(conn, id).await.unwrap()),
455+
sort::<SpeedSection>(
456+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
457+
.await
458+
.unwrap()
459+
),
464460
sort(railjson.speed_sections)
465461
);
466462
assert_eq!(
467-
sort::<NeutralSection>(find_all_schemas(conn, id).await.unwrap()),
463+
sort::<NeutralSection>(
464+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
465+
.await
466+
.unwrap()
467+
),
468468
sort(railjson.neutral_sections)
469469
);
470470
assert_eq!(
471-
sort::<Electrification>(find_all_schemas(conn, id).await.unwrap()),
471+
sort::<Electrification>(
472+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
473+
.await
474+
.unwrap()
475+
),
472476
sort(railjson.electrifications)
473477
);
474478
assert_eq!(
475-
sort::<Signal>(find_all_schemas(conn, id).await.unwrap()),
479+
sort::<Signal>(
480+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
481+
.await
482+
.unwrap()
483+
),
476484
sort(railjson.signals)
477485
);
478486
assert_eq!(
479-
sort::<Detector>(find_all_schemas(conn, id).await.unwrap()),
487+
sort::<Detector>(
488+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
489+
.await
490+
.unwrap()
491+
),
480492
sort(railjson.detectors)
481493
);
482494
assert_eq!(
483-
sort::<OperationalPoint>(find_all_schemas(conn, id).await.unwrap()),
495+
sort::<OperationalPoint>(
496+
find_all_schemas(db_pool.get_ok().deref_mut(), id)
497+
.await
498+
.unwrap()
499+
),
484500
sort(railjson.operational_points)
485501
);
486502
}

0 commit comments

Comments
 (0)