From 863cf13df298e6c46562c9ec08c7bbe002ee7539 Mon Sep 17 00:00:00 2001 From: chenkins Date: Sat, 15 Feb 2025 10:11:06 +0100 Subject: [PATCH 1/5] Add test_lru_cache_problem.py. --- tests/test_lru_cache_problem.py | 132 ++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 tests/test_lru_cache_problem.py diff --git a/tests/test_lru_cache_problem.py b/tests/test_lru_cache_problem.py new file mode 100644 index 00000000..c92bcd17 --- /dev/null +++ b/tests/test_lru_cache_problem.py @@ -0,0 +1,132 @@ +from flatland.envs.line_generators import sparse_line_generator +from flatland.envs.persistence import RailEnvPersister +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator + + +def test_lru_load(): + # seed 42 + env_42 = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=1), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) + + env_42.reset(random_seed=42) + transitions_42 = {} + for r in range(30): + for c in range(30): + transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c) + + RailEnvPersister.save(env_42, "env_42.pkl") + + # seed 43 + env_43 = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=2), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) + + env_43.reset(random_seed=43) + transitions_43 = {} + for r in range(30): + for c in range(30): + transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) + # reset clears the cache, so the transitions are indeed different + assert set(transitions_42.items()) != set(transitions_43.items()) + + # seed 42 bis + env_42_bis = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=1), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) + + env_42_bis.reset(random_seed=42) + transitions_42_bis = {} + for r in range(30): + for c in range(30): + transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c) + # sanity check: same seed gives same transitions + assert set(transitions_42.items()) == set(transitions_42_bis.items()) + + # populate cache with infrastructure from seed 43 + env_43 = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=2), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) + env_43.reset(random_seed=43) + transitions_43 = {} + for r in range(30): + for c in range(30): + transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) + # reset clears the cache, so the transitions are indeed different + assert set(transitions_42.items()) != set(transitions_43.items()) + + # load env_42 from file + RailEnvPersister.load(env_43, "env_42.pkl") + env_42_tri = env_43 + + transitions_42_tri = {} + for r in range(30): + for c in range(30): + transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c) + # load() does not invalidate cache (so env_43 transitions are still in the cache) - TODO to be fixed + assert set(transitions_42.items()) != set(transitions_42_tri.items()) + + +def test_lru_load_new(): + # seed 42 + env_42 = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=1), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) + + env_42.reset(random_seed=42) + transitions_42 = {} + for r in range(30): + for c in range(30): + transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c) + + RailEnvPersister.save(env_42, "env_42.pkl") + + # seed 43 + env_43 = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=2), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) + + env_43.reset(random_seed=43) + transitions_43 = {} + for r in range(30): + for c in range(30): + transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) + # reset clears the cache, so the transitions are indeed different + assert set(transitions_42.items()) != set(transitions_43.items()) + + # seed 42 bis + env_42_bis = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=1), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) + + env_42_bis.reset(random_seed=42) + transitions_42_bis = {} + for r in range(30): + for c in range(30): + transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c) + # sanity check: same seed gives same transitions + assert set(transitions_42.items()) == set(transitions_42_bis.items()) + + # populate cache with infrastructure from seed 43 + env_43 = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=2), + line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) + env_43.reset(random_seed=43) + transitions_43 = {} + for r in range(30): + for c in range(30): + transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) + # reset clears the cache, so the transitions are indeed different + assert set(transitions_42.items()) != set(transitions_43.items()) + + # load env_42 from file + # N.B.line `env.rail = GridTransitionMap(1, 1)` in `load_new` has side effect of clearing infrastructure cache, but not `load()` TODO fix load() + env_42_tri, _ = RailEnvPersister.load_new("env_42.pkl") + + transitions_42_tri = {} + for r in range(30): + for c in range(30): + transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c) + # load_new() invalidates cache (so env_43 transitions are cleared) + assert set(transitions_42.items()) == set(transitions_42_tri.items()) From 8c175c71f200ebfdc9e13155e749c9b659453c1e Mon Sep 17 00:00:00 2001 From: chenkins Date: Tue, 18 Feb 2025 09:53:14 +0100 Subject: [PATCH 2/5] Use methodtools.lru_cache for proper instance method caching. Keep GridTransitionMap object lifecycle in sync with cache lifecycle. --- flatland/core/transition_map.py | 28 ++++++++++++++-------------- flatland/envs/persistence.py | 10 ++++++---- flatland/envs/rail_env.py | 11 ++++++----- flatland/utils/decorators.py | 4 ++++ pyproject.toml | 2 +- requirements.txt | 1 + tests/test_lru_cache_problem.py | 9 +++++---- 7 files changed, 37 insertions(+), 28 deletions(-) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 999c3e12..bc5fdc6d 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -1,12 +1,12 @@ """ TransitionMap and derived classes. """ -from functools import lru_cache +import traceback +import methodtools import numpy as np from importlib_resources import path from numpy import array -import traceback from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position, get_direction @@ -14,7 +14,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions -from flatland.utils.decorators import enable_infrastructure_lru_cache, send_infrastructure_data_change_signal_to_reset_lru_cache +from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache from flatland.utils.ordered_set import OrderedSet @@ -146,7 +146,7 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([] self.random_generator.seed(random_seed) self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) - @enable_infrastructure_lru_cache(maxsize=1_000_000) + @methodtools.lru_cache(maxsize=1_000_000) def get_full_transitions(self, row, column): """ Returns the full transitions for the cell at (row, column) in the format transition_map's transitions. @@ -165,7 +165,7 @@ def get_full_transitions(self, row, column): """ return self.grid[(row, column)] - @enable_infrastructure_lru_cache(maxsize=4_000_000) + @methodtools.lru_cache(maxsize=4_000_000) def get_transitions(self, row, column, orientation): """ Return a tuple of transitions available in a cell specified by @@ -207,16 +207,16 @@ def set_transitions(self, cell_id, new_transitions): """ send_infrastructure_data_change_signal_to_reset_lru_cache() - #assert len(cell_id) in (2, 3), \ + # assert len(cell_id) in (2, 3), \ # 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.' if len(cell_id) == 3: self.grid[cell_id[0:2]] = self.transitions.set_transitions(self.grid[cell_id[0:2]], - cell_id[2], - new_transitions) + cell_id[2], + new_transitions) elif len(cell_id) == 2: self.grid[cell_id] = new_transitions - @enable_infrastructure_lru_cache(maxsize=4_000_000) + @methodtools.lru_cache(maxsize=4_000_000) def get_transition(self, cell_id, transition_index): """ Return the status of whether an agent in cell `cell_id` can perform a @@ -343,7 +343,7 @@ def load_transition_map(self, package, resource): self.height = new_height self.grid = new_grid - @enable_infrastructure_lru_cache(maxsize=1_000_000) + @methodtools.lru_cache(maxsize=1_000_000) def is_dead_end(self, rcPos: IntVector2DArray): """ Check if the cell is a dead-end. @@ -360,7 +360,7 @@ def is_dead_end(self, rcPos: IntVector2DArray): cell_transition = self.get_full_transitions(rcPos[0], rcPos[1]) return Grid4Transitions.has_deadend(cell_transition) - @enable_infrastructure_lru_cache(maxsize=1_000_000) + @methodtools.lru_cache(maxsize=1_000_000) def is_simple_turn(self, rcPos: IntVector2DArray): """ Check if the cell is a left/right simple turn @@ -388,7 +388,7 @@ def is_simple_turn(trans): return is_simple_turn(tmp) - @enable_infrastructure_lru_cache(maxsize=4_000_000) + @methodtools.lru_cache(maxsize=4_000_000) def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray): """ Breath first search for a possible path from one node with a certain orientation to a target node. @@ -417,7 +417,7 @@ def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVec return False - @enable_infrastructure_lru_cache(maxsize=1_000_000) + @methodtools.lru_cache(maxsize=1_000_000) def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) @@ -625,7 +625,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1): self.set_transitions((rcPos[0], rcPos[1]), transition) return True - @enable_infrastructure_lru_cache(maxsize=1_000_000) + @methodtools.lru_cache(maxsize=1_000_000) def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D, new_pos: IntVector2D, end_pos: IntVector2D): """ diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index 2e22a9f1..7404d2c6 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -187,7 +187,7 @@ def set_full_state(cls, env, env_dict): ------- env_dict: dict """ - env.rail.grid = np.array(env_dict["grid"]) + grid = np.array(env_dict["grid"]) # Initialise the env with the frozen agents in the file env.agents = env_dict.get("agents", []) @@ -195,9 +195,11 @@ def set_full_state(cls, env, env_dict): # For consistency, set number_of_agents, which is the number which will be generated on reset env.number_of_agents = env.get_num_agents() - env.height, env.width = env.rail.grid.shape - env.rail.height = env.height - env.rail.width = env.width + env.height, env.width = grid.shape + + # use new rail object instance for lru cache scoping and garbage collection to work properly + env.rail = GridTransitionMap(height=env.height, width=env.width) + env.rail.grid = grid env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False) # TODO merge with https://github.com/flatland-association/flatland-rl/pull/97/files diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6c1f33e0..c20d7a07 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -2,6 +2,7 @@ Definition of the RailEnv environment. """ import random +from functools import lru_cache from typing import List, Optional, Dict, Tuple import numpy as np @@ -26,8 +27,7 @@ from flatland.envs.step_utils.states import TrainState, StateTransitionSignals from flatland.envs.step_utils.transition_utils import check_valid_action from flatland.utils import seeding -from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \ - enable_infrastructure_lru_cache +from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache from flatland.utils.rendertools import RenderTool, AgentRenderVariant @@ -239,8 +239,9 @@ def reset_agents(self): agent.reset() self.active_agents = [i for i in range(len(self.agents))] - @enable_infrastructure_lru_cache() - def action_required(self, agent_state, is_cell_entry): + @lru_cache() + @staticmethod + def action_required(agent_state, is_cell_entry): """ Check if an agent needs to provide an action @@ -459,7 +460,7 @@ def get_info_dict(self): state - State from the trains's state machine """ info_dict = { - 'action_required': {i: self.action_required(agent.state, agent.speed_counter.is_cell_entry) + 'action_required': {i: RailEnv.action_required(agent.state, agent.speed_counter.is_cell_entry) for i, agent in enumerate(self.agents)}, 'malfunction': { i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents) diff --git a/flatland/utils/decorators.py b/flatland/utils/decorators.py index dd8b22b4..1d45adc6 100644 --- a/flatland/utils/decorators.py +++ b/flatland/utils/decorators.py @@ -3,6 +3,7 @@ infrastructure_lru_cache_functions = [] +# TODO https://github.com/flatland-association/flatland-rl/issues/104 1. revise which caches need to be scoped at all - some seem not to require cache clearing at all. 2. refactor with need to explicitly reset cache in calls dispersed in the whole code base. Use classes to group the cache scope using methodtools for instance method lru caching. def enable_infrastructure_lru_cache(*args, **kwargs): def decorator(func): func = lru_cache(*args, **kwargs)(func) @@ -12,6 +13,9 @@ def decorator(func): return decorator +# send_infrastructure_data_change_signal_to_reset_lru_cache() has a problem with instance methods - the methods are not properly cleared. +# Therefore, make sure to use methodtools for instance methods and to instantiantiate new objects to match instance and cache lifecycle. +# See https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-instance-methods-release-object def send_infrastructure_data_change_signal_to_reset_lru_cache(): for func in infrastructure_lru_cache_functions: func.cache_clear() diff --git a/pyproject.toml b/pyproject.toml index 96985941..73d47916 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "ipython", "ipywidgets", "matplotlib", + "methodtools", "msgpack_numpy", "msgpack", "networkx", @@ -46,7 +47,6 @@ dependencies = [ "setuptools", "svgutils", "timeout_decorator", - ] dynamic = ["version"] diff --git a/requirements.txt b/requirements.txt index e5569313..9c5b31a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -77,6 +77,7 @@ matplotlib==3.10.0 # seaborn matplotlib-inline==0.1.7 # via ipython +methodtools==0.4.7 msgpack==1.1.0 # via # flatland-rl (pyproject.toml) diff --git a/tests/test_lru_cache_problem.py b/tests/test_lru_cache_problem.py index c92bcd17..edf32af3 100644 --- a/tests/test_lru_cache_problem.py +++ b/tests/test_lru_cache_problem.py @@ -4,6 +4,7 @@ from flatland.envs.rail_generators import sparse_rail_generator +# TODO refactor parametrized load and load_new! def test_lru_load(): # seed 42 env_42 = RailEnv(width=30, height=30, @@ -64,8 +65,8 @@ def test_lru_load(): for r in range(30): for c in range(30): transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c) - # load() does not invalidate cache (so env_43 transitions are still in the cache) - TODO to be fixed - assert set(transitions_42.items()) != set(transitions_42_tri.items()) + # load() now invalidates cache correctly + assert set(transitions_42.items()) == set(transitions_42_tri.items()) def test_lru_load_new(): @@ -120,8 +121,8 @@ def test_lru_load_new(): # reset clears the cache, so the transitions are indeed different assert set(transitions_42.items()) != set(transitions_43.items()) - # load env_42 from file - # N.B.line `env.rail = GridTransitionMap(1, 1)` in `load_new` has side effect of clearing infrastructure cache, but not `load()` TODO fix load() + # load_new() env_42 from file + # N.B.line `env.rail = GridTransitionMap(1, 1)` in `load_new` has side effect of clearing infrastructure cache. env_42_tri, _ = RailEnvPersister.load_new("env_42.pkl") transitions_42_tri = {} From f236fe538fdd4ea2cde052fa1e4ad6c45d409210 Mon Sep 17 00:00:00 2001 From: chenkins Date: Tue, 18 Feb 2025 17:11:28 +0100 Subject: [PATCH 3/5] Get rid of methodtools and go back to functools.lru_cache using proper eq/hash definition on instance methods. --- flatland/core/transition_map.py | 60 ++++++++++++++----------- flatland/utils/decorators.py | 7 ++- pyproject.toml | 1 - tests/test_flatland_core_transitions.py | 16 +++---- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index bc5fdc6d..64c4f5ac 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -2,8 +2,9 @@ TransitionMap and derived classes. """ import traceback +import uuid +from functools import lru_cache -import methodtools import numpy as np from importlib_resources import path from numpy import array @@ -14,7 +15,6 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions -from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache from flatland.utils.ordered_set import OrderedSet @@ -135,7 +135,6 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([] grid. """ - send_infrastructure_data_change_signal_to_reset_lru_cache() self.width = width self.height = height self.transitions = transitions @@ -145,8 +144,19 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([] else: self.random_generator.seed(random_seed) self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) + self._reset_cache() - @methodtools.lru_cache(maxsize=1_000_000) + def _reset_cache(self): + # use __eq__ and __hash__ to control cache lifecycle of instance methods, see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls. + self.uuid = uuid.uuid4().int + + def __eq__(self, __value): + return isinstance(__value, GridTransitionMap) and self.uuid == __value.uuid + + def __hash__(self): + return self.uuid + + @lru_cache(maxsize=1_000_000) def get_full_transitions(self, row, column): """ Returns the full transitions for the cell at (row, column) in the format transition_map's transitions. @@ -165,7 +175,7 @@ def get_full_transitions(self, row, column): """ return self.grid[(row, column)] - @methodtools.lru_cache(maxsize=4_000_000) + @lru_cache(maxsize=4_000_000) def get_transitions(self, row, column, orientation): """ Return a tuple of transitions available in a cell specified by @@ -206,7 +216,7 @@ def set_transitions(self, cell_id, new_transitions): Tuple of new transitions validitiy for the cell. """ - send_infrastructure_data_change_signal_to_reset_lru_cache() + self._reset_cache() # assert len(cell_id) in (2, 3), \ # 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.' if len(cell_id) == 3: @@ -216,7 +226,7 @@ def set_transitions(self, cell_id, new_transitions): elif len(cell_id) == 2: self.grid[cell_id] = new_transitions - @methodtools.lru_cache(maxsize=4_000_000) + @lru_cache(maxsize=4_000_000) def get_transition(self, cell_id, transition_index): """ Return the status of whether an agent in cell `cell_id` can perform a @@ -240,7 +250,7 @@ def get_transition(self, cell_id, transition_index): 0/1 allowed/not allowed, a probability in [0,1], etc...) """ - #assert len(cell_id) == 3, \ + # assert len(cell_id) == 3, \ # 'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.' return self.transitions.get_transition(self.grid[cell_id[0:2]], cell_id[2], transition_index) @@ -264,39 +274,39 @@ def set_transition(self, cell_id, transition_index, new_transition, remove_deade 0/1 allowed/not allowed, a probability in [0,1], etc...) """ - send_infrastructure_data_change_signal_to_reset_lru_cache() - #assert len(cell_id) == 3, \ + self._reset_cache() + # assert len(cell_id) == 3, \ # 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.' nDir = cell_id[2] if type(nDir) == np.ndarray: # I can't work out how to dump a complete backtrace here try: - assert type(nDir)==int, "cell direction is not an int" + assert type(nDir) == int, "cell direction is not an int" except Exception as e: traceback.print_stack() print("fixing nDir:", cell_id, nDir) nDir = int(nDir[0]) - #if type(transition_index) not in (int, np.int64): + # if type(transition_index) not in (int, np.int64): if isinstance(transition_index, np.ndarray): - #print("fixing transition_index:", cell_id, transition_index) + # print("fixing transition_index:", cell_id, transition_index) if type(transition_index) == np.ndarray: transition_index = int(transition_index.ravel()[0]) else: # print("transition_index type:", type(transition_index)) transition_index = int(transition_index) - #if type(new_transition) not in (int, bool): + # if type(new_transition) not in (int, bool): if isinstance(new_transition, np.ndarray): - #print("fixing new_transition:", cell_id, new_transition) + # print("fixing new_transition:", cell_id, new_transition) new_transition = int(new_transition.ravel()[0]) - #print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends) + # print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends) self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition( self.grid[cell_id[0:2]], - nDir, # cell_id[2], + nDir, # cell_id[2], transition_index, new_transition, remove_deadends) @@ -332,7 +342,7 @@ def load_transition_map(self, package, resource): (height,width) ) """ - send_infrastructure_data_change_signal_to_reset_lru_cache() + self._reset_cache() with path(package, resource) as file_in: new_grid = np.load(file_in) @@ -343,7 +353,7 @@ def load_transition_map(self, package, resource): self.height = new_height self.grid = new_grid - @methodtools.lru_cache(maxsize=1_000_000) + @lru_cache(maxsize=1_000_000) def is_dead_end(self, rcPos: IntVector2DArray): """ Check if the cell is a dead-end. @@ -360,7 +370,7 @@ def is_dead_end(self, rcPos: IntVector2DArray): cell_transition = self.get_full_transitions(rcPos[0], rcPos[1]) return Grid4Transitions.has_deadend(cell_transition) - @methodtools.lru_cache(maxsize=1_000_000) + @lru_cache(maxsize=1_000_000) def is_simple_turn(self, rcPos: IntVector2DArray): """ Check if the cell is a left/right simple turn @@ -388,7 +398,7 @@ def is_simple_turn(trans): return is_simple_turn(tmp) - @methodtools.lru_cache(maxsize=4_000_000) + @lru_cache(maxsize=4_000_000) def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray): """ Breath first search for a possible path from one node with a certain orientation to a target node. @@ -417,7 +427,7 @@ def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVec return False - @methodtools.lru_cache(maxsize=1_000_000) + @lru_cache(maxsize=1_000_000) def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) @@ -504,7 +514,7 @@ def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False): Returns: True (valid) or False (invalid) """ - send_infrastructure_data_change_signal_to_reset_lru_cache() + self._reset_cache() cell_transition = self.grid[tuple(rcPos)] if check_this_cell: @@ -548,7 +558,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1): """ Fixes broken transitions """ - send_infrastructure_data_change_signal_to_reset_lru_cache() + self._reset_cache() gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] grcPos = array(rcPos) grcMax = self.grid.shape @@ -625,7 +635,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1): self.set_transitions((rcPos[0], rcPos[1]), transition) return True - @methodtools.lru_cache(maxsize=1_000_000) + @lru_cache(maxsize=1_000_000) def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D, new_pos: IntVector2D, end_pos: IntVector2D): """ diff --git a/flatland/utils/decorators.py b/flatland/utils/decorators.py index 1d45adc6..625e2fbf 100644 --- a/flatland/utils/decorators.py +++ b/flatland/utils/decorators.py @@ -3,7 +3,7 @@ infrastructure_lru_cache_functions = [] -# TODO https://github.com/flatland-association/flatland-rl/issues/104 1. revise which caches need to be scoped at all - some seem not to require cache clearing at all. 2. refactor with need to explicitly reset cache in calls dispersed in the whole code base. Use classes to group the cache scope using methodtools for instance method lru caching. +# TODO https://github.com/flatland-association/flatland-rl/issues/104 1. revise which caches need to be scoped at all - some seem not to require cache clearing at all. 2. refactor with need to explicitly reset cache in calls dispersed in the whole code base. Use classes to group the cache scope by overriding eq/hash for instance method lru caching (see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls) def enable_infrastructure_lru_cache(*args, **kwargs): def decorator(func): func = lru_cache(*args, **kwargs)(func) @@ -13,9 +13,8 @@ def decorator(func): return decorator -# send_infrastructure_data_change_signal_to_reset_lru_cache() has a problem with instance methods - the methods are not properly cleared. -# Therefore, make sure to use methodtools for instance methods and to instantiantiate new objects to match instance and cache lifecycle. -# See https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-instance-methods-release-object +# send_infrastructure_data_change_signal_to_reset_lru_cache() has a problem with instance methods - the methods are not properly cleared by it. +# Therefore, make sure to override eq/hash to control cache lifecycle for instance method lru caching (see https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-instance-methods-release-object and https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls) def send_infrastructure_data_change_signal_to_reset_lru_cache(): for func in infrastructure_lru_cache_functions: func.cache_clear() diff --git a/pyproject.toml b/pyproject.toml index 73d47916..a095c43a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ dependencies = [ "ipython", "ipywidgets", "matplotlib", - "methodtools", "msgpack_numpy", "msgpack", "networkx", diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index e65b5106..20eaffde 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -6,7 +6,6 @@ from flatland.core.grid.grid8 import Grid8Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache # remove whitespace in string; keep whitespace below for easier reading @@ -127,28 +126,23 @@ def test_adding_new_valid_transition(): assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - send_infrastructure_data_change_signal_to_reset_lru_cache() - grid_map.grid[(5, 5)] = rail_trans.transitions[2] + grid_map.set_transitions((5, 5), rail_trans.transitions[2]) assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid - send_infrastructure_data_change_signal_to_reset_lru_cache() - grid_map.grid[(5, 5)] = rail_trans.transitions[3] + grid_map.set_transitions((5, 5), rail_trans.transitions[3]) assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - send_infrastructure_data_change_signal_to_reset_lru_cache() - grid_map.grid[(5, 5)] = rail_trans.transitions[7] + grid_map.set_transitions((5, 5), rail_trans.transitions[7]) assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition - send_infrastructure_data_change_signal_to_reset_lru_cache() - grid_map.grid[(5, 5)] = rail_trans.transitions[0] + grid_map.set_transitions((5, 5), rail_trans.transitions[3]) assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition - send_infrastructure_data_change_signal_to_reset_lru_cache() - grid_map.grid[(5, 5)] = rail_trans.transitions[0] + grid_map.set_transitions((5, 5), rail_trans.transitions[3]) assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True) From 49220a077ae080930004ee5e6853c06c7f055504 Mon Sep 17 00:00:00 2001 From: Christian Eichenberger Date: Tue, 18 Feb 2025 18:38:07 +0100 Subject: [PATCH 4/5] Update requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9c5b31a3..e5569313 100644 --- a/requirements.txt +++ b/requirements.txt @@ -77,7 +77,6 @@ matplotlib==3.10.0 # seaborn matplotlib-inline==0.1.7 # via ipython -methodtools==0.4.7 msgpack==1.1.0 # via # flatland-rl (pyproject.toml) From 6026a74511177a0a37a5cad55ec35b646eb6c38f Mon Sep 17 00:00:00 2001 From: chenkins Date: Wed, 19 Feb 2025 18:55:27 +0100 Subject: [PATCH 5/5] Add cache size assertions. --- tests/test_lru_cache_problem.py | 124 +++++++++++++++++++++++++++----- 1 file changed, 106 insertions(+), 18 deletions(-) diff --git a/tests/test_lru_cache_problem.py b/tests/test_lru_cache_problem.py index edf32af3..151a2795 100644 --- a/tests/test_lru_cache_problem.py +++ b/tests/test_lru_cache_problem.py @@ -3,131 +3,219 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator +maxsize = 1000000 +env_42_hits = 53 +env_42_cache_size = 1137 +env_43_hits = 60 +env_43_cache_size = 1108 +grid_size = 30 * 30 +hits_42_900_43_900_42_900_43_900 = env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits + grid_size +cache_size_42_43_42_43 = env_42_cache_size + env_43_cache_size + env_42_cache_size + env_43_cache_size +cache_size_42_43_42 = env_42_cache_size + env_43_cache_size + env_42_cache_size +cache_size_42_43 = env_42_cache_size + env_43_cache_size + -# TODO refactor parametrized load and load_new! def test_lru_load(): - # seed 42 + # avoid side effects from other tests + _clear_all_lru_caches() + + # (1) new env with seed 42 env_42 = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) - + assert _info_lru_cache() == (0, 0, maxsize, 0) env_42.reset(random_seed=42) + assert _info_lru_cache() == (env_42_hits, env_42_cache_size, maxsize, env_42_cache_size) transitions_42 = {} for r in range(30): for c in range(30): transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c) + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) + # (1b) save env with seed 42 RailEnvPersister.save(env_42, "env_42.pkl") + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) - # seed 43 + # (2) new env with seed 43 env_43 = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) - + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) env_43.reset(random_seed=43) + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits, 2245, maxsize, 2245) + transitions_43 = {} for r in range(30): for c in range(30): transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) # reset clears the cache, so the transitions are indeed different assert set(transitions_42.items()) != set(transitions_43.items()) + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43) - # seed 42 bis + # (3) second new env with seed 42 env_42_bis = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) - + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43) env_42_bis.reset(random_seed=42) + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits, cache_size_42_43_42, maxsize, + cache_size_42_43_42) + transitions_42_bis = {} for r in range(30): for c in range(30): transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c) # sanity check: same seed gives same transitions assert set(transitions_42.items()) == set(transitions_42_bis.items()) + assert _info_lru_cache() == ( + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size, cache_size_42_43_42, maxsize, + cache_size_42_43_42) - # populate cache with infrastructure from seed 43 + # (4) populate cache with infrastructure from seed 43 env_43 = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) env_43.reset(random_seed=43) + + assert _info_lru_cache() == ( + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits, + cache_size_42_43_42_43, maxsize, + cache_size_42_43_42_43) + transitions_43 = {} for r in range(30): for c in range(30): transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) # reset clears the cache, so the transitions are indeed different assert set(transitions_42.items()) != set(transitions_43.items()) + assert _info_lru_cache() == ( + hits_42_900_43_900_42_900_43_900, + cache_size_42_43_42_43, maxsize, + cache_size_42_43_42_43) - # load env_42 from file + # (5) load env_42 from file RailEnvPersister.load(env_43, "env_42.pkl") + # load does no reset -> no additional caching + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) env_42_tri = env_43 - transitions_42_tri = {} for r in range(30): for c in range(30): transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c) # load() now invalidates cache correctly assert set(transitions_42.items()) == set(transitions_42_tri.items()) + # 30*30 additional misses are cached: + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43 + grid_size, maxsize, cache_size_42_43_42_43 + grid_size) def test_lru_load_new(): - # seed 42 + # avoid side effects from other tests + _clear_all_lru_caches() + + # (1) new env with seed 42 env_42 = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) - + assert _info_lru_cache() == (0, 0, maxsize, 0) env_42.reset(random_seed=42) + assert _info_lru_cache() == (env_42_hits, env_42_cache_size, maxsize, env_42_cache_size) transitions_42 = {} for r in range(30): for c in range(30): transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c) + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) + # (1b) save env with seed 42 RailEnvPersister.save(env_42, "env_42.pkl") + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) - # seed 43 + # (2) new env with seed 43 env_43 = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) - + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) env_43.reset(random_seed=43) + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits, cache_size_42_43, maxsize, cache_size_42_43) transitions_43 = {} for r in range(30): for c in range(30): transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) # reset clears the cache, so the transitions are indeed different assert set(transitions_42.items()) != set(transitions_43.items()) + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43) - # seed 42 bis + # (3) second new env with seed 42 env_42_bis = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42) env_42_bis.reset(random_seed=42) + assert _info_lru_cache() == ( + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits, cache_size_42_43_42, maxsize, cache_size_42_43_42) transitions_42_bis = {} for r in range(30): for c in range(30): transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c) # sanity check: same seed gives same transitions assert set(transitions_42.items()) == set(transitions_42_bis.items()) + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size, cache_size_42_43_42, maxsize, cache_size_42_43_42) - # populate cache with infrastructure from seed 43 + # (4) populate cache with infrastructure from seed 43 env_43 = RailEnv(width=30, height=30, rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43) env_43.reset(random_seed=43) + + assert _info_lru_cache() == ( + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) + transitions_43 = {} for r in range(30): for c in range(30): transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c) # reset clears the cache, so the transitions are indeed different assert set(transitions_42.items()) != set(transitions_43.items()) + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) - # load_new() env_42 from file + # (5) load_new() env_42 from file # N.B.line `env.rail = GridTransitionMap(1, 1)` in `load_new` has side effect of clearing infrastructure cache. env_42_tri, _ = RailEnvPersister.load_new("env_42.pkl") - + # load does no reset -> no additional caching + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) transitions_42_tri = {} for r in range(30): for c in range(30): transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c) # load_new() invalidates cache (so env_43 transitions are cleared) assert set(transitions_42.items()) == set(transitions_42_tri.items()) + # 900 additional misses are cached: + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43 + grid_size, maxsize, cache_size_42_43_42_43 + grid_size) + + +def _info_lru_cache(): + import functools + import gc + + gc.collect() + wrappers = [ + a for a in gc.get_objects() + if isinstance(a, functools._lru_cache_wrapper)] + # print(wrappers) + for wrapper in wrappers: + if wrapper.__name__ == "get_full_transitions": + print(f"{wrapper.__name__} {wrapper.cache_info()}") + return wrapper.cache_info() + + +# https://stackoverflow.com/questions/40273767/clear-all-lru-cache-in-python +def _clear_all_lru_caches(): + import functools + import gc + + gc.collect() + wrappers = [ + a for a in gc.get_objects() + if isinstance(a, functools._lru_cache_wrapper)] + + for wrapper in wrappers: + wrapper.cache_clear()