Skip to content

Commit 16e32c3

Browse files
committed
Add cache size assertions.
1 parent aa40a08 commit 16e32c3

File tree

1 file changed

+106
-18
lines changed

1 file changed

+106
-18
lines changed

tests/test_lru_cache_problem.py

+106-18
Original file line numberDiff line numberDiff line change
@@ -3,131 +3,219 @@
33
from flatland.envs.rail_env import RailEnv
44
from flatland.envs.rail_generators import sparse_rail_generator
55

6+
maxsize = 1000000
7+
env_42_hits = 53
8+
env_42_cache_size = 1137
9+
env_43_hits = 60
10+
env_43_cache_size = 1108
11+
grid_size = 30 * 30
12+
hits_42_900_43_900_42_900_43_900 = env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits + grid_size
13+
cache_size_42_43_42_43 = env_42_cache_size + env_43_cache_size + env_42_cache_size + env_43_cache_size
14+
cache_size_42_43_42 = env_42_cache_size + env_43_cache_size + env_42_cache_size
15+
cache_size_42_43 = env_42_cache_size + env_43_cache_size
16+
617

7-
# TODO refactor parametrized load and load_new!
818
def test_lru_load():
9-
# seed 42
19+
# avoid side effects from other tests
20+
_clear_all_lru_caches()
21+
22+
# (1) new env with seed 42
1023
env_42 = RailEnv(width=30, height=30,
1124
rail_generator=sparse_rail_generator(seed=1),
1225
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
13-
26+
assert _info_lru_cache() == (0, 0, maxsize, 0)
1427
env_42.reset(random_seed=42)
28+
assert _info_lru_cache() == (env_42_hits, env_42_cache_size, maxsize, env_42_cache_size)
1529
transitions_42 = {}
1630
for r in range(30):
1731
for c in range(30):
1832
transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c)
33+
assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size)
1934

35+
# (1b) save env with seed 42
2036
RailEnvPersister.save(env_42, "env_42.pkl")
37+
assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size)
2138

22-
# seed 43
39+
# (2) new env with seed 43
2340
env_43 = RailEnv(width=30, height=30,
2441
rail_generator=sparse_rail_generator(seed=2),
2542
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
26-
43+
assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size)
2744
env_43.reset(random_seed=43)
45+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits, 2245, maxsize, 2245)
46+
2847
transitions_43 = {}
2948
for r in range(30):
3049
for c in range(30):
3150
transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
3251
# reset clears the cache, so the transitions are indeed different
3352
assert set(transitions_42.items()) != set(transitions_43.items())
53+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43)
3454

35-
# seed 42 bis
55+
# (3) second new env with seed 42
3656
env_42_bis = RailEnv(width=30, height=30,
3757
rail_generator=sparse_rail_generator(seed=1),
3858
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
39-
59+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43)
4060
env_42_bis.reset(random_seed=42)
61+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits, cache_size_42_43_42, maxsize,
62+
cache_size_42_43_42)
63+
4164
transitions_42_bis = {}
4265
for r in range(30):
4366
for c in range(30):
4467
transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c)
4568
# sanity check: same seed gives same transitions
4669
assert set(transitions_42.items()) == set(transitions_42_bis.items())
70+
assert _info_lru_cache() == (
71+
env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size, cache_size_42_43_42, maxsize,
72+
cache_size_42_43_42)
4773

48-
# populate cache with infrastructure from seed 43
74+
# (4) populate cache with infrastructure from seed 43
4975
env_43 = RailEnv(width=30, height=30,
5076
rail_generator=sparse_rail_generator(seed=2),
5177
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
5278
env_43.reset(random_seed=43)
79+
80+
assert _info_lru_cache() == (
81+
env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits,
82+
cache_size_42_43_42_43, maxsize,
83+
cache_size_42_43_42_43)
84+
5385
transitions_43 = {}
5486
for r in range(30):
5587
for c in range(30):
5688
transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
5789
# reset clears the cache, so the transitions are indeed different
5890
assert set(transitions_42.items()) != set(transitions_43.items())
91+
assert _info_lru_cache() == (
92+
hits_42_900_43_900_42_900_43_900,
93+
cache_size_42_43_42_43, maxsize,
94+
cache_size_42_43_42_43)
5995

60-
# load env_42 from file
96+
# (5) load env_42 from file
6197
RailEnvPersister.load(env_43, "env_42.pkl")
98+
# load does no reset -> no additional caching
99+
assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43)
62100
env_42_tri = env_43
63-
64101
transitions_42_tri = {}
65102
for r in range(30):
66103
for c in range(30):
67104
transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c)
68105
# load() now invalidates cache correctly
69106
assert set(transitions_42.items()) == set(transitions_42_tri.items())
107+
# 30*30 additional misses are cached:
108+
assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43 + grid_size, maxsize, cache_size_42_43_42_43 + grid_size)
70109

71110

72111
def test_lru_load_new():
73-
# seed 42
112+
# avoid side effects from other tests
113+
_clear_all_lru_caches()
114+
115+
# (1) new env with seed 42
74116
env_42 = RailEnv(width=30, height=30,
75117
rail_generator=sparse_rail_generator(seed=1),
76118
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
77-
119+
assert _info_lru_cache() == (0, 0, maxsize, 0)
78120
env_42.reset(random_seed=42)
121+
assert _info_lru_cache() == (env_42_hits, env_42_cache_size, maxsize, env_42_cache_size)
79122
transitions_42 = {}
80123
for r in range(30):
81124
for c in range(30):
82125
transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c)
126+
assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size)
83127

128+
# (1b) save env with seed 42
84129
RailEnvPersister.save(env_42, "env_42.pkl")
130+
assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size)
85131

86-
# seed 43
132+
# (2) new env with seed 43
87133
env_43 = RailEnv(width=30, height=30,
88134
rail_generator=sparse_rail_generator(seed=2),
89135
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
90-
136+
assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size)
91137
env_43.reset(random_seed=43)
138+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits, cache_size_42_43, maxsize, cache_size_42_43)
92139
transitions_43 = {}
93140
for r in range(30):
94141
for c in range(30):
95142
transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
96143
# reset clears the cache, so the transitions are indeed different
97144
assert set(transitions_42.items()) != set(transitions_43.items())
145+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43)
98146

99-
# seed 42 bis
147+
# (3) second new env with seed 42
100148
env_42_bis = RailEnv(width=30, height=30,
101149
rail_generator=sparse_rail_generator(seed=1),
102150
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
103151

104152
env_42_bis.reset(random_seed=42)
153+
assert _info_lru_cache() == (
154+
env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits, cache_size_42_43_42, maxsize, cache_size_42_43_42)
105155
transitions_42_bis = {}
106156
for r in range(30):
107157
for c in range(30):
108158
transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c)
109159
# sanity check: same seed gives same transitions
110160
assert set(transitions_42.items()) == set(transitions_42_bis.items())
161+
assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size, cache_size_42_43_42, maxsize, cache_size_42_43_42)
111162

112-
# populate cache with infrastructure from seed 43
163+
# (4) populate cache with infrastructure from seed 43
113164
env_43 = RailEnv(width=30, height=30,
114165
rail_generator=sparse_rail_generator(seed=2),
115166
line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
116167
env_43.reset(random_seed=43)
168+
169+
assert _info_lru_cache() == (
170+
env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43)
171+
117172
transitions_43 = {}
118173
for r in range(30):
119174
for c in range(30):
120175
transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
121176
# reset clears the cache, so the transitions are indeed different
122177
assert set(transitions_42.items()) != set(transitions_43.items())
178+
assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43)
123179

124-
# load_new() env_42 from file
180+
# (5) load_new() env_42 from file
125181
# N.B.line `env.rail = GridTransitionMap(1, 1)` in `load_new` has side effect of clearing infrastructure cache.
126182
env_42_tri, _ = RailEnvPersister.load_new("env_42.pkl")
127-
183+
# load does no reset -> no additional caching
184+
assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43)
128185
transitions_42_tri = {}
129186
for r in range(30):
130187
for c in range(30):
131188
transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c)
132189
# load_new() invalidates cache (so env_43 transitions are cleared)
133190
assert set(transitions_42.items()) == set(transitions_42_tri.items())
191+
# 900 additional misses are cached:
192+
assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43 + grid_size, maxsize, cache_size_42_43_42_43 + grid_size)
193+
194+
195+
def _info_lru_cache():
196+
import functools
197+
import gc
198+
199+
gc.collect()
200+
wrappers = [
201+
a for a in gc.get_objects()
202+
if isinstance(a, functools._lru_cache_wrapper)]
203+
# print(wrappers)
204+
for wrapper in wrappers:
205+
if wrapper.__name__ == "get_full_transitions":
206+
print(f"{wrapper.__name__} {wrapper.cache_info()}")
207+
return wrapper.cache_info()
208+
209+
210+
# https://stackoverflow.com/questions/40273767/clear-all-lru-cache-in-python
211+
def _clear_all_lru_caches():
212+
import functools
213+
import gc
214+
215+
gc.collect()
216+
wrappers = [
217+
a for a in gc.get_objects()
218+
if isinstance(a, functools._lru_cache_wrapper)]
219+
220+
for wrapper in wrappers:
221+
wrapper.cache_clear()

0 commit comments

Comments
 (0)