Skip to content

Commit

Permalink
editoast: stop loading infra caches at server start up
Browse files Browse the repository at this point in the history
  • Loading branch information
majaziri authored and flomonster committed Nov 24, 2022
1 parent f78db45 commit 5795e65
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 128 deletions.
104 changes: 80 additions & 24 deletions editoast/src/infra_cache.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::api_error::ApiError;
use crate::infra::Infra;
use crate::schema::operation::{OperationResult, RailjsonObject};
use crate::schema::*;
use chashmap::{CHashMap, ReadGuard, WriteGuard};
use diesel::sql_types::{Double, Integer, Text};
use diesel::PgConnection;
use diesel::{sql_query, QueryableByName, RunQueryDsl};
Expand Down Expand Up @@ -326,22 +329,23 @@ impl InfraCache {
}

/// Given an infra id load infra cache from database
pub fn load(conn: &PgConnection, infra_id: i32) -> InfraCache {
pub fn load(conn: &PgConnection, infra: &Infra) -> Result<InfraCache, Box<dyn ApiError>> {
let infra_id = infra.id;
let mut infra_cache = Self::default();

// Load track sections list
sql_query(
"SELECT obj_id, (data->>'length')::float as length, data->>'geo' as geo, data->>'sch' as sch FROM osrd_infra_tracksectionmodel WHERE infra_id = $1",
)
.bind::<Integer, _>(infra_id)
.load::<TrackQueryable>(conn)
.expect("Error loading track sections").into_iter().for_each(|track| infra_cache.add::<TrackSectionCache>(track.into()));
.load::<TrackQueryable>(conn)?
.into_iter().for_each(|track| infra_cache.add::<TrackSectionCache>(track.into()));

// Load signal tracks references
sql_query(
"SELECT obj_id, data->>'track' AS track, (data->>'position')::float AS position FROM osrd_infra_signalmodel WHERE infra_id = $1")
.bind::<Integer, _>(infra_id)
.load::<SignalCache>(conn).expect("Error loading signal refs").into_iter().for_each(|signal|
.load::<SignalCache>(conn)?.into_iter().for_each(|signal|
infra_cache.add(signal)
);

Expand All @@ -359,7 +363,7 @@ impl InfraCache {
sql_query(
"SELECT obj_id, data->>'parts' AS parts FROM osrd_infra_operationalpointmodel WHERE infra_id = $1")
.bind::<Integer, _>(infra_id)
.load::<OperationalPointQueryable>(conn).expect("Error loading operational point refs").into_iter().for_each(|op|
.load::<OperationalPointQueryable>(conn)?.into_iter().for_each(|op|
infra_cache.add::<OperationalPointCache>(op.into())
);

Expand All @@ -372,7 +376,7 @@ impl InfraCache {
sql_query(
"SELECT obj_id, data->>'switch_type' AS switch_type, data->>'ports' AS ports FROM osrd_infra_switchmodel WHERE infra_id = $1")
.bind::<Integer, _>(infra_id)
.load::<SwitchQueryable>(conn).expect("Error loading switch refs").into_iter().for_each(|switch| {
.load::<SwitchQueryable>(conn)?.into_iter().for_each(|switch| {
infra_cache.add::<SwitchCache>(switch.into());
});

Expand All @@ -385,15 +389,15 @@ impl InfraCache {
sql_query(
"SELECT obj_id, data->>'track' AS track, (data->>'position')::float AS position FROM osrd_infra_detectormodel WHERE infra_id = $1")
.bind::<Integer, _>(infra_id)
.load::<DetectorCache>(conn).expect("Error loading detector refs").into_iter().for_each(|detector|
.load::<DetectorCache>(conn)?.into_iter().for_each(|detector|
infra_cache.add(detector)
);

// Load buffer stop tracks references
sql_query(
"SELECT obj_id, data->>'track' AS track, (data->>'position')::float AS position FROM osrd_infra_bufferstopmodel WHERE infra_id = $1")
.bind::<Integer, _>(infra_id)
.load::<BufferStopCache>(conn).expect("Error loading buffer stop refs").into_iter().for_each(|buffer_stop|
.load::<BufferStopCache>(conn)?.into_iter().for_each(|buffer_stop|
infra_cache.add(buffer_stop)
);

Expand All @@ -402,7 +406,39 @@ impl InfraCache {
.into_iter()
.for_each(|catenary| infra_cache.add::<Catenary>(catenary));

infra_cache
Ok(infra_cache)
}

/// This function tries to get the infra from the cache, if it fails, it loads it from the database
/// If the infra is not found in the database, it returns `None`
pub fn get_or_load<'a>(
conn: &PgConnection,
infra_caches: &'a CHashMap<i32, InfraCache>,
infra: &Infra,
) -> Result<ReadGuard<'a, i32, InfraCache>, Box<dyn ApiError>> {
// Cache hit
if let Some(infra_cache) = infra_caches.get(&infra.id) {
return Ok(infra_cache);
}
// Cache miss
infra_caches.insert_new(infra.id, InfraCache::load(conn, infra)?);
Ok(infra_caches.get(&infra.id).unwrap())
}

/// This function tries to get the infra from the cache, if it fails, it loads it from the database
/// If the infra is not found in the database, it returns `None`
pub fn get_or_load_mut<'a>(
conn: &PgConnection,
infra_caches: &'a CHashMap<i32, InfraCache>,
infra: &Infra,
) -> Result<WriteGuard<'a, i32, InfraCache>, Box<dyn ApiError>> {
// Cache hit
if let Some(infra_cache) = infra_caches.get_mut(&infra.id) {
return Ok(infra_cache);
}
// Cache miss
infra_caches.insert_new(infra.id, InfraCache::load(conn, infra)?);
Ok(infra_caches.get_mut(&infra.id).unwrap())
}

/// Get all track sections references of a given track and type
Expand Down Expand Up @@ -480,6 +516,8 @@ impl InfraCache {
pub mod tests {
use std::collections::HashMap;

use chashmap::CHashMap;

use crate::chartos::BoundingBox;
use crate::infra::tests::test_infra_transaction;
use crate::infra_cache::{InfraCache, SwitchCache};
Expand All @@ -503,7 +541,7 @@ pub mod tests {
fn load_track_section() {
test_infra_transaction(|conn, infra| {
let track = create_track(conn, infra.id, Default::default());
let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();
assert_eq!(infra_cache.track_sections().len(), 1);
assert!(infra_cache.track_sections().contains_key(track.get_id()));
});
Expand All @@ -513,9 +551,7 @@ pub mod tests {
fn load_signal() {
test_infra_transaction(|conn, infra| {
let signal = create_signal(conn, infra.id, Default::default());

let infra_cache = InfraCache::load(conn, infra.id);

let infra_cache = InfraCache::load(conn, &infra).unwrap();
assert!(infra_cache.signals().contains_key(signal.get_id()));
let refs = infra_cache.track_sections_refs;
assert_eq!(refs.get("InvalidRef").unwrap().len(), 1);
Expand All @@ -533,9 +569,7 @@ pub mod tests {
..Default::default()
},
);

let infra_cache = InfraCache::load(conn, infra.id);

let infra_cache = InfraCache::load(conn, &infra).unwrap();
assert!(infra_cache.speed_sections().contains_key(speed.get_id()));
let refs = infra_cache.track_sections_refs;
assert_eq!(refs.get("InvalidRef").unwrap().len(), 1);
Expand All @@ -554,7 +588,7 @@ pub mod tests {
},
);

let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();

assert!(infra_cache.routes().contains_key(route.get_id()));
let refs = infra_cache.track_sections_refs;
Expand All @@ -574,7 +608,7 @@ pub mod tests {
},
);

let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();

assert!(infra_cache.operational_points().contains_key(op.get_id()));
let refs = infra_cache.track_sections_refs;
Expand All @@ -586,7 +620,7 @@ pub mod tests {
fn load_track_section_link() {
test_infra_transaction(|conn, infra| {
let link = create_link(conn, infra.id, Default::default());
let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();
assert!(infra_cache
.track_section_links()
.contains_key(link.get_id()));
Expand All @@ -604,7 +638,7 @@ pub mod tests {
..Default::default()
},
);
let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();
assert!(infra_cache.switches().contains_key(switch.get_id()));
})
}
Expand All @@ -613,7 +647,7 @@ pub mod tests {
fn load_switch_type() {
test_infra_transaction(|conn, infra| {
let s_type = create_switch_type(conn, infra.id, Default::default());
let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();
assert!(infra_cache.switch_types().contains_key(s_type.get_id()));
})
}
Expand All @@ -623,7 +657,7 @@ pub mod tests {
test_infra_transaction(|conn, infra| {
let detector = create_detector(conn, infra.id, Default::default());

let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();

assert!(infra_cache.detectors().contains_key(detector.get_id()));
let refs = infra_cache.track_sections_refs;
Expand All @@ -636,7 +670,7 @@ pub mod tests {
test_infra_transaction(|conn, infra| {
let bs = create_buffer_stop(conn, infra.id, Default::default());

let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();

assert!(infra_cache.buffer_stops().contains_key(bs.get_id()));
let refs = infra_cache.track_sections_refs;
Expand All @@ -656,7 +690,7 @@ pub mod tests {
},
);

let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra).unwrap();

assert!(infra_cache.catenaries().contains_key(catenary.get_id()));
let refs = infra_cache.track_sections_refs;
Expand Down Expand Up @@ -956,4 +990,26 @@ pub mod tests {

infra_cache
}

#[test]
fn load_infra_cache() {
test_infra_transaction(|conn, infra| {
let infra_caches = CHashMap::new();
InfraCache::get_or_load(conn, &infra_caches, &infra).unwrap();
assert_eq!(infra_caches.len(), 1);
InfraCache::get_or_load(conn, &infra_caches, &infra).unwrap();
assert_eq!(infra_caches.len(), 1);
});
}

#[test]
fn load_infra_cache_mut() {
test_infra_transaction(|conn, infra| {
let infra_caches = CHashMap::new();
InfraCache::get_or_load_mut(conn, &infra_caches, &infra).unwrap();
assert_eq!(infra_caches.len(), 1);
InfraCache::get_or_load_mut(conn, &infra_caches, &infra).unwrap();
assert_eq!(infra_caches.len(), 1);
});
}
}
27 changes: 5 additions & 22 deletions editoast/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
}
}
pub fn create_server(
infra_caches: Arc<CHashMap<i32, InfraCache>>,
port: u16,
pg_config: &PostgresConfig,
chartos_config: ChartosConfig,
Expand Down Expand Up @@ -88,7 +87,7 @@ pub fn create_server(
let mut rocket = rocket::custom(config)
.attach(DBConnection::fairing())
.attach(cors)
.manage(infra_caches)
.manage(Arc::<CHashMap<i32, InfraCache>>::default())
.manage(chartos_config);

// Mount routes
Expand All @@ -103,24 +102,7 @@ async fn runserver(
pg_config: PostgresConfig,
chartos_config: ChartosConfig,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let conn = pg_config.make_connection();
let infras = Infra::list(&conn);

// Initialize infra caches
let infra_caches = Arc::new(CHashMap::new());
for infra in infras.iter() {
println!(
"🍞 Loading cache for infra {}[{}]...",
infra.name.bold(),
infra.id
);
let infra_cache = InfraCache::load(&conn, infra.id);
infra_caches.insert_new(infra.id, infra_cache);
}
println!("✅ Done loading infra caches!");

let rocket = create_server(infra_caches, args.port, &pg_config, chartos_config);

let rocket = create_server(args.port, &pg_config, chartos_config);
// Run server
let _rocket = rocket.launch().await?;
Ok(())
Expand Down Expand Up @@ -155,7 +137,7 @@ async fn generate(
infra.name.bold(),
infra.id
);
let infra_cache = InfraCache::load(&conn, infra.id);
let infra_cache = InfraCache::load(&conn, &infra)?;
if infra.refresh(&conn, args.force, &infra_cache)? {
chartos::invalidate_all(infra.id, &chartos_config).await;
println!("✅ Infra {}[{}] generated!", infra.name.bold(), infra.id);
Expand Down Expand Up @@ -184,7 +166,7 @@ fn import_railjson(
println!("✅ Infra {}[{}] saved!", infra.name.bold(), infra.id);
// Generate only if the was set
if args.generate {
let infra_cache = InfraCache::load(conn, infra.id);
let infra_cache = InfraCache::load(conn, &infra)?;
infra.refresh(conn, true, &infra_cache)?;
println!(
"✅ Infra {}[{}] generated data refreshed!",
Expand Down Expand Up @@ -224,6 +206,7 @@ fn clear(args: ClearArgs, pg_config: PostgresConfig) -> Result<(), Box<dyn Error
#[cfg(test)]
mod tests {
use crate::client::{ImportRailjsonArgs, PostgresConfig};

use crate::import_railjson;
use crate::schema::RailJson;
use diesel::result::Error;
Expand Down
6 changes: 1 addition & 5 deletions editoast/src/views/infra/edition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@ use crate::schema::operation::{Operation, OperationResult};

pub fn edit(
conn: &PgConnection,
infra: i32,
infra: &Infra,
operations: &[Operation],
infra_cache: &mut InfraCache,
) -> ApiResult<(Vec<OperationResult>, InvalidationZone)> {
// Use a transaction to give scope to the infra lock
// Retrieve and lock infra
let infra = Infra::retrieve_for_update(conn, infra)?;

// Check if the infra is locked
if infra.locked {
return Err(InfraLockedError { infra_id: infra.id }.into());
Expand Down
Loading

0 comments on commit 5795e65

Please sign in to comment.