Skip to content

Commit ff03b9d

Browse files
committed
Add p_level_free in rail gen to sample percentage of level free diamond-crossings from generated diamond crossings.
1 parent e0ab0b5 commit ff03b9d

4 files changed

+182
-42
lines changed

flatland/envs/rail_env.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
Definition of the RailEnv environment.
33
"""
44
import random
5-
from typing import List, Optional, Dict, Tuple
5+
from typing import List, Optional, Dict, Tuple, Set
66

77
import numpy as np
88

99
import flatland.envs.timetable_generators as ttg
1010
from flatland.core.env import Environment
1111
from flatland.core.env_observation_builder import ObservationBuilder
1212
from flatland.core.grid.grid4 import Grid4Transitions
13+
from flatland.core.grid.grid_utils import Vector2D
1314
from flatland.core.transition_map import GridTransitionMap
1415
from flatland.envs import agent_chains as ac
1516
from flatland.envs import line_generators as line_gen
@@ -205,7 +206,7 @@ def __init__(self,
205206

206207
self.motionCheck = ac.MotionCheck()
207208

208-
self.level_free = set()
209+
self.level_free_positions: Set[Vector2D] = set()
209210

210211
def _seed(self, seed):
211212
self.np_random, seed = seeding.np_random(seed)
@@ -315,6 +316,8 @@ def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True,
315316
agents_hints = None
316317
if optionals and 'agents_hints' in optionals:
317318
agents_hints = optionals['agents_hints']
319+
if optionals and 'level_free_positions' in optionals:
320+
self.level_free_positions = optionals['level_free_positions']
318321

319322
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
320323
self.num_resets, self.np_random)
@@ -565,10 +568,10 @@ def step(self, action_dict: Dict[int, RailEnvActions]):
565568

566569
# only conflict if the level-free cell is traversed through the same axis (horizontally (0 north or 2 south), or vertically (1 east or 3 west)
567570
new_position_level_free = new_position
568-
if new_position in self.level_free:
571+
if new_position in self.level_free_positions:
569572
new_position_level_free = (new_position, new_direction % 2)
570573
agent_position_level_free = agent.position
571-
if agent.position in self.level_free:
574+
if agent.position in self.level_free_positions:
572575
agent_position_level_free = (agent.position, agent.direction % 2)
573576

574577
# This is for storing and later checking for conflicts of agents trying to occupy same cell
@@ -581,7 +584,7 @@ def step(self, action_dict: Dict[int, RailEnvActions]):
581584
i_agent = agent.handle
582585

583586
agent_position_level_free = agent.position
584-
if agent.position in self.level_free:
587+
if agent.position in self.level_free_positions:
585588
agent_position_level_free = (agent.position, agent.direction % 2)
586589

587590
## Update positions
@@ -740,6 +743,7 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar
740743
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
741744
show_predictions=show_predictions,
742745
show_rowcols=show_rowcols, return_image=return_image)
746+
743747
def initialize_renderer(self, mode, gl,
744748
agent_render_variant,
745749
show_debug,

flatland/envs/rail_generators.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Rail generators: infrastructure manager (IM) / Infrastrukturbetreiber (ISB)."""
2+
import math
23
import warnings
34
from typing import Callable, Tuple, Optional, Dict, List
45

@@ -9,7 +10,7 @@
910
from flatland.core.grid.grid4_utils import direction_to_point
1011
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D, \
1112
Vec2dOperations
12-
from flatland.core.grid.rail_env_grid import RailEnvTransitions
13+
from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum
1314
from flatland.core.transition_map import GridTransitionMap
1415
from flatland.envs import persistence
1516
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
@@ -35,8 +36,7 @@ def __init__(self, *args, **kwargs):
3536
"""
3637
pass
3738

38-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
39-
np_random: RandomState = None) -> RailGeneratorProduct:
39+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
4040
pass
4141

4242
def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
@@ -53,8 +53,7 @@ class EmptyRailGen(RailGen):
5353
Primarily used by the editor
5454
"""
5555

56-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
57-
np_random: RandomState = None) -> RailGenerator:
56+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator:
5857
rail_trans = RailEnvTransitions()
5958
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
6059
rail_array = grid_map.grid
@@ -100,8 +99,7 @@ def __init__(self, rail_map, optionals=None):
10099
self.rail_map = rail_map
101100
self.optionals = optionals
102101

103-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
104-
np_random: RandomState = None) -> RailGeneratorProduct:
102+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
105103
return self.rail_map, self.optionals
106104

107105

@@ -116,7 +114,7 @@ def sparse_rail_generator(*args, **kwargs):
116114
class SparseRailGen(RailGen):
117115

118116
def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
119-
max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
117+
max_rail_pairs_in_city: int = 2, seed=None, p_level_free: float = 0) -> RailGenerator:
120118
"""
121119
Generates railway networks with cities and inner city rails
122120
@@ -133,6 +131,8 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
133131
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
134132
seed: int
135133
Initiate the seed
134+
p_level_free : float
135+
Percentage of diamond-crossings which are level-free.
136136
137137
Returns
138138
-------
@@ -143,9 +143,9 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
143143
self.max_rails_between_cities = max_rails_between_cities
144144
self.max_rail_pairs_in_city = max_rail_pairs_in_city
145145
self.seed = seed
146+
self.p_level_free = p_level_free
146147

147-
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
148-
np_random: RandomState = None) -> RailGenerator:
148+
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
149149
"""
150150
151151
Parameters
@@ -181,7 +181,7 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
181181
min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
182182
rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3)
183183
rails_between_cities = (rail_pairs_in_city * 2) if self.max_rails_between_cities > (
184-
rail_pairs_in_city * 2) else self.max_rails_between_cities
184+
rail_pairs_in_city * 2) else self.max_rails_between_cities
185185

186186
# We compute the city radius by the given max number of rails it can contain.
187187
# The radius is equal to the number of tracks divided by 2
@@ -237,11 +237,28 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
237237

238238
# Fix all transition elements
239239
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
240-
return grid_map, {'agents_hints': {
241-
'city_positions': city_positions,
242-
'train_stations': train_stations,
243-
'city_orientations': city_orientations
244-
}}
240+
241+
# choose p_level_free percentage of diamond crossings to be level-free
242+
num_diamond_crossings = np.count_nonzero(grid_map.grid[grid_map.grid == RailEnvTransitionsEnum.diamond_crossing])
243+
num_level_free_diamond_crossings = math.floor(self.p_level_free * num_diamond_crossings)
244+
# ceil with probability p_ceail
245+
p_ceil = (self.p_level_free * num_diamond_crossings) % 1.0
246+
num_level_free_diamond_crossings += np_random.choice([1, 0], p=(p_ceil, 1 - p_ceil))
247+
level_free_positions = set()
248+
if num_level_free_diamond_crossings > 0:
249+
choice = np_random.choice(num_diamond_crossings, size=num_level_free_diamond_crossings, replace=False)
250+
positions_diamond_crossings = (grid_map.grid == RailEnvTransitionsEnum.diamond_crossing).nonzero()
251+
level_free_positions = {tuple(positions_diamond_crossings[choice[i]]) for i in range(len(choice))}
252+
253+
return grid_map, {
254+
'agents_hints':
255+
{
256+
'city_positions': city_positions,
257+
'train_stations': train_stations,
258+
'city_orientations': city_orientations
259+
},
260+
'level_free_positions': level_free_positions
261+
}
245262

246263
def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
247264
height: int, np_random: RandomState = None) -> Tuple[
@@ -264,7 +281,6 @@ def _generate_random_city_positions(self, num_cities: int, city_radius: int, wid
264281
Returns
265282
-------
266283
Returns a list of all city positions as coordinates (x,y)
267-
268284
"""
269285

270286
city_positions: IntVector2DArray = []
@@ -322,7 +338,6 @@ def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: in
322338
Returns
323339
-------
324340
Returns a list of all city positions as coordinates (x,y)
325-
326341
"""
327342
aspect_ratio = height / width
328343
# Compute max numbe of possible cities per row and col.
@@ -526,7 +541,7 @@ def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction)
526541
"""
527542
Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
528543
direction. Direction is a 90 degree cone facing the desired directiont.
529-
Exampe:
544+
Example:
530545
North: The closes neighbour in the North direction is within the cone spanned by a line going
531546
North-West and North-East
532547
@@ -677,7 +692,6 @@ def _fix_transitions(self, city_cells: set, inter_city_lines: List[IntVector2DAr
677692
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
678693
grid_map: RailEnvTransitions
679694
The grid map containing the rails. Used to draw new rails
680-
681695
"""
682696

683697
# Fix all cities with illegal transition maps

0 commit comments

Comments
 (0)