Skip to content

Commit 5b61c0f

Browse files
committed
railjson_generator: replace 'pytype' by 'pyright'
Signed-off-by: Jean SIMARD <[email protected]>
1 parent 5197260 commit 5b61c0f

15 files changed

+83
-449
lines changed

.github/workflows/build.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ jobs:
250250
poetry run ruff check
251251
poetry run ruff format --check
252252
253-
- name: Pytype
253+
- name: Pyright
254254
run: |
255255
cd python/railjson_generator
256-
poetry run pytype -j auto
256+
poetry run pyright
257257
258258
- name: Pytest
259259
run: |

python/railjson_generator/poetry.lock

+17-382
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/railjson_generator/pyproject.toml

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@ pytest-cov = "^4.1.0"
1515
python = ">=3.9,<3.12"
1616

1717
[tool.poetry.group.dev.dependencies]
18-
pytype = { platform = "linux", version = "^2023.10.17" }
18+
pyright = "1.1.393"
1919
ruff = "0.9.5"
2020

21-
[tool.pytype]
22-
inputs = ["railjson_generator"]
23-
2421
[build-system]
2522
requires = ["poetry-core>=1.0.0"]
2623
build-backend = "poetry.core.masonry.api"

python/railjson_generator/railjson_generator/schema/infra/infra.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def to_rjs(self) -> infra.RailJsonInfra:
3636
track_sections=[track.to_rjs() for track in self.track_sections],
3737
switches=[switch.to_rjs() for switch in self.switches],
3838
routes=[route.to_rjs() for route in self.routes],
39-
signals=self.make_rjs_signals(),
40-
buffer_stops=self.make_rjs_buffer_stops(),
41-
detectors=self.make_rjs_detectors(),
39+
signals=list(self.make_rjs_signals()),
40+
buffer_stops=list(self.make_rjs_buffer_stops()),
41+
detectors=list(self.make_rjs_detectors()),
4242
operational_points=self.make_rjs_operational_points(),
4343
extended_switch_types=[],
4444
speed_sections=[
@@ -85,8 +85,7 @@ def make_rjs_operational_points(self):
8585
new_op = infra.OperationalPoint(
8686
id=op.label,
8787
parts=parts_per_op[op.label],
88-
name=op.label,
89-
extensions={
88+
extensions={ # pyright: ignore[reportCallIssue] - 'extensions' exists but is register through 'register_extension'
9089
"sncf": infra.OperationalPointSncfExtension(
9190
ci=0,
9291
ch="BV",
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Mapping
22

33
from geojson_pydantic import LineString
4+
from geojson_pydantic.types import LineStringCoords
45

56

6-
def make_geo_line(*points) -> LineString:
7+
def make_geo_line(points: LineStringCoords) -> LineString:
78
return LineString(coordinates=points, type="LineString")
89

910

10-
def make_geo_lines(*points) -> Mapping[str, LineString]:
11-
return {"geo": make_geo_line(*points)}
11+
def make_geo_lines(points: LineStringCoords) -> Mapping[str, LineString]:
12+
return {"geo": make_geo_line(points)}

python/railjson_generator/railjson_generator/schema/infra/neutral_section.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
def _neutral_section_id():
1212
# pytype: disable=name-error
13-
res = f"neutral_section.{NeutralSection._INDEX}"
14-
NeutralSection._INDEX += 1
13+
res = f"neutral_section.{NeutralSection._index}"
14+
NeutralSection._index += 1
1515
# pytype: enable=name-error
1616
return res
1717

@@ -33,7 +33,7 @@ class NeutralSection:
3333
lower_pantograph: bool = field(default=False)
3434
label: str = field(default_factory=_neutral_section_id)
3535

36-
_INDEX = 0
36+
_index = 0
3737

3838
def add_track_range(
3939
self, track: TrackSection, begin: float, end: float, direction: Direction

python/railjson_generator/railjson_generator/schema/infra/route.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __post_init__(self):
3535

3636
def to_rjs(self):
3737
return infra.Route(
38-
id=self.label,
38+
id=self.label, # pyright: ignore[reportArgumentType] - '__post_init__' ensures 'label' always has a value
3939
entry_point=self.entry_point.get_waypoint_ref(),
4040
entry_point_direction=infra.Direction[self.entry_point_direction.name],
4141
exit_point=self.exit_point.get_waypoint_ref(),

python/railjson_generator/railjson_generator/schema/infra/signal.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
def _signal_id():
1010
# pytype: disable=name-error
11-
res = f"signal.{Signal._INDEX}"
12-
Signal._INDEX += 1
11+
res = f"signal.{Signal._index}"
12+
Signal._index += 1
1313
# pytype: enable=name-error
1414
return res
1515

@@ -20,7 +20,7 @@ class SignalConditionalParameters:
2020
parameters: Dict[str, str]
2121

2222
def to_rjs(self):
23-
return infra.SignalConditionalParameters(
23+
return infra.ConditionalParameter(
2424
on_route=self.on_route,
2525
parameters=self.parameters,
2626
)
@@ -69,7 +69,7 @@ class Signal:
6969
installation_type: str = "CARRE"
7070
side: infra.Side = infra.Side.LEFT
7171

72-
_INDEX = 0
72+
_index = 0
7373

7474
def add_logical_signal(self, *args, **kwargs) -> LogicalSignal:
7575
signal = LogicalSignal(*args, **kwargs)
@@ -84,7 +84,7 @@ def to_rjs(self, track):
8484
direction=infra.Direction[self.direction.name],
8585
sight_distance=self.sight_distance,
8686
logical_signals=[sig.to_rjs() for sig in self.logical_signals],
87-
extensions={
87+
extensions={ # pyright: ignore[reportCallIssue] - 'extensions' exists but is register through 'register_extension'
8888
"sncf": {
8989
"label": self.label,
9090
"side": self.side,

python/railjson_generator/railjson_generator/schema/infra/speed_section.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
def _speed_section_id():
1212
# pytype: disable=name-error
13-
res = f"speed_section.{SpeedSection._INDEX}"
14-
SpeedSection._INDEX += 1
13+
res = f"speed_section.{SpeedSection._index}"
14+
SpeedSection._index += 1
1515
# pytype: enable=name-error
1616
return res
1717

@@ -24,7 +24,7 @@ class SpeedSection:
2424
label: str = field(default_factory=_speed_section_id)
2525
on_routes: Optional[List[str]] = None
2626

27-
_INDEX = 0
27+
_index = 0
2828

2929
def add_track_range(self, track, begin, end, applicable_directions):
3030
self.track_ranges.append(

python/railjson_generator/railjson_generator/schema/infra/switch.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
def _switch_id():
1818
# pytype: disable=name-error
19-
res = f"switch.{Switch._INDEX}"
20-
Switch._INDEX += 1
19+
res = f"switch.{Switch._index}"
20+
Switch._index += 1
2121
# pytype: enable=name-error
2222
return res
2323

@@ -30,7 +30,7 @@ class SwitchGroup:
3030

3131
@dataclass
3232
class Switch:
33-
_INDEX = 0
33+
_index = 0
3434
# overridden by subclasses
3535
PORT_NAMES = []
3636
SWITCH_TYPE = ""
@@ -85,7 +85,7 @@ def to_rjs(self):
8585
port_name: getattr(self, port_name).to_rjs()
8686
for port_name in self.PORT_NAMES
8787
},
88-
extensions={"sncf": {"label": self.label}},
88+
extensions={"sncf": {"label": self.label}}, # pyright: ignore[reportCallIssue] - 'extensions' exists but is register through 'register_extension'
8989
)
9090

9191

python/railjson_generator/railjson_generator/schema/infra/track_section.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass, field
22
from typing import List, Optional, Tuple
33

4+
from geojson_pydantic.types import LineStringCoords
45
from osrd_schemas import infra
56
from pydantic import ValidationError
67

@@ -21,15 +22,15 @@
2122

2223
def _track_id():
2324
# pytype: disable=name-error
24-
res = f"track.{TrackSection._INDEX}"
25-
TrackSection._INDEX += 1
25+
res = f"track.{TrackSection._index}"
26+
TrackSection._index += 1
2627
# pytype: enable=name-error
2728
return res
2829

2930

3031
@dataclass
3132
class TrackSection:
32-
_INDEX = 0
33+
_index = 0
3334

3435
length: float
3536
label: str = field(default_factory=_track_id)
@@ -131,10 +132,12 @@ def neighbors(self, direction: Direction):
131132
return self.begining_links
132133

133134
def to_rjs(self):
134-
if self.coordinates == [(None, None), (None, None)]:
135-
self.coordinates = [(0, 0), (0, 0)]
135+
# Replace 'None' by '0.0'
136+
coordinates: LineStringCoords = list(
137+
map(lambda pos: (pos[0] or 0.0, pos[1] or 0.0), self.coordinates)
138+
)
136139
try:
137-
geo_data = make_geo_lines(*self.coordinates)
140+
geo_data = make_geo_lines(coordinates)
138141
except ValidationError:
139142
print(f"Track section {self.label} has invalid coordinates:")
140143
print(self.coordinates)
@@ -149,7 +152,7 @@ def to_rjs(self):
149152
for loading_gauge_limit in self.loading_gauge_limits
150153
],
151154
**geo_data,
152-
extensions={
155+
extensions={ # pyright: ignore[reportCallIssue] - 'extensions' exists but is register through 'register_extension'
153156
"sncf": {
154157
"line_code": self.line_code,
155158
"line_name": self.line_name,

python/railjson_generator/railjson_generator/schema/infra/waypoint.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
def _buffer_stop_id():
1212
# pytype: disable=name-error
13-
res = f"buffer_stop.{BufferStop._INDEX}"
14-
BufferStop._INDEX += 1
13+
res = f"buffer_stop.{BufferStop._index}"
14+
BufferStop._index += 1
1515
# pytype: enable=name-error
1616
return res
1717

@@ -21,7 +21,7 @@ class BufferStop:
2121
position: float
2222
label: str = field(default_factory=_buffer_stop_id)
2323

24-
_INDEX = 0
24+
_index = 0
2525

2626
@property
2727
def id(self):
@@ -45,8 +45,8 @@ def get_direction(self, track) -> Direction:
4545

4646
def _detector_id():
4747
# pytype: disable=name-error
48-
res = f"detector.{Detector._INDEX}"
49-
Detector._INDEX += 1
48+
res = f"detector.{Detector._index}"
49+
Detector._index += 1
5050
# pytype: enable=name-error
5151
return res
5252

@@ -56,7 +56,7 @@ class Detector:
5656
position: float
5757
label: str = field(default_factory=_detector_id)
5858

59-
_INDEX = 0
59+
_index = 0
6060

6161
@property
6262
def id(self):

python/railjson_generator/railjson_generator/schema/simulation/train_schedule.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from dataclasses import dataclass, field
2-
from typing import List
2+
from typing import Union
33

44
from osrd_schemas.train_schedule import (
5-
Allowance,
65
AllowanceDistribution,
76
AllowancePercentValue,
87
AllowanceTimePerDistanceValue,
@@ -17,8 +16,8 @@
1716

1817
def _train_id():
1918
# pytype: disable=name-error
20-
res = f"train.{TrainSchedule._INDEX}"
21-
TrainSchedule._INDEX += 1
19+
res = f"train.{TrainSchedule._index}"
20+
TrainSchedule._index += 1
2221
# pytype: enable=name-error
2322
return res
2423

@@ -29,10 +28,12 @@ class TrainSchedule:
2928
rolling_stock: str = field(default="fast_rolling_stock")
3029
departure_time: float = field(default=0.0)
3130
initial_speed: float = field(default=0.0)
32-
stops: List[Stop] = field(default_factory=list)
33-
allowances: List[Allowance] = field(default_factory=list)
31+
stops: list[Stop] = field(default_factory=list)
32+
allowances: list[Union[EngineeringAllowance, StandardAllowance]] = field(
33+
default_factory=list
34+
)
3435

35-
_INDEX = 0
36+
_index = 0
3637

3738
def add_stop(self, *args, **kwargs):
3839
stop = Stop(*args, **kwargs)
@@ -54,18 +55,18 @@ def add_standard_single_value_allowance(
5455
"""Add a standard allowance with a single value. For more information on allowances, see
5556
the documentation of the Allowance class in osrd_schemas."""
5657
if value_type == "time":
57-
value = AllowanceTimeValue(seconds=value)
58+
allowance = AllowanceTimeValue(seconds=value)
5859
elif value_type == "time_per_distance":
59-
value = AllowanceTimePerDistanceValue(minutes=value)
60+
allowance = AllowanceTimePerDistanceValue(minutes=value)
6061
elif value_type == "percentage":
61-
value = AllowancePercentValue(percentage=value)
62+
allowance = AllowancePercentValue(percentage=value)
6263
else:
6364
raise ValueError(f"Unknown value kind {value_type}")
6465

6566
distribution = AllowanceDistribution(distribution)
6667

6768
self.add_allowance(
68-
default_value=value,
69+
default_value=allowance,
6970
distribution=distribution,
7071
ranges=[],
7172
capacity_speed_limit=-1,
@@ -86,8 +87,8 @@ def format(self):
8687

8788
def _group_id():
8889
# pytype: disable=name-error
89-
res = f"group.{TrainScheduleGroup._INDEX}"
90-
TrainScheduleGroup._INDEX += 1
90+
res = f"group.{TrainScheduleGroup._index}"
91+
TrainScheduleGroup._index += 1
9192
# pytype: enable=name-error
9293
return res
9394

@@ -96,11 +97,11 @@ def _group_id():
9697
class TrainScheduleGroup:
9798
"""A group of train schedules that share the same waypoints."""
9899

99-
schedules: List[TrainSchedule] = field(default_factory=list)
100-
waypoints: List[List[DirectedLocation]] = field(default_factory=list)
100+
schedules: list[TrainSchedule] = field(default_factory=list)
101+
waypoints: list[list[DirectedLocation]] = field(default_factory=list)
101102
id: str = field(default_factory=_group_id)
102103

103-
_INDEX = 0
104+
_index = 0
104105

105106
def format(self):
106107
return {

python/railjson_generator/railjson_generator/test_speed_section.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from osrd_schemas.infra import ApplicableDirections
2-
1+
from railjson_generator.schema.infra.direction import ApplicableDirection
32
from railjson_generator.schema.infra.speed_section import (
43
ApplicableDirectionsTrackRange,
54
SpeedSection,
@@ -23,27 +22,27 @@ def test_applicable_track_ranges(self):
2322
begin=50.0,
2423
end=100.0,
2524
track=track1,
26-
applicable_directions=ApplicableDirections.BOTH,
25+
applicable_directions=ApplicableDirection.BOTH,
2726
),
2827
ApplicableDirectionsTrackRange(
2928
begin=0.0,
3029
end=200.0,
3130
track=track2,
32-
applicable_directions=ApplicableDirections.START_TO_STOP,
31+
applicable_directions=ApplicableDirection.START_TO_STOP,
3332
),
3433
)
3534

3635
ref.add_track_range(
3736
begin=50.0,
3837
end=100.0,
3938
track=track1,
40-
applicable_directions=ApplicableDirections.BOTH,
39+
applicable_directions=ApplicableDirection.BOTH,
4140
)
4241
ref.add_track_range(
4342
begin=0.0,
4443
end=200.0,
4544
track=track2,
46-
applicable_directions=ApplicableDirections.START_TO_STOP,
45+
applicable_directions=ApplicableDirection.START_TO_STOP,
4746
)
4847

4948
assert speed == ref

0 commit comments

Comments
 (0)