Skip to content

Commit d2557c1

Browse files
committed
editoast: add train import command for train schedule v2
1 parent 6d33ff6 commit d2557c1

File tree

4 files changed

+193
-5
lines changed

4 files changed

+193
-5
lines changed

editoast/src/client/mod.rs

+16
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ pub enum Commands {
5050
Search(SearchCommands),
5151
#[command(subcommand, about, long_about = "Infrastructure related commands")]
5252
Infra(InfraCommands),
53+
#[command(subcommand, about, long_about = "Trains related commands")]
54+
Trains(TrainsCommands),
55+
}
56+
57+
#[derive(Subcommand, Debug)]
58+
pub enum TrainsCommands {
59+
Import(ImportTrainArgs),
60+
}
61+
62+
#[derive(Args, Debug, Derivative)]
63+
#[derivative(Default)]
64+
#[command(about, long_about = "Import a train given a JSON file")]
65+
pub struct ImportTrainArgs {
66+
#[arg(long, help = "The timetable id on which attach the trains to")]
67+
pub timetable: Option<i64>,
68+
pub path: PathBuf,
5369
}
5470

5571
#[derive(Subcommand, Debug)]

editoast/src/fixtures.rs

+4
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ pub mod tests {
142142
rs
143143
}
144144

145+
pub fn get_trainschedule_json_array() -> &'static str {
146+
include_str!("./tests/train_schedules/simple_array.json")
147+
}
148+
145149
pub async fn named_other_rolling_stock(
146150
name: &str,
147151
db_pool: Data<DbPool>,

editoast/src/main.rs

+96-5
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ use chashmap::CHashMap;
2727
use clap::Parser;
2828
use client::{
2929
ClearArgs, Client, Color, Commands, DeleteProfileSetArgs, ElectricalProfilesCommands,
30-
GenerateArgs, ImportProfileSetArgs, ImportRailjsonArgs, ImportRollingStockArgs, InfraCloneArgs,
31-
InfraCommands, ListProfileSetArgs, MakeMigrationArgs, RedisConfig, RefreshArgs, RunserverArgs,
32-
SearchCommands,
30+
GenerateArgs, ImportProfileSetArgs, ImportRailjsonArgs, ImportRollingStockArgs,
31+
ImportTrainArgs, InfraCloneArgs, InfraCommands, ListProfileSetArgs, MakeMigrationArgs,
32+
RedisConfig, RefreshArgs, RunserverArgs, SearchCommands, TrainsCommands,
3333
};
34+
use modelsv2::{
35+
timetable::Timetable, train_schedule::TrainSchedule, train_schedule::TrainScheduleChangeset,
36+
Create as CreateV2, CreateBatch, Model, Retrieve as RetrieveV2,
37+
};
38+
use schema::v2::trainschedule::TrainScheduleBase;
39+
use views::v2::train_schedule::TrainScheduleForm;
40+
3441
use colored::*;
3542
use core::CoreClient;
3643
use diesel::{sql_query, ConnectionError, ConnectionResult};
@@ -185,9 +192,65 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
185192
}
186193
InfraCommands::ImportRailjson(args) => import_railjson(args, create_db_pool()?).await,
187194
},
195+
Commands::Trains(subcommand) => match subcommand {
196+
TrainsCommands::Import(args) => trains_import(args, create_db_pool()?).await,
197+
},
188198
}
189199
}
190200

201+
async fn trains_import(
202+
args: ImportTrainArgs,
203+
db_pool: Data<DbPool>,
204+
) -> Result<(), Box<dyn Error + Send + Sync>> {
205+
let train_file = match File::open(args.path.clone()) {
206+
Ok(file) => file,
207+
Err(e) => {
208+
let error = CliError::new(
209+
1,
210+
format!("❌ Could not open file {:?} ({:?})", args.path, e),
211+
);
212+
return Err(Box::new(error));
213+
}
214+
};
215+
216+
let conn = &mut db_pool.get().await?;
217+
let timetable = match args.timetable {
218+
Some(timetable) => match Timetable::retrieve(conn, timetable).await? {
219+
Some(timetable) => timetable,
220+
None => {
221+
let error = CliError::new(1, format!("❌ Timetable not found, id: {0}", timetable));
222+
return Err(Box::new(error));
223+
}
224+
},
225+
None => {
226+
let changeset = Timetable::changeset();
227+
changeset.create(conn).await?
228+
}
229+
};
230+
231+
let train_schedules: Vec<TrainScheduleBase> =
232+
serde_json::from_reader(BufReader::new(train_file))?;
233+
let changesets: Vec<TrainScheduleChangeset> = train_schedules
234+
.into_iter()
235+
.map(|train_schedule| {
236+
TrainScheduleForm {
237+
timetable_id: timetable.id,
238+
train_schedule,
239+
}
240+
.into()
241+
})
242+
.collect();
243+
let inserted: Vec<_> = TrainSchedule::create_batch(conn, changesets).await?;
244+
245+
println!(
246+
"✅ {} train schedules created for timetable with id {}",
247+
inserted.len(),
248+
timetable.id
249+
);
250+
251+
Ok(())
252+
}
253+
191254
fn init_sentry(args: &RunserverArgs) -> Option<ClientInitGuard> {
192255
match (args.sentry_dsn.clone(), args.sentry_env.clone()) {
193256
(Some(sentry_dsn), Some(sentry_env)) => Some(sentry::init((
@@ -790,10 +853,11 @@ mod tests {
790853
use super::*;
791854

792855
use crate::fixtures::tests::{
793-
db_pool, electrical_profile_set, get_fast_rolling_stock, TestFixture,
856+
db_pool, electrical_profile_set, get_fast_rolling_stock, get_trainschedule_json_array,
857+
TestFixture,
794858
};
795859
use diesel::sql_query;
796-
use diesel::sql_types::Text;
860+
use diesel::sql_types::{BigInt, Text};
797861
use diesel_async::RunQueryDsl;
798862
use rand::distributions::Alphanumeric;
799863
use rand::{thread_rng, Rng};
@@ -802,6 +866,33 @@ mod tests {
802866
use std::io::Write;
803867
use tempfile::NamedTempFile;
804868

869+
#[rstest]
870+
async fn import_train_schedule_v2(db_pool: Data<DbPool>) {
871+
let conn = &mut db_pool.get().await.unwrap();
872+
873+
let changeset = Timetable::changeset();
874+
let timetable = changeset.create(conn).await.unwrap();
875+
876+
let mut file = NamedTempFile::new().unwrap();
877+
file.write(get_trainschedule_json_array().as_bytes())
878+
.unwrap();
879+
880+
let args = ImportTrainArgs {
881+
path: file.path().into(),
882+
timetable: Some(timetable.id),
883+
};
884+
885+
let result = trains_import(args, db_pool.clone()).await;
886+
887+
assert!(result.is_ok(), "{:?}", result);
888+
889+
sql_query("DELETE FROM timetable_v2 WHERE id = $1")
890+
.bind::<BigInt, _>(timetable.id)
891+
.execute(conn)
892+
.await
893+
.unwrap();
894+
}
895+
805896
#[rstest]
806897
async fn import_rolling_stock_ko_file_not_found(db_pool: Data<DbPool>) {
807898
// GIVEN
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
[
2+
{
3+
"train_name": "ABC3615",
4+
"rolling_stock_name": "R2D2",
5+
"labels": [
6+
"choo-choo",
7+
"tchou-tchou"
8+
],
9+
"speed_limit_tag": "MA100",
10+
"start_time": "2023-12-21T08:51:30+00:00",
11+
"path": [
12+
{
13+
"id": "a",
14+
"uic": 87210
15+
},
16+
{
17+
"id": "b",
18+
"track": "foo",
19+
"offset": 10
20+
},
21+
{
22+
"id": "c",
23+
"deleted": true,
24+
"trigram": "ABC"
25+
},
26+
{
27+
"id": "d",
28+
"operational_point": "X"
29+
}
30+
],
31+
"constraint_distribution": "MARECO",
32+
"schedule": [
33+
{
34+
"at": "a",
35+
"stop_for": "PT5M",
36+
"locked": true
37+
},
38+
{
39+
"at": "b",
40+
"arrival": "PT10M",
41+
"stop_for": "PT5M"
42+
},
43+
{
44+
"at": "c",
45+
"stop_for": "PT5M"
46+
},
47+
{
48+
"at": "d",
49+
"arrival": "PT50M",
50+
"locked": true
51+
}
52+
],
53+
"margins": {
54+
"boundaries": [
55+
"b",
56+
"c"
57+
],
58+
"values": [
59+
"5%",
60+
"3min/km",
61+
"none"
62+
]
63+
},
64+
"initial_speed": 2.5,
65+
"power_restrictions": [
66+
{
67+
"from": "b",
68+
"to": "c",
69+
"value": "M1C1"
70+
}
71+
],
72+
"comfort": "AIR_CONDITIONING",
73+
"options": {
74+
"use_electrical_profiles": true
75+
}
76+
}
77+
]

0 commit comments

Comments
 (0)