Skip to content

Commit d56bf68

Browse files
committed
Add regression test diamond crossing without over and underpasses.
1 parent af4ed97 commit d56bf68

File tree

2 files changed

+156
-34
lines changed

2 files changed

+156
-34
lines changed

flatland/utils/simple_rail.py

+70-34
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from flatland.core.grid.rail_env_grid import RailEnvTransitions
5+
from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum
66
from flatland.core.transition_map import GridTransitionMap
77

88

@@ -42,16 +42,16 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict]:
4242
rail = GridTransitionMap(width=rail_map.shape[1],
4343
height=rail_map.shape[0], transitions=transitions)
4444
rail.grid = rail_map
45-
city_positions = [(0,3), (6, 6)]
45+
city_positions = [(3, 9), (6, 2)]
4646
train_stations = [
47-
[( (0, 3), 0 ) ],
48-
[( (6, 6), 0 ) ],
49-
]
47+
[((3, 9), 1)],
48+
[((6, 2), 2)],
49+
]
5050
city_orientations = [0, 2]
5151
agents_hints = {'city_positions': city_positions,
5252
'train_stations': train_stations,
5353
'city_orientations': city_orientations
54-
}
54+
}
5555
optionals = {'agents_hints': agents_hints}
5656
return rail, rail_map, optionals
5757

@@ -93,16 +93,16 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict]:
9393
rail = GridTransitionMap(width=rail_map.shape[1],
9494
height=rail_map.shape[0], transitions=transitions)
9595
rail.grid = rail_map
96-
city_positions = [(0,3), (6, 6)]
96+
city_positions = [(0, 3), (6, 6)]
9797
train_stations = [
98-
[( (0, 3), 0 ) ],
99-
[( (6, 6), 0 ) ],
100-
]
98+
[((0, 3), 0)],
99+
[((6, 6), 0)],
100+
]
101101
city_orientations = [0, 2]
102102
agents_hints = {'city_positions': city_positions,
103103
'train_stations': train_stations,
104104
'city_orientations': city_orientations
105-
}
105+
}
106106
optionals = {'agents_hints': agents_hints}
107107
return rail, rail_map, optionals
108108

@@ -141,16 +141,16 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array, Dict]:
141141
rail = GridTransitionMap(width=rail_map.shape[1],
142142
height=rail_map.shape[0], transitions=transitions)
143143
rail.grid = rail_map
144-
city_positions = [(0,3), (6, 6)]
144+
city_positions = [(0, 3), (6, 6)]
145145
train_stations = [
146-
[( (0, 3), 0 ) ],
147-
[( (6, 6), 0 ) ],
148-
]
146+
[((0, 3), 0)],
147+
[((6, 6), 0)],
148+
]
149149
city_orientations = [0, 2]
150150
agents_hints = {'city_positions': city_positions,
151151
'train_stations': train_stations,
152152
'city_orientations': city_orientations
153-
}
153+
}
154154
optionals = {'agents_hints': agents_hints}
155155
return rail, rail_map, optionals
156156

@@ -190,16 +190,16 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array, Dict]:
190190
rail = GridTransitionMap(width=rail_map.shape[1],
191191
height=rail_map.shape[0], transitions=transitions)
192192
rail.grid = rail_map
193-
city_positions = [(0,3), (6, 6)]
193+
city_positions = [(0, 3), (6, 6)]
194194
train_stations = [
195-
[( (0, 3), 0 ) ],
196-
[( (6, 6), 0 ) ],
197-
]
195+
[((0, 3), 0)],
196+
[((6, 6), 0)],
197+
]
198198
city_orientations = [0, 2]
199199
agents_hints = {'city_positions': city_positions,
200200
'train_stations': train_stations,
201201
'city_orientations': city_orientations
202-
}
202+
}
203203
optionals = {'agents_hints': agents_hints}
204204
return rail, rail_map, optionals
205205

@@ -245,16 +245,16 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array, D
245245
rail = GridTransitionMap(width=rail_map.shape[1],
246246
height=rail_map.shape[0], transitions=transitions)
247247
rail.grid = rail_map
248-
city_positions = [(0,3), (6, 6)]
248+
city_positions = [(0, 3), (6, 6)]
249249
train_stations = [
250-
[( (0, 3), 0 ) ],
251-
[( (6, 6), 0 ) ],
252-
]
250+
[((0, 3), 0)],
251+
[((6, 6), 0)],
252+
]
253253
city_orientations = [0, 2]
254254
agents_hints = {'city_positions': city_positions,
255255
'train_stations': train_stations,
256256
'city_orientations': city_orientations
257-
}
257+
}
258258
optionals = {'agents_hints': agents_hints}
259259
return rail, rail_map, optionals
260260

@@ -294,20 +294,21 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict[str, A
294294
rail = GridTransitionMap(width=rail_map.shape[1],
295295
height=rail_map.shape[0], transitions=transitions)
296296
rail.grid = rail_map
297-
city_positions = [(0,3), (6, 6)]
297+
city_positions = [(0, 3), (6, 6)]
298298
train_stations = [
299-
[( (0, 3), 0 ) ],
300-
[( (6, 6), 0 ) ],
301-
]
299+
[((0, 3), 0)],
300+
[((6, 6), 0)],
301+
]
302302
city_orientations = [0, 2]
303303
agents_hints = {'city_positions': city_positions,
304304
'train_stations': train_stations,
305305
'city_orientations': city_orientations
306-
}
306+
}
307307
optionals = {'agents_hints': agents_hints}
308308
return rail, rail_map, optionals
309309

310-
def make_oval_rail() -> Tuple[GridTransitionMap, np.array]:
310+
311+
def make_oval_rail() -> Tuple[GridTransitionMap, np.array, Any]:
311312
transitions = RailEnvTransitions()
312313
cells = transitions.transition_list
313314

@@ -322,7 +323,7 @@ def make_oval_rail() -> Tuple[GridTransitionMap, np.array]:
322323
rail_map = np.array(
323324
[[empty] * 9] +
324325
[[empty] + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west] + [empty]] +
325-
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]]+
326+
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] +
326327
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] +
327328
[[empty] + [right_turn_from_east] + [horizontal_straight] * 5 + [right_turn_from_north] + [empty]] +
328329
[[empty] * 9], dtype=np.uint16)
@@ -341,4 +342,39 @@ def make_oval_rail() -> Tuple[GridTransitionMap, np.array]:
341342
'city_orientations': city_orientations
342343
}
343344
optionals = {'agents_hints': agents_hints}
344-
return rail, rail_map, optionals
345+
return rail, rail_map, optionals
346+
347+
348+
def make_diamond_crossing_rail() -> Tuple[GridTransitionMap, np.array, Dict]:
349+
# We instantiate a very simple rail network on a 6x10 grid:
350+
# Note that some cells have invalid RailEnvTransitions!
351+
# |
352+
# |
353+
# _ _ | _ _ _ _ _ _ _
354+
# |
355+
# |
356+
# |
357+
transitions = RailEnvTransitions()
358+
rail_map = np.array(
359+
[[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.dead_end_from_south] + [RailEnvTransitionsEnum.empty] * 7] +
360+
[[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.vertical_straight] + [RailEnvTransitionsEnum.empty] * 7] * 2 +
361+
[[RailEnvTransitionsEnum.dead_end_from_east] + [RailEnvTransitionsEnum.horizontal_straight] * 1 + [RailEnvTransitionsEnum.diamond_crossing] * 1 + [
362+
RailEnvTransitionsEnum.horizontal_straight] * 6 + [RailEnvTransitionsEnum.dead_end_from_west]] +
363+
[[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.vertical_straight] + [RailEnvTransitionsEnum.empty] * 7] * 2 +
364+
[[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.dead_end_from_north] + [RailEnvTransitionsEnum.empty] * 7]
365+
, dtype=np.uint16)
366+
rail = GridTransitionMap(width=rail_map.shape[1],
367+
height=rail_map.shape[0], transitions=transitions)
368+
rail.grid = rail_map
369+
city_positions = [(1, 4), (4, 4)]
370+
train_stations = [
371+
[((1, 4), 0)],
372+
[((4, 4), 0)],
373+
]
374+
city_orientations = [1, 3]
375+
agents_hints = {'city_positions': city_positions,
376+
'train_stations': train_stations,
377+
'city_orientations': city_orientations
378+
}
379+
optionals = {'agents_hints': agents_hints}
380+
return rail, rail_map, optionals

tests/test_over_under_passes.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import time
2+
3+
from flatland.envs.line_generators import sparse_line_generator
4+
from flatland.envs.observations import TreeObsForRailEnv
5+
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
6+
from flatland.envs.rail_env import RailEnv
7+
from flatland.envs.rail_env_action import RailEnvActions
8+
from flatland.envs.rail_generators import rail_from_grid_transition_map
9+
from flatland.envs.rail_trainrun_data_structures import Waypoint
10+
from flatland.envs.step_utils.states import TrainState
11+
from flatland.utils.rendertools import RenderTool
12+
from flatland.utils.simple_rail import make_diamond_crossing_rail
13+
14+
15+
def test_diamond_crossing_without_over_and_underpasses(rendering: bool = False):
16+
rail, rail_map, optionals = make_diamond_crossing_rail()
17+
18+
# TODO better way to init state?
19+
env = RailEnv(
20+
width=rail_map.shape[1],
21+
height=rail_map.shape[0],
22+
# TODO typing
23+
rail_generator=rail_from_grid_transition_map(rail, optionals),
24+
line_generator=sparse_line_generator(),
25+
number_of_agents=2,
26+
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
27+
record_steps=True
28+
)
29+
30+
env.reset()
31+
env._max_episode_steps = 555
32+
33+
# set the initial position
34+
agent_0 = env.agents[0]
35+
agent_0.initial_position = (3, 1) # just ahead of diamond crossing facing east
36+
agent_0.position = (3, 1) # just ahead of diamond crossing facing east
37+
agent_0.direction = 1 # east
38+
agent_0.initial_direction = 1 # east
39+
agent_0.target = (3, 9) # east dead-end
40+
agent_0.moving = True
41+
agent_0.latest_arrival = 999
42+
agent_0._set_state(TrainState.MOVING)
43+
44+
agent_1 = env.agents[1]
45+
agent_1.initial_position = (2, 2) # just ahead of diamond crossing facing south
46+
agent_1.position = (2, 2) # just ahead of diamond crossing facing south
47+
agent_1.direction = 2 # south
48+
agent_1.initial_direction = 2 # south
49+
agent_1.target = (6, 2) # south dead-end
50+
agent_1.moving = True
51+
agent_1.latest_arrival = 999
52+
agent_1._set_state(TrainState.MOVING)
53+
54+
env.distance_map._compute(env.agents, env.rail)
55+
done = False
56+
env_renderer = None
57+
if rendering:
58+
env_renderer = RenderTool(env)
59+
while not done:
60+
_, _, dones, _ = env.step({
61+
0: RailEnvActions.MOVE_FORWARD,
62+
1: RailEnvActions.MOVE_FORWARD,
63+
})
64+
done = dones["__all__"]
65+
if env_renderer is not None:
66+
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
67+
time.sleep(0.2)
68+
69+
waypoints = []
70+
for agent_states in env.cur_episode:
71+
cur = []
72+
for agent_state in agent_states:
73+
r, c, d, _, _, _ = agent_state
74+
cur.append(Waypoint((r, c), d))
75+
waypoints.append(cur)
76+
expected = [
77+
[Waypoint(position=(3, 2), direction=1), Waypoint(position=(2, 2), direction=2)],
78+
[Waypoint(position=(3, 3), direction=1), Waypoint(position=(3, 2), direction=2)],
79+
[Waypoint(position=(3, 4), direction=1), Waypoint(position=(4, 2), direction=2)],
80+
[Waypoint(position=(3, 5), direction=1), Waypoint(position=(5, 2), direction=2)],
81+
[Waypoint(position=(3, 6), direction=1), Waypoint(position=(0, 0), direction=2)],
82+
[Waypoint(position=(3, 7), direction=1), Waypoint(position=(0, 0), direction=2)],
83+
[Waypoint(position=(3, 8), direction=1), Waypoint(position=(0, 0), direction=2)],
84+
[Waypoint(position=(0, 0), direction=1), Waypoint(position=(0, 0), direction=2)]
85+
]
86+
assert expected == waypoints, waypoints

0 commit comments

Comments
 (0)