Skip to content

Commit 6b46ccb

Browse files
committed
editoast: add --force option for import-rolling-stock command
Signed-off-by: Louis Greiner <[email protected]>
1 parent 846383c commit 6b46ccb

File tree

3 files changed

+151
-18
lines changed

3 files changed

+151
-18
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ docker compose up -d --build
4747
./scripts/load-railjson-infra.sh small_infra tests/data/infras/small_infra/infra.json
4848

4949
# import rolling stocks with realistic characterics, representative of the industry
50-
./scripts/load-railjson-rolling-stock.sh tests/data/rolling_stocks/realistic/*.json
50+
./scripts/load-railjson-rolling-stock.sh tests/data/rolling_stocks/realistic/*.json --force
5151

5252
# import more rolling stocks
5353
./scripts/load-railjson-rolling-stock.sh tests/data/rolling_stocks/*.json

editoast/src/client/import_rolling_stock.rs

+132-11
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,27 @@ use crate::models::towed_rolling_stock::TowedRollingStockModel;
1616
use crate::models::RollingStockModel;
1717
use crate::CliError;
1818

19-
#[derive(Args, Debug)]
19+
#[derive(Args, Clone, Debug)]
2020
#[command(about, long_about = "Import a rolling stock given a json file")]
2121
pub struct ImportRollingStockArgs {
2222
/// Rolling stock file path
2323
rolling_stock_path: Vec<PathBuf>,
24+
25+
/// If true, force the update of the rolling stock if it already exists
26+
#[clap(long, default_value_t = false)]
27+
pub force: bool,
2428
}
2529

2630
pub async fn import_rolling_stock(
2731
args: ImportRollingStockArgs,
2832
db_pool: Arc<DbConnectionPoolV2>,
2933
) -> Result<(), Box<dyn Error + Send + Sync>> {
3034
for rolling_stock_path in args.rolling_stock_path {
35+
let mut conn = db_pool.get().await?;
3136
let rolling_stock_file = File::open(rolling_stock_path)?;
3237
let rolling_stock: RollingStock =
3338
serde_json::from_reader(BufReader::new(rolling_stock_file))?;
39+
let rolling_stock_name = rolling_stock.name.clone();
3440
let rolling_stock: Changeset<RollingStockModel> = rolling_stock.into();
3541
match rolling_stock.validate() {
3642
Ok(()) => {
@@ -42,16 +48,42 @@ pub async fn import_rolling_stock(
4248
.unwrap_or("rolling stock without name")
4349
.bold()
4450
);
45-
let rolling_stock = rolling_stock
46-
.locked(false)
47-
.version(0)
48-
.create(&mut db_pool.get().await?)
49-
.await?;
50-
println!(
51-
"✅ Rolling stock {}[{}] saved!",
52-
&rolling_stock.name.bold(),
53-
&rolling_stock.id
54-
);
51+
let existing_rolling_stock =
52+
RollingStockModel::retrieve(&mut conn, rolling_stock_name.clone()).await?;
53+
match (existing_rolling_stock, args.force) {
54+
(Some(_), true) => {
55+
let rolling_stock = rolling_stock
56+
.locked(false)
57+
.version(0)
58+
.update(&mut conn, rolling_stock_name.clone())
59+
.await?
60+
.unwrap();
61+
println!(
62+
" ↳ ✅ Rolling stock {}[{}] saved! (forced update)",
63+
&rolling_stock_name.bold(),
64+
&rolling_stock.id,
65+
);
66+
}
67+
(Some(existing_rolling_stock), false) => {
68+
println!(
69+
" ↳ ⚠️ Rolling stock {}[{}] already existing! (try use \"--force\" to update it)",
70+
&rolling_stock_name.bold(),
71+
&existing_rolling_stock.id,
72+
);
73+
}
74+
_ => {
75+
let rolling_stock = rolling_stock
76+
.locked(false)
77+
.version(0)
78+
.create(&mut conn)
79+
.await?;
80+
println!(
81+
" ↳ ✅ Rolling stock {}[{}] saved!",
82+
&rolling_stock_name.bold(),
83+
&rolling_stock.id,
84+
);
85+
}
86+
}
5587
}
5688
Err(e) => {
5789
let mut error_message = "❌ Rolling stock was not created!".to_string();
@@ -116,6 +148,7 @@ mod tests {
116148

117149
use crate::client::generate_temp_file;
118150

151+
use editoast_common::units;
119152
use editoast_models::DbConnectionPoolV2;
120153
use rstest::rstest;
121154

@@ -133,6 +166,7 @@ mod tests {
133166
let db_pool = DbConnectionPoolV2::for_tests();
134167
let args = ImportRollingStockArgs {
135168
rolling_stock_path: vec!["non/existing/railjson/file/location".into()],
169+
force: false,
136170
};
137171

138172
// WHEN
@@ -154,6 +188,7 @@ mod tests {
154188
let file = generate_temp_file(&non_electric_rs);
155189
let args = ImportRollingStockArgs {
156190
rolling_stock_path: vec![file.path().into()],
191+
force: false,
157192
};
158193

159194
// WHEN
@@ -183,6 +218,7 @@ mod tests {
183218
let file = generate_temp_file(&non_electric_rs);
184219
let args = ImportRollingStockArgs {
185220
rolling_stock_path: vec![file.path().into()],
221+
force: false,
186222
};
187223

188224
// WHEN
@@ -216,6 +252,7 @@ mod tests {
216252
let file = generate_temp_file(&electric_rs);
217253
let args = ImportRollingStockArgs {
218254
rolling_stock_path: vec![file.path().into()],
255+
force: false,
219256
};
220257

221258
// WHEN
@@ -243,6 +280,7 @@ mod tests {
243280
let file = generate_temp_file(&electric_rolling_stock);
244281
let args = ImportRollingStockArgs {
245282
rolling_stock_path: vec![file.path().into()],
283+
force: false,
246284
};
247285

248286
// WHEN
@@ -263,6 +301,87 @@ mod tests {
263301
assert!(electrical_power_startup_time.is_some());
264302
assert!(raise_pantograph_time.is_some());
265303
}
304+
305+
#[rstest]
306+
async fn import_existing_rolling_stock_without_force() {
307+
// GIVEN
308+
let db_pool = DbConnectionPoolV2::for_tests();
309+
let existing_rolling_stock_name = "existing_rolling_stock";
310+
let existing_rolling_stock_form =
311+
get_fast_rolling_stock_schema(existing_rolling_stock_name);
312+
313+
let existing_rolling_stock: Changeset<RollingStockModel> =
314+
existing_rolling_stock_form.clone().into();
315+
existing_rolling_stock
316+
.locked(false)
317+
.version(0)
318+
.create(&mut db_pool.get_ok())
319+
.await
320+
.unwrap();
321+
322+
// second rolling stock with same values except length (100.0 instead of 400.0)
323+
let mut updated_rolling_stock_form: RollingStock = existing_rolling_stock_form.clone();
324+
updated_rolling_stock_form.length = units::meter::new(100.0);
325+
let file = generate_temp_file(&updated_rolling_stock_form);
326+
let args = ImportRollingStockArgs {
327+
rolling_stock_path: vec![file.path().into()],
328+
force: false,
329+
};
330+
331+
// WHEN
332+
let result = import_rolling_stock(args, db_pool.clone().into()).await;
333+
334+
// THEN
335+
assert!(result.is_ok(), "import should succeed, but result as skipped, as a rolling stock already exists and --force is disabled");
336+
let rolling_stock = RollingStockModel::retrieve(
337+
&mut db_pool.get_ok(),
338+
existing_rolling_stock_name.to_string(),
339+
)
340+
.await
341+
.unwrap();
342+
assert!(rolling_stock.is_some());
343+
assert!(rolling_stock.unwrap().length == units::meter::new(400.0));
344+
}
345+
346+
#[rstest]
347+
async fn import_existing_rolling_stock_with_force() {
348+
// GIVEN
349+
let db_pool = DbConnectionPoolV2::for_tests();
350+
let existing_rolling_stock_name = "existing_rolling_stock";
351+
let existing_rolling_stock_form =
352+
get_fast_rolling_stock_schema(existing_rolling_stock_name);
353+
let existing_rolling_stock: Changeset<RollingStockModel> =
354+
existing_rolling_stock_form.clone().into();
355+
existing_rolling_stock
356+
.locked(false)
357+
.version(0)
358+
.create(&mut db_pool.get_ok())
359+
.await
360+
.unwrap();
361+
362+
// second rolling stock with same values except length (100.0 instead of 400.0)
363+
let mut updated_rolling_stock_form: RollingStock = existing_rolling_stock_form.clone();
364+
updated_rolling_stock_form.length = units::meter::new(100.0);
365+
let file = generate_temp_file(&updated_rolling_stock_form);
366+
let args = ImportRollingStockArgs {
367+
rolling_stock_path: vec![file.path().into()],
368+
force: true,
369+
};
370+
371+
// WHEN
372+
let result = import_rolling_stock(args, db_pool.clone().into()).await;
373+
374+
// THEN
375+
assert!(result.is_ok(), "import should succeed, but result as skipped, as a rolling stock already exists and --force is disabled");
376+
let rolling_stock = RollingStockModel::retrieve(
377+
&mut db_pool.get_ok(),
378+
existing_rolling_stock_name.to_string(),
379+
)
380+
.await
381+
.unwrap();
382+
assert!(rolling_stock.is_some());
383+
assert!(rolling_stock.unwrap().length == units::meter::new(100.0));
384+
}
266385
}
267386

268387
mod towed_rolling_stock {
@@ -279,6 +398,7 @@ mod tests {
279398
let db_pool = DbConnectionPoolV2::for_tests();
280399
let args = ImportRollingStockArgs {
281400
rolling_stock_path: vec!["non/existing/railjson/file/location".into()],
401+
force: false,
282402
};
283403

284404
// WHEN
@@ -301,6 +421,7 @@ mod tests {
301421
let file = generate_temp_file(&towed_rolling_stock_form);
302422
let args = ImportRollingStockArgs {
303423
rolling_stock_path: vec![file.path().into()],
424+
force: false,
304425
};
305426

306427
// WHEN

scripts/load-railjson-rolling-stock.sh

+18-6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,24 @@
99
set -e
1010

1111
if [ "$#" = 0 ]; then
12-
echo "Missing path to RailJSON rolling stock"
13-
exit 1
12+
echo "Missing path to RailJSON rolling stock"
13+
exit 1
1414
fi
1515

16-
echo "Loading $# example rolling stock(s)"
17-
for rolling_stock_path in "$@"; do
18-
docker cp "${rolling_stock_path}" osrd-editoast:tmp/stock.json
19-
docker exec osrd-editoast editoast import-rolling-stock //tmp/stock.json
16+
FORCE_OPTION=""
17+
ROLLING_STOCK_PATHS=""
18+
for arg in "$@"; do
19+
if [ "$arg" = "--force" ]; then
20+
FORCE_OPTION="--force"
21+
else
22+
ROLLING_STOCK_PATHS="$ROLLING_STOCK_PATHS $arg"
23+
fi
24+
done
25+
26+
echo "Loading $(echo "$ROLLING_STOCK_PATHS" | wc -w) example rolling stock(s)"
27+
for rolling_stock_path in $ROLLING_STOCK_PATHS; do
28+
docker cp "$rolling_stock_path" osrd-editoast:tmp/stock.json
29+
# ignore the mandatory "" around $FORCE_OPTION, since "" is interpreted as an arg
30+
# shellcheck disable=SC2086
31+
docker exec osrd-editoast editoast import-rolling-stock //tmp/stock.json $FORCE_OPTION
2032
done

0 commit comments

Comments
 (0)