1
1
"""Rail generators: infrastructure manager (IM) / Infrastrukturbetreiber (ISB)."""
2
+ import math
2
3
import warnings
3
4
from typing import Callable , Tuple , Optional , Dict , List
4
5
9
10
from flatland .core .grid .grid4_utils import direction_to_point
10
11
from flatland .core .grid .grid_utils import IntVector2DArray , IntVector2D , \
11
12
Vec2dOperations
12
- from flatland .core .grid .rail_env_grid import RailEnvTransitions
13
+ from flatland .core .grid .rail_env_grid import RailEnvTransitions , RailEnvTransitionsEnum
13
14
from flatland .core .transition_map import GridTransitionMap
14
15
from flatland .envs import persistence
15
16
from flatland .envs .grid4_generators_utils import connect_rail_in_grid_map , connect_straight_line_in_grid_map , \
@@ -35,8 +36,7 @@ def __init__(self, *args, **kwargs):
35
36
"""
36
37
pass
37
38
38
- def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 ,
39
- np_random : RandomState = None ) -> RailGeneratorProduct :
39
+ def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 , np_random : RandomState = None ) -> RailGeneratorProduct :
40
40
pass
41
41
42
42
def __call__ (self , * args , ** kwargs ) -> RailGeneratorProduct :
@@ -53,8 +53,7 @@ class EmptyRailGen(RailGen):
53
53
Primarily used by the editor
54
54
"""
55
55
56
- def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 ,
57
- np_random : RandomState = None ) -> RailGenerator :
56
+ def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 , np_random : RandomState = None ) -> RailGenerator :
58
57
rail_trans = RailEnvTransitions ()
59
58
grid_map = GridTransitionMap (width = width , height = height , transitions = rail_trans )
60
59
rail_array = grid_map .grid
@@ -100,8 +99,7 @@ def __init__(self, rail_map, optionals=None):
100
99
self .rail_map = rail_map
101
100
self .optionals = optionals
102
101
103
- def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 ,
104
- np_random : RandomState = None ) -> RailGeneratorProduct :
102
+ def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 , np_random : RandomState = None ) -> RailGeneratorProduct :
105
103
return self .rail_map , self .optionals
106
104
107
105
@@ -116,7 +114,7 @@ def sparse_rail_generator(*args, **kwargs):
116
114
class SparseRailGen (RailGen ):
117
115
118
116
def __init__ (self , max_num_cities : int = 2 , grid_mode : bool = False , max_rails_between_cities : int = 2 ,
119
- max_rail_pairs_in_city : int = 2 , seed = None ) -> RailGenerator :
117
+ max_rail_pairs_in_city : int = 2 , seed = None , p_level_free : float = 0 ) -> RailGenerator :
120
118
"""
121
119
Generates railway networks with cities and inner city rails
122
120
@@ -133,6 +131,8 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
133
131
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
134
132
seed: int
135
133
Initiate the seed
134
+ p_level_free : float
135
+ Percentage of diamond-crossings which are level-free.
136
136
137
137
Returns
138
138
-------
@@ -143,9 +143,9 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
143
143
self .max_rails_between_cities = max_rails_between_cities
144
144
self .max_rail_pairs_in_city = max_rail_pairs_in_city
145
145
self .seed = seed
146
+ self .p_level_free = p_level_free
146
147
147
- def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 ,
148
- np_random : RandomState = None ) -> RailGenerator :
148
+ def generate (self , width : int , height : int , num_agents : int , num_resets : int = 0 , np_random : RandomState = None ) -> RailGeneratorProduct :
149
149
"""
150
150
151
151
Parameters
@@ -181,7 +181,7 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
181
181
min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
182
182
rail_pairs_in_city = min_nr_rail_pairs_in_city if self .max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self .max_rail_pairs_in_city # (pairs can be 1,2,3)
183
183
rails_between_cities = (rail_pairs_in_city * 2 ) if self .max_rails_between_cities > (
184
- rail_pairs_in_city * 2 ) else self .max_rails_between_cities
184
+ rail_pairs_in_city * 2 ) else self .max_rails_between_cities
185
185
186
186
# We compute the city radius by the given max number of rails it can contain.
187
187
# The radius is equal to the number of tracks divided by 2
@@ -237,11 +237,28 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
237
237
238
238
# Fix all transition elements
239
239
self ._fix_transitions (city_cells , inter_city_lines , grid_map , vector_field )
240
- return grid_map , {'agents_hints' : {
241
- 'city_positions' : city_positions ,
242
- 'train_stations' : train_stations ,
243
- 'city_orientations' : city_orientations
244
- }}
240
+
241
+ # choose p_level_free percentage of diamond crossings to be level-free
242
+ num_diamond_crossings = np .count_nonzero (grid_map .grid [grid_map .grid == RailEnvTransitionsEnum .diamond_crossing ])
243
+ num_level_free_diamond_crossings = math .floor (self .p_level_free * num_diamond_crossings )
244
+ # ceil with probability p_ceail
245
+ p_ceil = (self .p_level_free * num_diamond_crossings ) % 1.0
246
+ num_level_free_diamond_crossings += np_random .choice ([1 , 0 ], p = (p_ceil , 1 - p_ceil ))
247
+ level_free_positions = set ()
248
+ if num_level_free_diamond_crossings > 0 :
249
+ choice = np_random .choice (num_diamond_crossings , size = num_level_free_diamond_crossings , replace = False )
250
+ positions_diamond_crossings = (grid_map .grid == RailEnvTransitionsEnum .diamond_crossing ).nonzero ()
251
+ level_free_positions = {tuple (positions_diamond_crossings [choice [i ]]) for i in range (len (choice ))}
252
+
253
+ return grid_map , {
254
+ 'agents_hints' :
255
+ {
256
+ 'city_positions' : city_positions ,
257
+ 'train_stations' : train_stations ,
258
+ 'city_orientations' : city_orientations
259
+ },
260
+ 'level_free_positions' : level_free_positions
261
+ }
245
262
246
263
def _generate_random_city_positions (self , num_cities : int , city_radius : int , width : int ,
247
264
height : int , np_random : RandomState = None ) -> Tuple [
@@ -264,7 +281,6 @@ def _generate_random_city_positions(self, num_cities: int, city_radius: int, wid
264
281
Returns
265
282
-------
266
283
Returns a list of all city positions as coordinates (x,y)
267
-
268
284
"""
269
285
270
286
city_positions : IntVector2DArray = []
@@ -322,7 +338,6 @@ def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: in
322
338
Returns
323
339
-------
324
340
Returns a list of all city positions as coordinates (x,y)
325
-
326
341
"""
327
342
aspect_ratio = height / width
328
343
# Compute max numbe of possible cities per row and col.
@@ -526,7 +541,7 @@ def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction)
526
541
"""
527
542
Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
528
543
direction. Direction is a 90 degree cone facing the desired directiont.
529
- Exampe :
544
+ Example :
530
545
North: The closes neighbour in the North direction is within the cone spanned by a line going
531
546
North-West and North-East
532
547
@@ -677,7 +692,6 @@ def _fix_transitions(self, city_cells: set, inter_city_lines: List[IntVector2DAr
677
692
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
678
693
grid_map: RailEnvTransitions
679
694
The grid map containing the rails. Used to draw new rails
680
-
681
695
"""
682
696
683
697
# Fix all cities with illegal transition maps
0 commit comments