|
8 | 8 | from enum import Enum, IntEnum
|
9 | 9 | from functools import cache
|
10 | 10 | from pathlib import Path
|
11 |
| -from typing import Dict, Iterable, List, Optional, Tuple |
| 11 | +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar |
12 | 12 |
|
13 | 13 | import requests
|
14 | 14 | from osrd_schemas.switch_type import builtin_node_types
|
|
36 | 36 | """
|
37 | 37 |
|
38 | 38 |
|
39 |
| -class _FailedTest(Exception): |
40 |
| - pass |
41 |
| - |
42 |
| - |
43 |
| -class _ErrorType(str, Enum): |
44 |
| - SCHEDULE = "SCHEDULE" |
45 |
| - RESULT = "RESULT" |
46 |
| - STDCM = "STDCM" |
47 |
| - |
48 |
| - |
49 |
| -class _Endpoint(IntEnum): |
50 |
| - BEGIN = 0 |
51 |
| - END = 1 |
52 |
| - |
53 |
| - def opposite(self): |
54 |
| - return _Endpoint.END if self == _Endpoint.BEGIN else _Endpoint.BEGIN |
55 |
| - |
56 |
| - |
57 |
| -@dataclass(eq=True, frozen=True) |
58 |
| -class _TrackEndpoint: |
59 |
| - track: str |
60 |
| - endpoint: _Endpoint |
61 |
| - |
62 |
| - @staticmethod |
63 |
| - def from_dict(obj: Dict) -> "_TrackEndpoint": |
64 |
| - return _TrackEndpoint(obj["track"], _Endpoint._member_map_[obj["endpoint"]]) |
65 |
| - |
66 |
| - |
67 |
| -@dataclass |
68 |
| -class _InfraGraph: |
69 |
| - RJSInfra: Dict |
70 |
| - tracks: Dict[str, Dict] = field(default_factory=dict) |
71 |
| - links: Dict[_TrackEndpoint, List[_TrackEndpoint]] = field(default_factory=lambda: defaultdict(list)) |
72 |
| - |
73 |
| - def link(self, a: _TrackEndpoint, b: _TrackEndpoint): |
74 |
| - self.links[a].append(b) |
75 |
| - self.links[b].append(a) |
76 |
| - |
77 |
| - |
78 |
| -@dataclass |
79 |
| -class _RollingStock: |
80 |
| - name: str |
81 |
| - id: int |
82 |
| - |
83 |
| - |
84 | 39 | def run(
|
85 | 40 | editoast_url: str,
|
86 | 41 | scenario: Scenario,
|
@@ -186,6 +141,54 @@ def create_scenario(editoast_url: str, infra_id: int) -> Scenario:
|
186 | 141 | return Scenario(project_id, study_id, id, infra_id, timetable_id)
|
187 | 142 |
|
188 | 143 |
|
| 144 | +class _FailedTest(Exception): |
| 145 | + pass |
| 146 | + |
| 147 | + |
| 148 | +class _ErrorType(str, Enum): |
| 149 | + SCHEDULE = "SCHEDULE" |
| 150 | + RESULT = "RESULT" |
| 151 | + STDCM = "STDCM" |
| 152 | + |
| 153 | + |
| 154 | +class _Endpoint(IntEnum): |
| 155 | + BEGIN = 0 |
| 156 | + END = 1 |
| 157 | + |
| 158 | + def opposite(self): |
| 159 | + return _Endpoint.END if self == _Endpoint.BEGIN else _Endpoint.BEGIN |
| 160 | + |
| 161 | + |
| 162 | +@dataclass(eq=True, frozen=True) |
| 163 | +class _TrackEndpoint: |
| 164 | + track: str |
| 165 | + endpoint: _Endpoint |
| 166 | + |
| 167 | + @staticmethod |
| 168 | + def from_dict(obj: Dict) -> "_TrackEndpoint": |
| 169 | + return _TrackEndpoint(obj["track"], _Endpoint._member_map_[obj["endpoint"]]) |
| 170 | + |
| 171 | + |
| 172 | +@dataclass |
| 173 | +class _InfraGraph: |
| 174 | + RJSInfra: Dict |
| 175 | + tracks: Dict[str, Dict] = field(default_factory=dict) |
| 176 | + links: Dict[_TrackEndpoint, List[_TrackEndpoint]] = field(default_factory=lambda: defaultdict(list)) |
| 177 | + |
| 178 | + def link(self, a: _TrackEndpoint, b: _TrackEndpoint): |
| 179 | + self.links[a].append(b) |
| 180 | + self.links[b].append(a) |
| 181 | + |
| 182 | + |
| 183 | +@dataclass |
| 184 | +class _RollingStock: |
| 185 | + name: str |
| 186 | + id: int |
| 187 | + |
| 188 | + |
| 189 | +U = TypeVar("U") |
| 190 | + |
| 191 | + |
189 | 192 | def _make_error(
|
190 | 193 | error_type: _ErrorType,
|
191 | 194 | response: Response,
|
@@ -373,7 +376,7 @@ def _make_graph(editoast_url: str, infra: int) -> _InfraGraph:
|
373 | 376 | return graph
|
374 | 377 |
|
375 | 378 |
|
376 |
| -def _random_set_element(s: Iterable): |
| 379 | +def _random_set_element(s: Iterable[U]) -> U: |
377 | 380 | """
|
378 | 381 | Picks a random element in an iterable
|
379 | 382 | """
|
|
0 commit comments