1
1
"""
2
2
TransitionMap and derived classes.
3
3
"""
4
+ import traceback
5
+ import uuid
4
6
from functools import lru_cache
5
7
6
8
import numpy as np
7
9
from importlib_resources import path
8
10
from numpy import array
9
- import traceback
10
11
11
12
from flatland .core .grid .grid4 import Grid4Transitions
12
13
from flatland .core .grid .grid4_utils import get_new_position , get_direction
13
14
from flatland .core .grid .grid_utils import IntVector2DArray , IntVector2D
14
15
from flatland .core .grid .grid_utils import Vec2dOperations as Vec2d
15
16
from flatland .core .grid .rail_env_grid import RailEnvTransitions
16
17
from flatland .core .transitions import Transitions
17
- from flatland .utils .decorators import enable_infrastructure_lru_cache , send_infrastructure_data_change_signal_to_reset_lru_cache
18
18
from flatland .utils .ordered_set import OrderedSet
19
19
20
20
@@ -135,7 +135,6 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]
135
135
grid.
136
136
137
137
"""
138
- send_infrastructure_data_change_signal_to_reset_lru_cache ()
139
138
self .width = width
140
139
self .height = height
141
140
self .transitions = transitions
@@ -145,8 +144,19 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]
145
144
else :
146
145
self .random_generator .seed (random_seed )
147
146
self .grid = np .zeros ((height , width ), dtype = self .transitions .get_type ())
147
+ self ._reset_cache ()
148
+
149
+ def _reset_cache (self ):
150
+ # use __eq__ and __hash__ to control cache lifecycle of instance methods, see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls.
151
+ self .uuid = uuid .uuid4 ().int
152
+
153
+ def __eq__ (self , __value ):
154
+ return isinstance (__value , GridTransitionMap ) and self .uuid == __value .uuid
155
+
156
+ def __hash__ (self ):
157
+ return self .uuid
148
158
149
- @enable_infrastructure_lru_cache (maxsize = 1_000_000 )
159
+ @lru_cache (maxsize = 1_000_000 )
150
160
def get_full_transitions (self , row , column ):
151
161
"""
152
162
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
@@ -165,7 +175,7 @@ def get_full_transitions(self, row, column):
165
175
"""
166
176
return self .grid [(row , column )]
167
177
168
- @enable_infrastructure_lru_cache (maxsize = 4_000_000 )
178
+ @lru_cache (maxsize = 4_000_000 )
169
179
def get_transitions (self , row , column , orientation ):
170
180
"""
171
181
Return a tuple of transitions available in a cell specified by
@@ -206,17 +216,17 @@ def set_transitions(self, cell_id, new_transitions):
206
216
Tuple of new transitions validitiy for the cell.
207
217
208
218
"""
209
- send_infrastructure_data_change_signal_to_reset_lru_cache ()
210
- #assert len(cell_id) in (2, 3), \
219
+ self . _reset_cache ()
220
+ # assert len(cell_id) in (2, 3), \
211
221
# 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.'
212
222
if len (cell_id ) == 3 :
213
223
self .grid [cell_id [0 :2 ]] = self .transitions .set_transitions (self .grid [cell_id [0 :2 ]],
214
- cell_id [2 ],
215
- new_transitions )
224
+ cell_id [2 ],
225
+ new_transitions )
216
226
elif len (cell_id ) == 2 :
217
227
self .grid [cell_id ] = new_transitions
218
228
219
- @enable_infrastructure_lru_cache (maxsize = 4_000_000 )
229
+ @lru_cache (maxsize = 4_000_000 )
220
230
def get_transition (self , cell_id , transition_index ):
221
231
"""
222
232
Return the status of whether an agent in cell `cell_id` can perform a
@@ -240,7 +250,7 @@ def get_transition(self, cell_id, transition_index):
240
250
0/1 allowed/not allowed, a probability in [0,1], etc...)
241
251
242
252
"""
243
- #assert len(cell_id) == 3, \
253
+ # assert len(cell_id) == 3, \
244
254
# 'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.'
245
255
return self .transitions .get_transition (self .grid [cell_id [0 :2 ]], cell_id [2 ], transition_index )
246
256
@@ -264,39 +274,39 @@ def set_transition(self, cell_id, transition_index, new_transition, remove_deade
264
274
0/1 allowed/not allowed, a probability in [0,1], etc...)
265
275
266
276
"""
267
- send_infrastructure_data_change_signal_to_reset_lru_cache ()
268
- #assert len(cell_id) == 3, \
277
+ self . _reset_cache ()
278
+ # assert len(cell_id) == 3, \
269
279
# 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'
270
280
271
281
nDir = cell_id [2 ]
272
282
if type (nDir ) == np .ndarray :
273
283
# I can't work out how to dump a complete backtrace here
274
284
try :
275
- assert type (nDir )== int , "cell direction is not an int"
285
+ assert type (nDir ) == int , "cell direction is not an int"
276
286
except Exception as e :
277
287
traceback .print_stack ()
278
288
print ("fixing nDir:" , cell_id , nDir )
279
289
nDir = int (nDir [0 ])
280
290
281
- #if type(transition_index) not in (int, np.int64):
291
+ # if type(transition_index) not in (int, np.int64):
282
292
if isinstance (transition_index , np .ndarray ):
283
- #print("fixing transition_index:", cell_id, transition_index)
293
+ # print("fixing transition_index:", cell_id, transition_index)
284
294
if type (transition_index ) == np .ndarray :
285
295
transition_index = int (transition_index .ravel ()[0 ])
286
296
else :
287
297
# print("transition_index type:", type(transition_index))
288
298
transition_index = int (transition_index )
289
299
290
- #if type(new_transition) not in (int, bool):
300
+ # if type(new_transition) not in (int, bool):
291
301
if isinstance (new_transition , np .ndarray ):
292
- #print("fixing new_transition:", cell_id, new_transition)
302
+ # print("fixing new_transition:", cell_id, new_transition)
293
303
new_transition = int (new_transition .ravel ()[0 ])
294
304
295
- #print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)
305
+ # print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)
296
306
297
307
self .grid [cell_id [0 ]][cell_id [1 ]] = self .transitions .set_transition (
298
308
self .grid [cell_id [0 :2 ]],
299
- nDir , # cell_id[2],
309
+ nDir , # cell_id[2],
300
310
transition_index ,
301
311
new_transition ,
302
312
remove_deadends )
@@ -332,7 +342,7 @@ def load_transition_map(self, package, resource):
332
342
(height,width) )
333
343
334
344
"""
335
- send_infrastructure_data_change_signal_to_reset_lru_cache ()
345
+ self . _reset_cache ()
336
346
with path (package , resource ) as file_in :
337
347
new_grid = np .load (file_in )
338
348
@@ -343,7 +353,7 @@ def load_transition_map(self, package, resource):
343
353
self .height = new_height
344
354
self .grid = new_grid
345
355
346
- @enable_infrastructure_lru_cache (maxsize = 1_000_000 )
356
+ @lru_cache (maxsize = 1_000_000 )
347
357
def is_dead_end (self , rcPos : IntVector2DArray ):
348
358
"""
349
359
Check if the cell is a dead-end.
@@ -360,7 +370,7 @@ def is_dead_end(self, rcPos: IntVector2DArray):
360
370
cell_transition = self .get_full_transitions (rcPos [0 ], rcPos [1 ])
361
371
return Grid4Transitions .has_deadend (cell_transition )
362
372
363
- @enable_infrastructure_lru_cache (maxsize = 1_000_000 )
373
+ @lru_cache (maxsize = 1_000_000 )
364
374
def is_simple_turn (self , rcPos : IntVector2DArray ):
365
375
"""
366
376
Check if the cell is a left/right simple turn
@@ -388,7 +398,7 @@ def is_simple_turn(trans):
388
398
389
399
return is_simple_turn (tmp )
390
400
391
- @enable_infrastructure_lru_cache (maxsize = 4_000_000 )
401
+ @lru_cache (maxsize = 4_000_000 )
392
402
def check_path_exists (self , start : IntVector2DArray , direction : int , end : IntVector2DArray ):
393
403
"""
394
404
Breath first search for a possible path from one node with a certain orientation to a target node.
@@ -417,7 +427,7 @@ def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVec
417
427
418
428
return False
419
429
420
- @enable_infrastructure_lru_cache (maxsize = 1_000_000 )
430
+ @lru_cache (maxsize = 1_000_000 )
421
431
def cell_neighbours_valid (self , rcPos : IntVector2DArray , check_this_cell = False ):
422
432
"""
423
433
Check validity of cell at rcPos = tuple(row, column)
@@ -504,7 +514,7 @@ def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
504
514
505
515
Returns: True (valid) or False (invalid)
506
516
"""
507
- send_infrastructure_data_change_signal_to_reset_lru_cache ()
517
+ self . _reset_cache ()
508
518
cell_transition = self .grid [tuple (rcPos )]
509
519
510
520
if check_this_cell :
@@ -548,7 +558,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
548
558
"""
549
559
Fixes broken transitions
550
560
"""
551
- send_infrastructure_data_change_signal_to_reset_lru_cache ()
561
+ self . _reset_cache ()
552
562
gDir2dRC = self .transitions .gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
553
563
grcPos = array (rcPos )
554
564
grcMax = self .grid .shape
@@ -625,7 +635,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
625
635
self .set_transitions ((rcPos [0 ], rcPos [1 ]), transition )
626
636
return True
627
637
628
- @enable_infrastructure_lru_cache (maxsize = 1_000_000 )
638
+ @lru_cache (maxsize = 1_000_000 )
629
639
def validate_new_transition (self , prev_pos : IntVector2D , current_pos : IntVector2D ,
630
640
new_pos : IntVector2D , end_pos : IntVector2D ):
631
641
"""
0 commit comments