1
1
import json
2
2
from pathlib import Path
3
- from typing import Any , Iterable , List
3
+ from typing import Any , Iterable , List , Optional
4
4
5
5
import pytest
6
6
import requests
@@ -80,13 +80,36 @@ def small_scenario(small_infra: Infra, foo_project_id: int, foo_study_id: int) -
80
80
yield Scenario (foo_project_id , foo_study_id , scenario_id , small_infra .id , timetable_id )
81
81
82
82
83
- def _create_fast_rolling_stocks (test_rolling_stocks : List [TestRollingStock ] = None ):
83
+ def get_rolling_stock (editoast_url : str , rolling_stock_name : str ) -> int :
84
+ """
85
+ Returns the ID corresponding to the rolling stock name, if available.
86
+ :param editoast_url: Api url
87
+ :param rolling_stock_name: name of the rolling stock
88
+ :return: ID the rolling stock
89
+ """
90
+ page = 1
91
+ while page is not None :
92
+ # TODO: feel free to reduce page_size when https://github.com/osrd-project/osrd/issues/5350 is fixed
93
+ r = requests .get (editoast_url + "light_rolling_stock/" , params = {"page" : page , "page_size" : 1_000 })
94
+ if r .status_code // 100 != 2 :
95
+ raise RuntimeError (f"Rolling stock error { r .status_code } : { r .content } " )
96
+ rjson = r .json ()
97
+ for rolling_stock in rjson ["results" ]:
98
+ if rolling_stock ["name" ] == rolling_stock_name :
99
+ return rolling_stock ["id" ]
100
+ page = rjson .get ("next" )
101
+ raise ValueError (f"Unable to find rolling stock { rolling_stock_name } " )
102
+
103
+
104
+ def create_fast_rolling_stocks (test_rolling_stocks : Optional [List [TestRollingStock ]] = None ):
84
105
if test_rolling_stocks is None :
85
106
payload = json .loads (FAST_ROLLING_STOCK_JSON_PATH .read_text ())
86
- response = requests .post (f"{ EDITOAST_URL } rolling_stock/" , json = payload ).json ()
87
- # TODO: if the fast_rolling_stock already exists, we should probably fetch it
88
- assert "id" in response , f"Failed to create rolling stock: { response } "
89
- return [response ["id" ]]
107
+ response = requests .post (f"{ EDITOAST_URL } rolling_stock/" , json = payload )
108
+ rjson = response .json ()
109
+ if response .status_code // 100 == 4 and "NameAlreadyUsed" in rjson ["type" ]:
110
+ return [get_rolling_stock (EDITOAST_URL , rjson ["context" ]["name" ])]
111
+ assert "id" in rjson , f"Failed to create rolling stock: { rjson } "
112
+ return [rjson ["id" ]]
90
113
ids = []
91
114
for rs in test_rolling_stocks :
92
115
payload = json .loads (rs .base_path .read_text ())
@@ -98,15 +121,15 @@ def _create_fast_rolling_stocks(test_rolling_stocks: List[TestRollingStock] = No
98
121
99
122
@pytest .fixture
100
123
def fast_rolling_stocks (request : Any ) -> Iterable [int ]:
101
- ids = _create_fast_rolling_stocks (request .node .get_closest_marker ("names_and_metadata" ).args [0 ])
124
+ ids = create_fast_rolling_stocks (request .node .get_closest_marker ("names_and_metadata" ).args [0 ])
102
125
yield ids
103
126
for id in ids :
104
127
requests .delete (f"{ EDITOAST_URL } rolling_stock/{ id } ?force=true" )
105
128
106
129
107
130
@pytest .fixture
108
131
def fast_rolling_stock () -> int :
109
- id = _create_fast_rolling_stocks ()[0 ]
132
+ id = create_fast_rolling_stocks ()[0 ]
110
133
yield id
111
134
requests .delete (f"{ EDITOAST_URL } rolling_stock/{ id } ?force=true" )
112
135
0 commit comments