@@ -16,21 +16,27 @@ use crate::models::towed_rolling_stock::TowedRollingStockModel;
16
16
use crate :: models:: RollingStockModel ;
17
17
use crate :: CliError ;
18
18
19
- #[ derive( Args , Debug ) ]
19
+ #[ derive( Args , Clone , Debug ) ]
20
20
#[ command( about, long_about = "Import a rolling stock given a json file" ) ]
21
21
pub struct ImportRollingStockArgs {
22
22
/// Rolling stock file path
23
23
rolling_stock_path : Vec < PathBuf > ,
24
+
25
+ /// If true, force the update of the rolling stock if it already exists
26
+ #[ clap( long, default_value_t = false ) ]
27
+ pub force : bool ,
24
28
}
25
29
26
30
pub async fn import_rolling_stock (
27
31
args : ImportRollingStockArgs ,
28
32
db_pool : Arc < DbConnectionPoolV2 > ,
29
33
) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
30
34
for rolling_stock_path in args. rolling_stock_path {
35
+ let mut conn = db_pool. get ( ) . await ?;
31
36
let rolling_stock_file = File :: open ( rolling_stock_path) ?;
32
37
let rolling_stock: RollingStock =
33
38
serde_json:: from_reader ( BufReader :: new ( rolling_stock_file) ) ?;
39
+ let rolling_stock_name = rolling_stock. name . clone ( ) ;
34
40
let rolling_stock: Changeset < RollingStockModel > = rolling_stock. into ( ) ;
35
41
match rolling_stock. validate ( ) {
36
42
Ok ( ( ) ) => {
@@ -42,16 +48,42 @@ pub async fn import_rolling_stock(
42
48
. unwrap_or( "rolling stock without name" )
43
49
. bold( )
44
50
) ;
45
- let rolling_stock = rolling_stock
46
- . locked ( false )
47
- . version ( 0 )
48
- . create ( & mut db_pool. get ( ) . await ?)
49
- . await ?;
50
- println ! (
51
- "✅ Rolling stock {}[{}] saved!" ,
52
- & rolling_stock. name. bold( ) ,
53
- & rolling_stock. id
54
- ) ;
51
+ let existing_rolling_stock =
52
+ RollingStockModel :: retrieve ( & mut conn, rolling_stock_name. clone ( ) ) . await ?;
53
+ match ( existing_rolling_stock, args. force ) {
54
+ ( Some ( _) , true ) => {
55
+ let rolling_stock = rolling_stock
56
+ . locked ( false )
57
+ . version ( 0 )
58
+ . update ( & mut conn, rolling_stock_name. clone ( ) )
59
+ . await ?
60
+ . unwrap ( ) ;
61
+ println ! (
62
+ " ↳ ✅ Rolling stock {}[{}] saved! (forced update)" ,
63
+ & rolling_stock_name. bold( ) ,
64
+ & rolling_stock. id,
65
+ ) ;
66
+ }
67
+ ( Some ( existing_rolling_stock) , false ) => {
68
+ println ! (
69
+ " ↳ ⚠️ Rolling stock {}[{}] already existing! (try use \" --force\" to update it)" ,
70
+ & rolling_stock_name. bold( ) ,
71
+ & existing_rolling_stock. id,
72
+ ) ;
73
+ }
74
+ _ => {
75
+ let rolling_stock = rolling_stock
76
+ . locked ( false )
77
+ . version ( 0 )
78
+ . create ( & mut conn)
79
+ . await ?;
80
+ println ! (
81
+ " ↳ ✅ Rolling stock {}[{}] saved!" ,
82
+ & rolling_stock_name. bold( ) ,
83
+ & rolling_stock. id,
84
+ ) ;
85
+ }
86
+ }
55
87
}
56
88
Err ( e) => {
57
89
let mut error_message = "❌ Rolling stock was not created!" . to_string ( ) ;
@@ -116,6 +148,7 @@ mod tests {
116
148
117
149
use crate :: client:: generate_temp_file;
118
150
151
+ use editoast_common:: units;
119
152
use editoast_models:: DbConnectionPoolV2 ;
120
153
use rstest:: rstest;
121
154
@@ -133,6 +166,7 @@ mod tests {
133
166
let db_pool = DbConnectionPoolV2 :: for_tests ( ) ;
134
167
let args = ImportRollingStockArgs {
135
168
rolling_stock_path : vec ! [ "non/existing/railjson/file/location" . into( ) ] ,
169
+ force : false ,
136
170
} ;
137
171
138
172
// WHEN
@@ -154,6 +188,7 @@ mod tests {
154
188
let file = generate_temp_file ( & non_electric_rs) ;
155
189
let args = ImportRollingStockArgs {
156
190
rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
191
+ force : false ,
157
192
} ;
158
193
159
194
// WHEN
@@ -183,6 +218,7 @@ mod tests {
183
218
let file = generate_temp_file ( & non_electric_rs) ;
184
219
let args = ImportRollingStockArgs {
185
220
rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
221
+ force : false ,
186
222
} ;
187
223
188
224
// WHEN
@@ -216,6 +252,7 @@ mod tests {
216
252
let file = generate_temp_file ( & electric_rs) ;
217
253
let args = ImportRollingStockArgs {
218
254
rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
255
+ force : false ,
219
256
} ;
220
257
221
258
// WHEN
@@ -243,6 +280,7 @@ mod tests {
243
280
let file = generate_temp_file ( & electric_rolling_stock) ;
244
281
let args = ImportRollingStockArgs {
245
282
rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
283
+ force : false ,
246
284
} ;
247
285
248
286
// WHEN
@@ -263,6 +301,87 @@ mod tests {
263
301
assert ! ( electrical_power_startup_time. is_some( ) ) ;
264
302
assert ! ( raise_pantograph_time. is_some( ) ) ;
265
303
}
304
+
305
+ #[ rstest]
306
+ async fn import_existing_rolling_stock_without_force ( ) {
307
+ // GIVEN
308
+ let db_pool = DbConnectionPoolV2 :: for_tests ( ) ;
309
+ let existing_rolling_stock_name = "existing_rolling_stock" ;
310
+ let existing_rolling_stock_form =
311
+ get_fast_rolling_stock_schema ( existing_rolling_stock_name) ;
312
+
313
+ let existing_rolling_stock: Changeset < RollingStockModel > =
314
+ existing_rolling_stock_form. clone ( ) . into ( ) ;
315
+ existing_rolling_stock
316
+ . locked ( false )
317
+ . version ( 0 )
318
+ . create ( & mut db_pool. get_ok ( ) )
319
+ . await
320
+ . unwrap ( ) ;
321
+
322
+ // second rolling stock with same values except length (100.0 instead of 400.0)
323
+ let mut updated_rolling_stock_form: RollingStock = existing_rolling_stock_form. clone ( ) ;
324
+ updated_rolling_stock_form. length = units:: meter:: new ( 100.0 ) ;
325
+ let file = generate_temp_file ( & updated_rolling_stock_form) ;
326
+ let args = ImportRollingStockArgs {
327
+ rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
328
+ force : false ,
329
+ } ;
330
+
331
+ // WHEN
332
+ let result = import_rolling_stock ( args, db_pool. clone ( ) . into ( ) ) . await ;
333
+
334
+ // THEN
335
+ assert ! ( result. is_ok( ) , "import should succeed, but result as skipped, as a rolling stock already exists and --force is disabled" ) ;
336
+ let rolling_stock = RollingStockModel :: retrieve (
337
+ & mut db_pool. get_ok ( ) ,
338
+ existing_rolling_stock_name. to_string ( ) ,
339
+ )
340
+ . await
341
+ . unwrap ( ) ;
342
+ assert ! ( rolling_stock. is_some( ) ) ;
343
+ assert ! ( rolling_stock. unwrap( ) . length == units:: meter:: new( 400.0 ) ) ;
344
+ }
345
+
346
+ #[ rstest]
347
+ async fn import_existing_rolling_stock_with_force ( ) {
348
+ // GIVEN
349
+ let db_pool = DbConnectionPoolV2 :: for_tests ( ) ;
350
+ let existing_rolling_stock_name = "existing_rolling_stock" ;
351
+ let existing_rolling_stock_form =
352
+ get_fast_rolling_stock_schema ( existing_rolling_stock_name) ;
353
+ let existing_rolling_stock: Changeset < RollingStockModel > =
354
+ existing_rolling_stock_form. clone ( ) . into ( ) ;
355
+ existing_rolling_stock
356
+ . locked ( false )
357
+ . version ( 0 )
358
+ . create ( & mut db_pool. get_ok ( ) )
359
+ . await
360
+ . unwrap ( ) ;
361
+
362
+ // second rolling stock with same values except length (100.0 instead of 400.0)
363
+ let mut updated_rolling_stock_form: RollingStock = existing_rolling_stock_form. clone ( ) ;
364
+ updated_rolling_stock_form. length = units:: meter:: new ( 100.0 ) ;
365
+ let file = generate_temp_file ( & updated_rolling_stock_form) ;
366
+ let args = ImportRollingStockArgs {
367
+ rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
368
+ force : true ,
369
+ } ;
370
+
371
+ // WHEN
372
+ let result = import_rolling_stock ( args, db_pool. clone ( ) . into ( ) ) . await ;
373
+
374
+ // THEN
375
+ assert ! ( result. is_ok( ) , "import should succeed, but result as skipped, as a rolling stock already exists and --force is disabled" ) ;
376
+ let rolling_stock = RollingStockModel :: retrieve (
377
+ & mut db_pool. get_ok ( ) ,
378
+ existing_rolling_stock_name. to_string ( ) ,
379
+ )
380
+ . await
381
+ . unwrap ( ) ;
382
+ assert ! ( rolling_stock. is_some( ) ) ;
383
+ assert ! ( rolling_stock. unwrap( ) . length == units:: meter:: new( 100.0 ) ) ;
384
+ }
266
385
}
267
386
268
387
mod towed_rolling_stock {
@@ -279,6 +398,7 @@ mod tests {
279
398
let db_pool = DbConnectionPoolV2 :: for_tests ( ) ;
280
399
let args = ImportRollingStockArgs {
281
400
rolling_stock_path : vec ! [ "non/existing/railjson/file/location" . into( ) ] ,
401
+ force : false ,
282
402
} ;
283
403
284
404
// WHEN
@@ -301,6 +421,7 @@ mod tests {
301
421
let file = generate_temp_file ( & towed_rolling_stock_form) ;
302
422
let args = ImportRollingStockArgs {
303
423
rolling_stock_path : vec ! [ file. path( ) . into( ) ] ,
424
+ force : false ,
304
425
} ;
305
426
306
427
// WHEN
0 commit comments