Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

118 Add test_lru_cache_problem.py. #119

Merged
merged 5 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions flatland/core/transition_map.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""
TransitionMap and derived classes.
"""
import traceback
import uuid
from functools import lru_cache

import numpy as np
from importlib_resources import path
from numpy import array
import traceback

from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.decorators import enable_infrastructure_lru_cache, send_infrastructure_data_change_signal_to_reset_lru_cache
from flatland.utils.ordered_set import OrderedSet


Expand Down Expand Up @@ -135,7 +135,6 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]
grid.

"""
send_infrastructure_data_change_signal_to_reset_lru_cache()
self.width = width
self.height = height
self.transitions = transitions
Expand All @@ -145,8 +144,19 @@ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
self._reset_cache()

def _reset_cache(self):
# use __eq__ and __hash__ to control cache lifecycle of instance methods, see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls.
self.uuid = uuid.uuid4().int

def __eq__(self, __value):
return isinstance(__value, GridTransitionMap) and self.uuid == __value.uuid

def __hash__(self):
return self.uuid

@enable_infrastructure_lru_cache(maxsize=1_000_000)
@lru_cache(maxsize=1_000_000)
def get_full_transitions(self, row, column):
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
Expand All @@ -165,7 +175,7 @@ def get_full_transitions(self, row, column):
"""
return self.grid[(row, column)]

@enable_infrastructure_lru_cache(maxsize=4_000_000)
@lru_cache(maxsize=4_000_000)
def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
Expand Down Expand Up @@ -206,17 +216,17 @@ def set_transitions(self, cell_id, new_transitions):
Tuple of new transitions validitiy for the cell.

"""
send_infrastructure_data_change_signal_to_reset_lru_cache()
#assert len(cell_id) in (2, 3), \
self._reset_cache()
# assert len(cell_id) in (2, 3), \
# 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if len(cell_id) == 3:
self.grid[cell_id[0:2]] = self.transitions.set_transitions(self.grid[cell_id[0:2]],
cell_id[2],
new_transitions)
cell_id[2],
new_transitions)
elif len(cell_id) == 2:
self.grid[cell_id] = new_transitions

@enable_infrastructure_lru_cache(maxsize=4_000_000)
@lru_cache(maxsize=4_000_000)
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id` can perform a
Expand All @@ -240,7 +250,7 @@ def get_transition(self, cell_id, transition_index):
0/1 allowed/not allowed, a probability in [0,1], etc...)

"""
#assert len(cell_id) == 3, \
# assert len(cell_id) == 3, \
# 'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.'
return self.transitions.get_transition(self.grid[cell_id[0:2]], cell_id[2], transition_index)

Expand All @@ -264,39 +274,39 @@ def set_transition(self, cell_id, transition_index, new_transition, remove_deade
0/1 allowed/not allowed, a probability in [0,1], etc...)

"""
send_infrastructure_data_change_signal_to_reset_lru_cache()
#assert len(cell_id) == 3, \
self._reset_cache()
# assert len(cell_id) == 3, \
# 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'

nDir = cell_id[2]
if type(nDir) == np.ndarray:
# I can't work out how to dump a complete backtrace here
try:
assert type(nDir)==int, "cell direction is not an int"
assert type(nDir) == int, "cell direction is not an int"
except Exception as e:
traceback.print_stack()
print("fixing nDir:", cell_id, nDir)
nDir = int(nDir[0])

#if type(transition_index) not in (int, np.int64):
# if type(transition_index) not in (int, np.int64):
if isinstance(transition_index, np.ndarray):
#print("fixing transition_index:", cell_id, transition_index)
# print("fixing transition_index:", cell_id, transition_index)
if type(transition_index) == np.ndarray:
transition_index = int(transition_index.ravel()[0])
else:
# print("transition_index type:", type(transition_index))
transition_index = int(transition_index)

#if type(new_transition) not in (int, bool):
# if type(new_transition) not in (int, bool):
if isinstance(new_transition, np.ndarray):
#print("fixing new_transition:", cell_id, new_transition)
# print("fixing new_transition:", cell_id, new_transition)
new_transition = int(new_transition.ravel()[0])

#print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)
# print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)

self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition(
self.grid[cell_id[0:2]],
nDir, # cell_id[2],
nDir, # cell_id[2],
transition_index,
new_transition,
remove_deadends)
Expand Down Expand Up @@ -332,7 +342,7 @@ def load_transition_map(self, package, resource):
(height,width) )

"""
send_infrastructure_data_change_signal_to_reset_lru_cache()
self._reset_cache()
with path(package, resource) as file_in:
new_grid = np.load(file_in)

Expand All @@ -343,7 +353,7 @@ def load_transition_map(self, package, resource):
self.height = new_height
self.grid = new_grid

@enable_infrastructure_lru_cache(maxsize=1_000_000)
@lru_cache(maxsize=1_000_000)
def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
Expand All @@ -360,7 +370,7 @@ def is_dead_end(self, rcPos: IntVector2DArray):
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)

@enable_infrastructure_lru_cache(maxsize=1_000_000)
@lru_cache(maxsize=1_000_000)
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
Expand Down Expand Up @@ -388,7 +398,7 @@ def is_simple_turn(trans):

return is_simple_turn(tmp)

@enable_infrastructure_lru_cache(maxsize=4_000_000)
@lru_cache(maxsize=4_000_000)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
Expand Down Expand Up @@ -417,7 +427,7 @@ def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVec

return False

@enable_infrastructure_lru_cache(maxsize=1_000_000)
@lru_cache(maxsize=1_000_000)
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Expand Down Expand Up @@ -504,7 +514,7 @@ def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):

Returns: True (valid) or False (invalid)
"""
send_infrastructure_data_change_signal_to_reset_lru_cache()
self._reset_cache()
cell_transition = self.grid[tuple(rcPos)]

if check_this_cell:
Expand Down Expand Up @@ -548,7 +558,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
send_infrastructure_data_change_signal_to_reset_lru_cache()
self._reset_cache()
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
Expand Down Expand Up @@ -625,7 +635,7 @@ def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True

@enable_infrastructure_lru_cache(maxsize=1_000_000)
@lru_cache(maxsize=1_000_000)
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Expand Down
10 changes: 6 additions & 4 deletions flatland/envs/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,19 @@ def set_full_state(cls, env, env_dict):
-------
env_dict: dict
"""
env.rail.grid = np.array(env_dict["grid"])
grid = np.array(env_dict["grid"])

# Initialise the env with the frozen agents in the file
env.agents = env_dict.get("agents", [])

# For consistency, set number_of_agents, which is the number which will be generated on reset
env.number_of_agents = env.get_num_agents()

env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
env.rail.width = env.width
env.height, env.width = grid.shape

# use new rail object instance for lru cache scoping and garbage collection to work properly
env.rail = GridTransitionMap(height=env.height, width=env.width)
env.rail.grid = grid
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)

# TODO merge with https://github.com/flatland-association/flatland-rl/pull/97/files
Expand Down
11 changes: 6 additions & 5 deletions flatland/envs/rail_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Definition of the RailEnv environment.
"""
import random
from functools import lru_cache
from typing import List, Optional, Dict, Tuple

import numpy as np
Expand All @@ -26,8 +27,7 @@
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.utils import seeding
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \
enable_infrastructure_lru_cache
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache
from flatland.utils.rendertools import RenderTool, AgentRenderVariant


Expand Down Expand Up @@ -239,8 +239,9 @@ def reset_agents(self):
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]

@enable_infrastructure_lru_cache()
def action_required(self, agent_state, is_cell_entry):
@lru_cache()
@staticmethod
def action_required(agent_state, is_cell_entry):
"""
Check if an agent needs to provide an action

Expand Down Expand Up @@ -459,7 +460,7 @@ def get_info_dict(self):
state - State from the trains's state machine
"""
info_dict = {
'action_required': {i: self.action_required(agent.state, agent.speed_counter.is_cell_entry)
'action_required': {i: RailEnv.action_required(agent.state, agent.speed_counter.is_cell_entry)
for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
Expand Down
3 changes: 3 additions & 0 deletions flatland/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
infrastructure_lru_cache_functions = []


# TODO https://github.com/flatland-association/flatland-rl/issues/104 1. revise which caches need to be scoped at all - some seem not to require cache clearing at all. 2. refactor with need to explicitly reset cache in calls dispersed in the whole code base. Use classes to group the cache scope by overriding eq/hash for instance method lru caching (see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls)
def enable_infrastructure_lru_cache(*args, **kwargs):
def decorator(func):
func = lru_cache(*args, **kwargs)(func)
Expand All @@ -12,6 +13,8 @@ def decorator(func):
return decorator


# send_infrastructure_data_change_signal_to_reset_lru_cache() has a problem with instance methods - the methods are not properly cleared by it.
# Therefore, make sure to override eq/hash to control cache lifecycle for instance method lru caching (see https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-instance-methods-release-object and https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls)
def send_infrastructure_data_change_signal_to_reset_lru_cache():
for func in infrastructure_lru_cache_functions:
func.cache_clear()
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ dependencies = [
"setuptools",
"svgutils",
"timeout_decorator",

]
dynamic = ["version"]

Expand Down
16 changes: 5 additions & 11 deletions tests/test_flatland_core_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from flatland.core.grid.grid8 import Grid8Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache


# remove whitespace in string; keep whitespace below for easier reading
Expand Down Expand Up @@ -127,28 +126,23 @@ def test_adding_new_valid_transition():
assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True)

# adding invalid turn
send_infrastructure_data_change_signal_to_reset_lru_cache()
grid_map.grid[(5, 5)] = rail_trans.transitions[2]
grid_map.set_transitions((5, 5), rail_trans.transitions[2])
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)

# should create #4 -> valid
send_infrastructure_data_change_signal_to_reset_lru_cache()
grid_map.grid[(5, 5)] = rail_trans.transitions[3]
grid_map.set_transitions((5, 5), rail_trans.transitions[3])
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True)

# adding invalid turn
send_infrastructure_data_change_signal_to_reset_lru_cache()
grid_map.grid[(5, 5)] = rail_trans.transitions[7]
grid_map.set_transitions((5, 5), rail_trans.transitions[7])
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)

# test path start condition
send_infrastructure_data_change_signal_to_reset_lru_cache()
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
grid_map.set_transitions((5, 5), rail_trans.transitions[3])
assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True)

# test path end condition
send_infrastructure_data_change_signal_to_reset_lru_cache()
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
grid_map.set_transitions((5, 5), rail_trans.transitions[3])
assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True)


Expand Down
Loading