Skip to content

Commit de36600

Browse files
authored
118 Add test_lru_cache_problem.py. (#119)
* Add test_lru_cache_problem.py. * Use methodtools.lru_cache for proper instance method caching. Keep GridTransitionMap object lifecycle in sync with cache lifecycle. * Get rid of methodtools and go back to functools.lru_cache using proper eq/hash definition on instance methods. * Update requirements.txt * Add cache size assertions.
1 parent 2538de0 commit de36600

7 files changed

+279
-49
lines changed

flatland/core/transition_map.py

+38-28
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""
22
TransitionMap and derived classes.
33
"""
4+
import traceback
5+
import uuid
46
from functools import lru_cache
57

68
import numpy as np
79
from importlib_resources import path
810
from numpy import array
9-
import traceback
1011

1112
from flatland.core.grid.grid4 import Grid4Transitions
1213
from flatland.core.grid.grid4_utils import get_new_position, get_direction
1314
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
1415
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
1516
from flatland.core.grid.rail_env_grid import RailEnvTransitions
1617
from flatland.core.transitions import Transitions
17-
from flatland.utils.decorators import enable_infrastructure_lru_cache, send_infrastructure_data_change_signal_to_reset_lru_cache
1818
from flatland.utils.ordered_set import OrderedSet
1919

2020

@@ -135,7 +135,6 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]
135135
grid.
136136
137137
"""
138-
send_infrastructure_data_change_signal_to_reset_lru_cache()
139138
self.width = width
140139
self.height = height
141140
self.transitions = transitions
@@ -145,8 +144,19 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]
145144
else:
146145
self.random_generator.seed(random_seed)
147146
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
147+
self._reset_cache()
148+
149+
def _reset_cache(self):
150+
# use __eq__ and __hash__ to control cache lifecycle of instance methods, see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls.
151+
self.uuid = uuid.uuid4().int
152+
153+
def __eq__(self, __value):
154+
return isinstance(__value, GridTransitionMap) and self.uuid == __value.uuid
155+
156+
def __hash__(self):
157+
return self.uuid
148158

149-
@enable_infrastructure_lru_cache(maxsize=1_000_000)
159+
@lru_cache(maxsize=1_000_000)
150160
def get_full_transitions(self, row, column):
151161
"""
152162
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
@@ -165,7 +175,7 @@ def get_full_transitions(self, row, column):
165175
"""
166176
return self.grid[(row, column)]
167177

168-
@enable_infrastructure_lru_cache(maxsize=4_000_000)
178+
@lru_cache(maxsize=4_000_000)
169179
def get_transitions(self, row, column, orientation):
170180
"""
171181
Return a tuple of transitions available in a cell specified by
@@ -206,17 +216,17 @@ def set_transitions(self, cell_id, new_transitions):
206216
Tuple of new transitions validitiy for the cell.
207217
208218
"""
209-
send_infrastructure_data_change_signal_to_reset_lru_cache()
210-
#assert len(cell_id) in (2, 3), \
219+
self._reset_cache()
220+
# assert len(cell_id) in (2, 3), \
211221
# 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.'
212222
if len(cell_id) == 3:
213223
self.grid[cell_id[0:2]] = self.transitions.set_transitions(self.grid[cell_id[0:2]],
214-
cell_id[2],
215-
new_transitions)
224+
cell_id[2],
225+
new_transitions)
216226
elif len(cell_id) == 2:
217227
self.grid[cell_id] = new_transitions
218228

219-
@enable_infrastructure_lru_cache(maxsize=4_000_000)
229+
@lru_cache(maxsize=4_000_000)
220230
def get_transition(self, cell_id, transition_index):
221231
"""
222232
Return the status of whether an agent in cell `cell_id` can perform a
@@ -240,7 +250,7 @@ def get_transition(self, cell_id, transition_index):
240250
0/1 allowed/not allowed, a probability in [0,1], etc...)
241251
242252
"""
243-
#assert len(cell_id) == 3, \
253+
# assert len(cell_id) == 3, \
244254
# 'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.'
245255
return self.transitions.get_transition(self.grid[cell_id[0:2]], cell_id[2], transition_index)
246256

@@ -264,39 +274,39 @@ def set_transition(self, cell_id, transition_index, new_transition, remove_deade
264274
0/1 allowed/not allowed, a probability in [0,1], etc...)
265275
266276
"""
267-
send_infrastructure_data_change_signal_to_reset_lru_cache()
268-
#assert len(cell_id) == 3, \
277+
self._reset_cache()
278+
# assert len(cell_id) == 3, \
269279
# 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'
270280

271281
nDir = cell_id[2]
272282
if type(nDir) == np.ndarray:
273283
# I can't work out how to dump a complete backtrace here
274284
try:
275-
assert type(nDir)==int, "cell direction is not an int"
285+
assert type(nDir) == int, "cell direction is not an int"
276286
except Exception as e:
277287
traceback.print_stack()
278288
print("fixing nDir:", cell_id, nDir)
279289
nDir = int(nDir[0])
280290

281-
#if type(transition_index) not in (int, np.int64):
291+
# if type(transition_index) not in (int, np.int64):
282292
if isinstance(transition_index, np.ndarray):
283-
#print("fixing transition_index:", cell_id, transition_index)
293+
# print("fixing transition_index:", cell_id, transition_index)
284294
if type(transition_index) == np.ndarray:
285295
transition_index = int(transition_index.ravel()[0])
286296
else:
287297
# print("transition_index type:", type(transition_index))
288298
transition_index = int(transition_index)
289299

290-
#if type(new_transition) not in (int, bool):
300+
# if type(new_transition) not in (int, bool):
291301
if isinstance(new_transition, np.ndarray):
292-
#print("fixing new_transition:", cell_id, new_transition)
302+
# print("fixing new_transition:", cell_id, new_transition)
293303
new_transition = int(new_transition.ravel()[0])
294304

295-
#print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)
305+
# print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)
296306

297307
self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition(
298308
self.grid[cell_id[0:2]],
299-
nDir, # cell_id[2],
309+
nDir, # cell_id[2],
300310
transition_index,
301311
new_transition,
302312
remove_deadends)
@@ -332,7 +342,7 @@ def load_transition_map(self, package, resource):
332342
(height,width) )
333343
334344
"""
335-
send_infrastructure_data_change_signal_to_reset_lru_cache()
345+
self._reset_cache()
336346
with path(package, resource) as file_in:
337347
new_grid = np.load(file_in)
338348

@@ -343,7 +353,7 @@ def load_transition_map(self, package, resource):
343353
self.height = new_height
344354
self.grid = new_grid
345355

346-
@enable_infrastructure_lru_cache(maxsize=1_000_000)
356+
@lru_cache(maxsize=1_000_000)
347357
def is_dead_end(self, rcPos: IntVector2DArray):
348358
"""
349359
Check if the cell is a dead-end.
@@ -360,7 +370,7 @@ def is_dead_end(self, rcPos: IntVector2DArray):
360370
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
361371
return Grid4Transitions.has_deadend(cell_transition)
362372

363-
@enable_infrastructure_lru_cache(maxsize=1_000_000)
373+
@lru_cache(maxsize=1_000_000)
364374
def is_simple_turn(self, rcPos: IntVector2DArray):
365375
"""
366376
Check if the cell is a left/right simple turn
@@ -388,7 +398,7 @@ def is_simple_turn(trans):
388398

389399
return is_simple_turn(tmp)
390400

391-
@enable_infrastructure_lru_cache(maxsize=4_000_000)
401+
@lru_cache(maxsize=4_000_000)
392402
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
393403
"""
394404
Breath first search for a possible path from one node with a certain orientation to a target node.
@@ -417,7 +427,7 @@ def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVec
417427

418428
return False
419429

420-
@enable_infrastructure_lru_cache(maxsize=1_000_000)
430+
@lru_cache(maxsize=1_000_000)
421431
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
422432
"""
423433
Check validity of cell at rcPos = tuple(row, column)
@@ -504,7 +514,7 @@ def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
504514
505515
Returns: True (valid) or False (invalid)
506516
"""
507-
send_infrastructure_data_change_signal_to_reset_lru_cache()
517+
self._reset_cache()
508518
cell_transition = self.grid[tuple(rcPos)]
509519

510520
if check_this_cell:
@@ -548,7 +558,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
548558
"""
549559
Fixes broken transitions
550560
"""
551-
send_infrastructure_data_change_signal_to_reset_lru_cache()
561+
self._reset_cache()
552562
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
553563
grcPos = array(rcPos)
554564
grcMax = self.grid.shape
@@ -625,7 +635,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
625635
self.set_transitions((rcPos[0], rcPos[1]), transition)
626636
return True
627637

628-
@enable_infrastructure_lru_cache(maxsize=1_000_000)
638+
@lru_cache(maxsize=1_000_000)
629639
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
630640
new_pos: IntVector2D, end_pos: IntVector2D):
631641
"""

flatland/envs/persistence.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,19 @@ def set_full_state(cls, env, env_dict):
187187
-------
188188
env_dict: dict
189189
"""
190-
env.rail.grid = np.array(env_dict["grid"])
190+
grid = np.array(env_dict["grid"])
191191

192192
# Initialise the env with the frozen agents in the file
193193
env.agents = env_dict.get("agents", [])
194194

195195
# For consistency, set number_of_agents, which is the number which will be generated on reset
196196
env.number_of_agents = env.get_num_agents()
197197

198-
env.height, env.width = env.rail.grid.shape
199-
env.rail.height = env.height
200-
env.rail.width = env.width
198+
env.height, env.width = grid.shape
199+
200+
# use new rail object instance for lru cache scoping and garbage collection to work properly
201+
env.rail = GridTransitionMap(height=env.height, width=env.width)
202+
env.rail.grid = grid
201203
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
202204

203205
# TODO merge with https://github.com/flatland-association/flatland-rl/pull/97/files

flatland/envs/rail_env.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Definition of the RailEnv environment.
33
"""
44
import random
5+
from functools import lru_cache
56
from typing import List, Optional, Dict, Tuple
67

78
import numpy as np
@@ -26,8 +27,7 @@
2627
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
2728
from flatland.envs.step_utils.transition_utils import check_valid_action
2829
from flatland.utils import seeding
29-
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \
30-
enable_infrastructure_lru_cache
30+
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache
3131
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
3232

3333

@@ -239,8 +239,9 @@ def reset_agents(self):
239239
agent.reset()
240240
self.active_agents = [i for i in range(len(self.agents))]
241241

242-
@enable_infrastructure_lru_cache()
243-
def action_required(self, agent_state, is_cell_entry):
242+
@lru_cache()
243+
@staticmethod
244+
def action_required(agent_state, is_cell_entry):
244245
"""
245246
Check if an agent needs to provide an action
246247
@@ -459,7 +460,7 @@ def get_info_dict(self):
459460
state - State from the trains's state machine
460461
"""
461462
info_dict = {
462-
'action_required': {i: self.action_required(agent.state, agent.speed_counter.is_cell_entry)
463+
'action_required': {i: RailEnv.action_required(agent.state, agent.speed_counter.is_cell_entry)
463464
for i, agent in enumerate(self.agents)},
464465
'malfunction': {
465466
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)

flatland/utils/decorators.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
infrastructure_lru_cache_functions = []
44

55

6+
# TODO https://github.com/flatland-association/flatland-rl/issues/104 1. revise which caches need to be scoped at all - some seem not to require cache clearing at all. 2. refactor with need to explicitly reset cache in calls dispersed in the whole code base. Use classes to group the cache scope by overriding eq/hash for instance method lru caching (see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls)
67
def enable_infrastructure_lru_cache(*args, **kwargs):
78
def decorator(func):
89
func = lru_cache(*args, **kwargs)(func)
@@ -12,6 +13,8 @@ def decorator(func):
1213
return decorator
1314

1415

16+
# send_infrastructure_data_change_signal_to_reset_lru_cache() has a problem with instance methods - the methods are not properly cleared by it.
17+
# Therefore, make sure to override eq/hash to control cache lifecycle for instance method lru caching (see https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-instance-methods-release-object and https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls)
1518
def send_infrastructure_data_change_signal_to_reset_lru_cache():
1619
for func in infrastructure_lru_cache_functions:
1720
func.cache_clear()

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ dependencies = [
4646
"setuptools",
4747
"svgutils",
4848
"timeout_decorator",
49-
5049
]
5150
dynamic = ["version"]
5251

tests/test_flatland_core_transitions.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from flatland.core.grid.grid8 import Grid8Transitions
77
from flatland.core.grid.rail_env_grid import RailEnvTransitions
88
from flatland.core.transition_map import GridTransitionMap
9-
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache
109

1110

1211
# remove whitespace in string; keep whitespace below for easier reading
@@ -127,28 +126,23 @@ def test_adding_new_valid_transition():
127126
assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True)
128127

129128
# adding invalid turn
130-
send_infrastructure_data_change_signal_to_reset_lru_cache()
131-
grid_map.grid[(5, 5)] = rail_trans.transitions[2]
129+
grid_map.set_transitions((5, 5), rail_trans.transitions[2])
132130
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
133131

134132
# should create #4 -> valid
135-
send_infrastructure_data_change_signal_to_reset_lru_cache()
136-
grid_map.grid[(5, 5)] = rail_trans.transitions[3]
133+
grid_map.set_transitions((5, 5), rail_trans.transitions[3])
137134
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True)
138135

139136
# adding invalid turn
140-
send_infrastructure_data_change_signal_to_reset_lru_cache()
141-
grid_map.grid[(5, 5)] = rail_trans.transitions[7]
137+
grid_map.set_transitions((5, 5), rail_trans.transitions[7])
142138
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
143139

144140
# test path start condition
145-
send_infrastructure_data_change_signal_to_reset_lru_cache()
146-
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
141+
grid_map.set_transitions((5, 5), rail_trans.transitions[3])
147142
assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True)
148143

149144
# test path end condition
150-
send_infrastructure_data_change_signal_to_reset_lru_cache()
151-
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
145+
grid_map.set_transitions((5, 5), rail_trans.transitions[3])
152146
assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True)
153147

154148

0 commit comments

Comments
 (0)