@@ -27,10 +27,17 @@ use chashmap::CHashMap;
27
27
use clap:: Parser ;
28
28
use client:: {
29
29
ClearArgs , Client , Color , Commands , DeleteProfileSetArgs , ElectricalProfilesCommands ,
30
- GenerateArgs , ImportProfileSetArgs , ImportRailjsonArgs , ImportRollingStockArgs , InfraCloneArgs ,
31
- InfraCommands , ListProfileSetArgs , MakeMigrationArgs , RedisConfig , RefreshArgs , RunserverArgs ,
32
- SearchCommands ,
30
+ GenerateArgs , ImportProfileSetArgs , ImportRailjsonArgs , ImportRollingStockArgs ,
31
+ ImportTrainArgs , InfraCloneArgs , InfraCommands , ListProfileSetArgs , MakeMigrationArgs ,
32
+ RedisConfig , RefreshArgs , RunserverArgs , SearchCommands , TrainsCommands ,
33
33
} ;
34
+ use modelsv2:: {
35
+ timetable:: Timetable , train_schedule:: TrainSchedule , train_schedule:: TrainScheduleChangeset ,
36
+ Create as CreateV2 , CreateBatch , Model , Retrieve as RetrieveV2 ,
37
+ } ;
38
+ use schema:: v2:: trainschedule:: TrainScheduleBase ;
39
+ use views:: v2:: train_schedule:: TrainScheduleForm ;
40
+
34
41
use colored:: * ;
35
42
use core:: CoreClient ;
36
43
use diesel:: { sql_query, ConnectionError , ConnectionResult } ;
@@ -185,9 +192,65 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
185
192
}
186
193
InfraCommands :: ImportRailjson ( args) => import_railjson ( args, create_db_pool ( ) ?) . await ,
187
194
} ,
195
+ Commands :: Trains ( subcommand) => match subcommand {
196
+ TrainsCommands :: Import ( args) => trains_import ( args, create_db_pool ( ) ?) . await ,
197
+ } ,
188
198
}
189
199
}
190
200
201
+ async fn trains_import (
202
+ args : ImportTrainArgs ,
203
+ db_pool : Data < DbPool > ,
204
+ ) -> Result < ( ) , Box < dyn Error + Send + Sync > > {
205
+ let train_file = match File :: open ( args. path . clone ( ) ) {
206
+ Ok ( file) => file,
207
+ Err ( e) => {
208
+ let error = CliError :: new (
209
+ 1 ,
210
+ format ! ( "❌ Could not open file {:?} ({:?})" , args. path, e) ,
211
+ ) ;
212
+ return Err ( Box :: new ( error) ) ;
213
+ }
214
+ } ;
215
+
216
+ let conn = & mut db_pool. get ( ) . await ?;
217
+ let timetable = match args. timetable {
218
+ Some ( timetable) => match Timetable :: retrieve ( conn, timetable) . await ? {
219
+ Some ( timetable) => timetable,
220
+ None => {
221
+ let error = CliError :: new ( 1 , format ! ( "❌ Timetable not found, id: {0}" , timetable) ) ;
222
+ return Err ( Box :: new ( error) ) ;
223
+ }
224
+ } ,
225
+ None => {
226
+ let changeset = Timetable :: changeset ( ) ;
227
+ changeset. create ( conn) . await ?
228
+ }
229
+ } ;
230
+
231
+ let train_schedules: Vec < TrainScheduleBase > =
232
+ serde_json:: from_reader ( BufReader :: new ( train_file) ) ?;
233
+ let changesets: Vec < TrainScheduleChangeset > = train_schedules
234
+ . into_iter ( )
235
+ . map ( |train_schedule| {
236
+ TrainScheduleForm {
237
+ timetable_id : timetable. id ,
238
+ train_schedule,
239
+ }
240
+ . into ( )
241
+ } )
242
+ . collect ( ) ;
243
+ let inserted: Vec < _ > = TrainSchedule :: create_batch ( conn, changesets) . await ?;
244
+
245
+ println ! (
246
+ "✅ {} train schedules created for timetable with id {}" ,
247
+ inserted. len( ) ,
248
+ timetable. id
249
+ ) ;
250
+
251
+ Ok ( ( ) )
252
+ }
253
+
191
254
fn init_sentry ( args : & RunserverArgs ) -> Option < ClientInitGuard > {
192
255
match ( args. sentry_dsn . clone ( ) , args. sentry_env . clone ( ) ) {
193
256
( Some ( sentry_dsn) , Some ( sentry_env) ) => Some ( sentry:: init ( (
@@ -790,18 +853,43 @@ mod tests {
790
853
use super :: * ;
791
854
792
855
use crate :: fixtures:: tests:: {
793
- db_pool, electrical_profile_set, get_fast_rolling_stock, TestFixture ,
856
+ db_pool, electrical_profile_set, get_fast_rolling_stock, get_trainschedule_json_array,
857
+ TestFixture ,
794
858
} ;
795
859
use diesel:: sql_query;
796
860
use diesel:: sql_types:: Text ;
797
861
use diesel_async:: RunQueryDsl ;
862
+ use modelsv2:: DeleteStatic ;
798
863
use rand:: distributions:: Alphanumeric ;
799
864
use rand:: { thread_rng, Rng } ;
800
865
use rstest:: rstest;
801
866
use serde:: Serialize ;
802
867
use std:: io:: Write ;
803
868
use tempfile:: NamedTempFile ;
804
869
870
+ #[ rstest]
871
+ async fn import_train_schedule_v2 ( db_pool : Data < DbPool > ) {
872
+ let conn = & mut db_pool. get ( ) . await . unwrap ( ) ;
873
+
874
+ let changeset = Timetable :: changeset ( ) ;
875
+ let timetable = changeset. create ( conn) . await . unwrap ( ) ;
876
+
877
+ let mut file = NamedTempFile :: new ( ) . unwrap ( ) ;
878
+ file. write_all ( get_trainschedule_json_array ( ) . as_bytes ( ) )
879
+ . unwrap ( ) ;
880
+
881
+ let args = ImportTrainArgs {
882
+ path : file. path ( ) . into ( ) ,
883
+ timetable : Some ( timetable. id ) ,
884
+ } ;
885
+
886
+ let result = trains_import ( args, db_pool. clone ( ) ) . await ;
887
+
888
+ assert ! ( result. is_ok( ) , "{:?}" , result) ;
889
+
890
+ Timetable :: delete_static ( conn, timetable. id ) . await . unwrap ( ) ;
891
+ }
892
+
805
893
#[ rstest]
806
894
async fn import_rolling_stock_ko_file_not_found ( db_pool : Data < DbPool > ) {
807
895
// GIVEN
0 commit comments