Skip to content

Commit 7ef78f2

Browse files
authored
110 Over- and underpasses (aka. level-free diamond crossings). (#120)
* Add regression test diamond crossing without over and underpasses. * First working version of level-free crossings. * Add failing test * Fix head-on collision on level-free elements. * Improve type hints. * Remove obsolete code. * Formatting. * Improve type hints. * Cleanup. * Formatting. * Improve documentation. * Add p_level_free in rail gen to sample percentage of level free diamond-crossings from generated diamond crossings. * Cleanup. Signed-off-by: chenkins <[email protected]> * Apply suggestions from code review --------- Signed-off-by: chenkins <[email protected]>
1 parent de36600 commit 7ef78f2

9 files changed

+537
-146
lines changed

flatland/envs/line_generators.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
"""Line generators (railway undertaking, "EVU")."""
2-
import warnings
1+
"""Line generators: Railway Undertaking (RU) / Eisenbahnverkehrsunternehmen (EVU)."""
32
from typing import Tuple, List, Callable, Mapping, Optional, Any
43

5-
import numpy as np
64
from numpy.random.mtrand import RandomState
75

8-
from flatland.core.grid.grid4_utils import get_new_position
96
from flatland.core.transition_map import GridTransitionMap
10-
from flatland.envs.agent_utils import EnvAgent
11-
from flatland.envs.timetable_utils import Line
127
from flatland.envs import persistence
13-
from flatland.utils.decorators import enable_infrastructure_lru_cache
8+
from flatland.envs.timetable_utils import Line
149

1510
AgentPosition = Tuple[int, int]
1611
LineGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Line]
@@ -46,8 +41,8 @@ def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1)
4641
self.speed_ratio_map = speed_ratio_map
4742
self.seed = seed
4843

49-
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0,
50-
np_random: RandomState = None) -> Line:
44+
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
45+
np_random: RandomState = None) -> Line:
5146
pass
5247

5348
def __call__(self, *args, **kwargs):
@@ -81,7 +76,7 @@ def decide_orientation(self, rail, start, target, possible_orientations, np_rand
8176
return 0
8277

8378
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
84-
np_random: RandomState) -> Line:
79+
np_random: RandomState) -> Line:
8580
"""
8681
8782
The generator that assigns tasks to all the agents
@@ -102,12 +97,10 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
10297
agents_target = []
10398
agents_direction = []
10499

105-
106100
city1, city2 = None, None
107101
city1_num_stations, city2_num_stations = None, None
108102
city1_possible_orientations, city2_possible_orientations = None, None
109103

110-
111104
for agent_idx in range(num_agents):
112105

113106
if (agent_idx % 2 == 0):
@@ -118,9 +111,9 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
118111
city1_num_stations = len(train_stations[city1])
119112
city2_num_stations = len(train_stations[city2])
120113
city1_possible_orientations = [city_orientation[city1],
121-
(city_orientation[city1] + 2) % 4]
114+
(city_orientation[city1] + 2) % 4]
122115
city2_possible_orientations = [city_orientation[city2],
123-
(city_orientation[city2] + 2) % 4]
116+
(city_orientation[city2] + 2) % 4]
124117

125118
# Agent 1 : city1 > city2, Agent 2: city2 > city1
126119
agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
@@ -143,13 +136,11 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
143136
agent_orientation = self.decide_orientation(
144137
rail, agent_start, agent_target, city2_possible_orientations, np_random)
145138

146-
147139
# agent1 details
148140
agents_position.append((agent_start[0][0], agent_start[0][1]))
149141
agents_target.append((agent_target[0][0], agent_target[0][1]))
150142
agents_direction.append(agent_orientation)
151143

152-
153144
if self.speed_ratio_map:
154145
speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
155146
else:
@@ -163,7 +154,7 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
163154
timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions)))
164155

165156
return Line(agent_positions=agents_position, agent_directions=agents_direction,
166-
agent_targets=agents_target, agent_speeds=speeds)
157+
agent_targets=agents_target, agent_speeds=speeds)
167158

168159

169160
def line_from_file(filename, load_from_package=None) -> LineGenerator:
@@ -182,11 +173,10 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
182173

183174
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
184175
np_random: RandomState = None) -> Line:
185-
186176
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
187177

188178
max_episode_steps = env_dict.get("max_episode_steps", 0)
189-
if (max_episode_steps==0):
179+
if (max_episode_steps == 0):
190180
print("This env file has no max_episode_steps (deprecated) - setting to 100")
191181
max_episode_steps = 100
192182

@@ -196,12 +186,12 @@ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_r
196186
agents_position = [a.initial_position for a in agents]
197187

198188
# this logic is wrong - we should really load the initial_direction as the direction.
199-
#agents_direction = [a.direction for a in agents]
189+
# agents_direction = [a.direction for a in agents]
200190
agents_direction = [a.initial_direction for a in agents]
201191
agents_target = [a.target for a in agents]
202192
agents_speed = [a.speed_counter.speed for a in agents]
203193

204194
return Line(agent_positions=agents_position, agent_directions=agents_direction,
205-
agent_targets=agents_target, agent_speeds=agents_speed)
195+
agent_targets=agents_target, agent_speeds=agents_speed)
206196

207197
return generator

flatland/envs/rail_env.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
"""
44
import random
55
from functools import lru_cache
6-
from typing import List, Optional, Dict, Tuple
6+
from typing import List, Optional, Dict, Tuple, Set
77

88
import numpy as np
99

1010
import flatland.envs.timetable_generators as ttg
1111
from flatland.core.env import Environment
1212
from flatland.core.env_observation_builder import ObservationBuilder
1313
from flatland.core.grid.grid4 import Grid4Transitions
14+
from flatland.core.grid.grid_utils import Vector2D
1415
from flatland.core.transition_map import GridTransitionMap
1516
from flatland.envs import agent_chains as ac
1617
from flatland.envs import line_generators as line_gen
@@ -99,7 +100,7 @@ class RailEnv(Environment):
99100
def __init__(self,
100101
width,
101102
height,
102-
rail_generator=None,
103+
rail_generator: "RailGenerator" = None,
103104
line_generator: "LineGenerator" = None, # : line_gen.LineGenerator = line_gen.random_line_generator(),
104105
number_of_agents=2,
105106
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
@@ -167,7 +168,7 @@ def __init__(self,
167168
self.rail_generator = rail_generator
168169
if line_generator is None:
169170
line_generator = line_gen.sparse_line_generator()
170-
self.line_generator: LineGenerator = line_generator
171+
self.line_generator: "LineGenerator" = line_generator
171172
self.timetable_generator = timetable_generator
172173

173174
self.rail: Optional[GridTransitionMap] = None
@@ -205,6 +206,8 @@ def __init__(self,
205206

206207
self.motionCheck = ac.MotionCheck()
207208

209+
self.level_free_positions: Set[Vector2D] = set()
210+
208211
def _seed(self, seed):
209212
self.np_random, seed = seeding.np_random(seed)
210213
random.seed(seed)
@@ -314,6 +317,8 @@ def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True,
314317
agents_hints = None
315318
if optionals and 'agents_hints' in optionals:
316319
agents_hints = optionals['agents_hints']
320+
if optionals and 'level_free_positions' in optionals:
321+
self.level_free_positions = optionals['level_free_positions']
317322

318323
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
319324
self.num_resets, self.np_random)
@@ -562,20 +567,32 @@ def step(self, action_dict: Dict[int, RailEnvActions]):
562567
direction=new_direction,
563568
preprocessed_action=preprocessed_action)
564569

570+
# only conflict if the level-free cell is traversed through the same axis (horizontally (0 north or 2 south), or vertically (1 east or 3 west)
571+
new_position_level_free = new_position
572+
if new_position in self.level_free_positions:
573+
new_position_level_free = (new_position, new_direction % 2)
574+
agent_position_level_free = agent.position
575+
if agent.position in self.level_free_positions:
576+
agent_position_level_free = (agent.position, agent.direction % 2)
577+
565578
# This is for storing and later checking for conflicts of agents trying to occupy same cell
566-
self.motionCheck.addAgent(i_agent, agent.position, new_position)
579+
self.motionCheck.addAgent(i_agent, agent_position_level_free, new_position_level_free)
567580

568581
# Find conflicts between trains trying to occupy same cell
569582
self.motionCheck.find_conflicts()
570583

571584
for agent in self.agents:
572585
i_agent = agent.handle
573586

587+
agent_position_level_free = agent.position
588+
if agent.position in self.level_free_positions:
589+
agent_position_level_free = (agent.position, agent.direction % 2)
590+
574591
## Update positions
575592
if agent.malfunction_handler.in_malfunction:
576593
movement_allowed = False
577594
else:
578-
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
595+
movement_allowed = self.motionCheck.check_motion(i_agent, agent_position_level_free)
579596

580597
movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit
581598
movement_allowed = movement_allowed or movement_inside_cell
@@ -727,6 +744,7 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar
727744
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
728745
show_predictions=show_predictions,
729746
show_rowcols=show_rowcols, return_image=return_image)
747+
730748
def initialize_renderer(self, mode, gl,
731749
agent_render_variant,
732750
show_debug,

flatland/envs/rail_generators.py

+35-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
1+
"""Rail generators: infrastructure manager (IM) / Infrastrukturbetreiber (ISB)."""
2+
import math
23
import warnings
34
from typing import Callable, Tuple, Optional, Dict, List
45

@@ -9,7 +10,7 @@
910
from flatland.core.grid.grid4_utils import direction_to_point
1011
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D, \
1112
Vec2dOperations
12-
from flatland.core.grid.rail_env_grid import RailEnvTransitions
13+
from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum
1314
from flatland.core.transition_map import GridTransitionMap
1415
from flatland.envs import persistence
1516
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
@@ -35,8 +36,7 @@ def __init__(self, *args, **kwargs):
3536
"""
3637
pass
3738

38-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
39-
np_random: RandomState = None) -> RailGeneratorProduct:
39+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
4040
pass
4141

4242
def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
@@ -53,8 +53,7 @@ class EmptyRailGen(RailGen):
5353
Primarily used by the editor
5454
"""
5555

56-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
57-
np_random: RandomState = None) -> RailGenerator:
56+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator:
5857
rail_trans = RailEnvTransitions()
5958
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
6059
rail_array = grid_map.grid
@@ -100,8 +99,7 @@ def __init__(self, rail_map, optionals=None):
10099
self.rail_map = rail_map
101100
self.optionals = optionals
102101

103-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
104-
np_random: RandomState = None) -> RailGeneratorProduct:
102+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
105103
return self.rail_map, self.optionals
106104

107105

@@ -116,7 +114,7 @@ def sparse_rail_generator(*args, **kwargs):
116114
class SparseRailGen(RailGen):
117115

118116
def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
119-
max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
117+
max_rail_pairs_in_city: int = 2, seed=None, p_level_free: float = 0) -> RailGenerator:
120118
"""
121119
Generates railway networks with cities and inner city rails
122120
@@ -133,6 +131,8 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
133131
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
134132
seed: int
135133
Initiate the seed
134+
p_level_free : float
135+
Percentage of diamond-crossings which are level-free.
136136
137137
Returns
138138
-------
@@ -143,9 +143,9 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
143143
self.max_rails_between_cities = max_rails_between_cities
144144
self.max_rail_pairs_in_city = max_rail_pairs_in_city
145145
self.seed = seed
146+
self.p_level_free = p_level_free
146147

147-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
148-
np_random: RandomState = None) -> RailGenerator:
148+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
149149
"""
150150
151151
Parameters
@@ -181,7 +181,7 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
181181
min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
182182
rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3)
183183
rails_between_cities = (rail_pairs_in_city * 2) if self.max_rails_between_cities > (
184-
rail_pairs_in_city * 2) else self.max_rails_between_cities
184+
rail_pairs_in_city * 2) else self.max_rails_between_cities
185185

186186
# We compute the city radius by the given max number of rails it can contain.
187187
# The radius is equal to the number of tracks divided by 2
@@ -237,11 +237,28 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
237237

238238
# Fix all transition elements
239239
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
240-
return grid_map, {'agents_hints': {
241-
'city_positions': city_positions,
242-
'train_stations': train_stations,
243-
'city_orientations': city_orientations
244-
}}
240+
241+
# choose p_level_free percentage of diamond crossings to be level-free
242+
num_diamond_crossings = np.count_nonzero(grid_map.grid[grid_map.grid == RailEnvTransitionsEnum.diamond_crossing])
243+
num_level_free_diamond_crossings = math.floor(self.p_level_free * num_diamond_crossings)
244+
# ceil with probability p_ceil
245+
p_ceil = (self.p_level_free * num_diamond_crossings) % 1.0
246+
num_level_free_diamond_crossings += np_random.choice([1, 0], p=(p_ceil, 1 - p_ceil))
247+
level_free_positions = set()
248+
if num_level_free_diamond_crossings > 0:
249+
choice = np_random.choice(num_diamond_crossings, size=num_level_free_diamond_crossings, replace=False)
250+
positions_diamond_crossings = (grid_map.grid == RailEnvTransitionsEnum.diamond_crossing).nonzero()
251+
level_free_positions = {tuple(positions_diamond_crossings[choice[i]]) for i in range(len(choice))}
252+
253+
return grid_map, {
254+
'agents_hints':
255+
{
256+
'city_positions': city_positions,
257+
'train_stations': train_stations,
258+
'city_orientations': city_orientations
259+
},
260+
'level_free_positions': level_free_positions
261+
}
245262

246263
def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
247264
height: int, np_random: RandomState = None) -> Tuple[
@@ -264,7 +281,6 @@ def _generate_random_city_positions(self, num_cities: int, city_radius: int, wid
264281
Returns
265282
-------
266283
Returns a list of all city positions as coordinates (x,y)
267-
268284
"""
269285

270286
city_positions: IntVector2DArray = []
@@ -322,7 +338,6 @@ def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: in
322338
Returns
323339
-------
324340
Returns a list of all city positions as coordinates (x,y)
325-
326341
"""
327342
aspect_ratio = height / width
328343
# Compute max numbe of possible cities per row and col.
@@ -526,7 +541,7 @@ def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction)
526541
"""
527542
Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
528543
direction. Direction is a 90 degree cone facing the desired directiont.
529-
Exampe:
544+
Example:
530545
North: The closes neighbour in the North direction is within the cone spanned by a line going
531546
North-West and North-East
532547
@@ -677,7 +692,6 @@ def _fix_transitions(self, city_cells: set, inter_city_lines: List[IntVector2DAr
677692
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
678693
grid_map: RailEnvTransitions
679694
The grid map containing the rails. Used to draw new rails
680-
681695
"""
682696

683697
# Fix all cities with illegal transition maps

flatland/envs/schedule_generators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
raise ImportError(" Schedule Generators is now renamed to line_generators, any reference to schedule should be replaced with line")
1+
raise ImportError(" Schedule Generators is now renamed to line_generators + timetable_generators, any reference to schedule should be replaced with line")

flatland/envs/sparse_rail_gen.py

-21
This file was deleted.

0 commit comments

Comments
 (0)