|
3 | 3 | from flatland.envs.rail_env import RailEnv
|
4 | 4 | from flatland.envs.rail_generators import sparse_rail_generator
|
5 | 5 |
|
| 6 | +maxsize = 1000000 |
| 7 | +env_42_hits = 53 |
| 8 | +env_42_cache_size = 1137 |
| 9 | +env_43_hits = 60 |
| 10 | +env_43_cache_size = 1108 |
| 11 | +grid_size = 30 * 30 |
| 12 | +hits_42_900_43_900_42_900_43_900 = env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits + grid_size |
| 13 | +cache_size_42_43_42_43 = env_42_cache_size + env_43_cache_size + env_42_cache_size + env_43_cache_size |
| 14 | +cache_size_42_43_42 = env_42_cache_size + env_43_cache_size + env_42_cache_size |
| 15 | +cache_size_42_43 = env_42_cache_size + env_43_cache_size |
| 16 | + |
6 | 17 |
|
7 |
| -# TODO refactor parametrized load and load_new! |
8 | 18 | def test_lru_load():
|
9 |
| - # seed 42 |
| 19 | + # avoid side effects from other tests |
| 20 | + _clear_all_lru_caches() |
| 21 | + |
| 22 | + # (1) new env with seed 42 |
10 | 23 | env_42 = RailEnv(width=30, height=30,
|
11 | 24 | rail_generator=sparse_rail_generator(seed=1),
|
12 | 25 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
|
13 |
| - |
| 26 | + assert _info_lru_cache() == (0, 0, maxsize, 0) |
14 | 27 | env_42.reset(random_seed=42)
|
| 28 | + assert _info_lru_cache() == (env_42_hits, env_42_cache_size, maxsize, env_42_cache_size) |
15 | 29 | transitions_42 = {}
|
16 | 30 | for r in range(30):
|
17 | 31 | for c in range(30):
|
18 | 32 | transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c)
|
| 33 | + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) |
19 | 34 |
|
| 35 | + # (1b) save env with seed 42 |
20 | 36 | RailEnvPersister.save(env_42, "env_42.pkl")
|
| 37 | + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) |
21 | 38 |
|
22 |
| - # seed 43 |
| 39 | + # (2) new env with seed 43 |
23 | 40 | env_43 = RailEnv(width=30, height=30,
|
24 | 41 | rail_generator=sparse_rail_generator(seed=2),
|
25 | 42 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
|
26 |
| - |
| 43 | + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) |
27 | 44 | env_43.reset(random_seed=43)
|
| 45 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits, 2245, maxsize, 2245) |
| 46 | + |
28 | 47 | transitions_43 = {}
|
29 | 48 | for r in range(30):
|
30 | 49 | for c in range(30):
|
31 | 50 | transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
|
32 | 51 | # reset clears the cache, so the transitions are indeed different
|
33 | 52 | assert set(transitions_42.items()) != set(transitions_43.items())
|
| 53 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43) |
34 | 54 |
|
35 |
| - # seed 42 bis |
| 55 | + # (3) second new env with seed 42 |
36 | 56 | env_42_bis = RailEnv(width=30, height=30,
|
37 | 57 | rail_generator=sparse_rail_generator(seed=1),
|
38 | 58 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
|
39 |
| - |
| 59 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43) |
40 | 60 | env_42_bis.reset(random_seed=42)
|
| 61 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits, cache_size_42_43_42, maxsize, |
| 62 | + cache_size_42_43_42) |
| 63 | + |
41 | 64 | transitions_42_bis = {}
|
42 | 65 | for r in range(30):
|
43 | 66 | for c in range(30):
|
44 | 67 | transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c)
|
45 | 68 | # sanity check: same seed gives same transitions
|
46 | 69 | assert set(transitions_42.items()) == set(transitions_42_bis.items())
|
| 70 | + assert _info_lru_cache() == ( |
| 71 | + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size, cache_size_42_43_42, maxsize, |
| 72 | + cache_size_42_43_42) |
47 | 73 |
|
48 |
| - # populate cache with infrastructure from seed 43 |
| 74 | + # (4) populate cache with infrastructure from seed 43 |
49 | 75 | env_43 = RailEnv(width=30, height=30,
|
50 | 76 | rail_generator=sparse_rail_generator(seed=2),
|
51 | 77 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
|
52 | 78 | env_43.reset(random_seed=43)
|
| 79 | + |
| 80 | + assert _info_lru_cache() == ( |
| 81 | + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits, |
| 82 | + cache_size_42_43_42_43, maxsize, |
| 83 | + cache_size_42_43_42_43) |
| 84 | + |
53 | 85 | transitions_43 = {}
|
54 | 86 | for r in range(30):
|
55 | 87 | for c in range(30):
|
56 | 88 | transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
|
57 | 89 | # reset clears the cache, so the transitions are indeed different
|
58 | 90 | assert set(transitions_42.items()) != set(transitions_43.items())
|
| 91 | + assert _info_lru_cache() == ( |
| 92 | + hits_42_900_43_900_42_900_43_900, |
| 93 | + cache_size_42_43_42_43, maxsize, |
| 94 | + cache_size_42_43_42_43) |
59 | 95 |
|
60 |
| - # load env_42 from file |
| 96 | + # (5) load env_42 from file |
61 | 97 | RailEnvPersister.load(env_43, "env_42.pkl")
|
| 98 | + # load does no reset -> no additional caching |
| 99 | + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) |
62 | 100 | env_42_tri = env_43
|
63 |
| - |
64 | 101 | transitions_42_tri = {}
|
65 | 102 | for r in range(30):
|
66 | 103 | for c in range(30):
|
67 | 104 | transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c)
|
68 | 105 | # load() now invalidates cache correctly
|
69 | 106 | assert set(transitions_42.items()) == set(transitions_42_tri.items())
|
| 107 | + # 30*30 additional misses are cached: |
| 108 | + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43 + grid_size, maxsize, cache_size_42_43_42_43 + grid_size) |
70 | 109 |
|
71 | 110 |
|
72 | 111 | def test_lru_load_new():
|
73 |
| - # seed 42 |
| 112 | + # avoid side effects from other tests |
| 113 | + _clear_all_lru_caches() |
| 114 | + |
| 115 | + # (1) new env with seed 42 |
74 | 116 | env_42 = RailEnv(width=30, height=30,
|
75 | 117 | rail_generator=sparse_rail_generator(seed=1),
|
76 | 118 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
|
77 |
| - |
| 119 | + assert _info_lru_cache() == (0, 0, maxsize, 0) |
78 | 120 | env_42.reset(random_seed=42)
|
| 121 | + assert _info_lru_cache() == (env_42_hits, env_42_cache_size, maxsize, env_42_cache_size) |
79 | 122 | transitions_42 = {}
|
80 | 123 | for r in range(30):
|
81 | 124 | for c in range(30):
|
82 | 125 | transitions_42[(r, c)] = env_42.rail.get_full_transitions(r, c)
|
| 126 | + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) |
83 | 127 |
|
| 128 | + # (1b) save env with seed 42 |
84 | 129 | RailEnvPersister.save(env_42, "env_42.pkl")
|
| 130 | + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) |
85 | 131 |
|
86 |
| - # seed 43 |
| 132 | + # (2) new env with seed 43 |
87 | 133 | env_43 = RailEnv(width=30, height=30,
|
88 | 134 | rail_generator=sparse_rail_generator(seed=2),
|
89 | 135 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
|
90 |
| - |
| 136 | + assert _info_lru_cache() == (env_42_hits + grid_size, env_42_cache_size, maxsize, env_42_cache_size) |
91 | 137 | env_43.reset(random_seed=43)
|
| 138 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits, cache_size_42_43, maxsize, cache_size_42_43) |
92 | 139 | transitions_43 = {}
|
93 | 140 | for r in range(30):
|
94 | 141 | for c in range(30):
|
95 | 142 | transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
|
96 | 143 | # reset clears the cache, so the transitions are indeed different
|
97 | 144 | assert set(transitions_42.items()) != set(transitions_43.items())
|
| 145 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size, cache_size_42_43, maxsize, cache_size_42_43) |
98 | 146 |
|
99 |
| - # seed 42 bis |
| 147 | + # (3) second new env with seed 42 |
100 | 148 | env_42_bis = RailEnv(width=30, height=30,
|
101 | 149 | rail_generator=sparse_rail_generator(seed=1),
|
102 | 150 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=42)
|
103 | 151 |
|
104 | 152 | env_42_bis.reset(random_seed=42)
|
| 153 | + assert _info_lru_cache() == ( |
| 154 | + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits, cache_size_42_43_42, maxsize, cache_size_42_43_42) |
105 | 155 | transitions_42_bis = {}
|
106 | 156 | for r in range(30):
|
107 | 157 | for c in range(30):
|
108 | 158 | transitions_42_bis[(r, c)] = env_42.rail.get_full_transitions(r, c)
|
109 | 159 | # sanity check: same seed gives same transitions
|
110 | 160 | assert set(transitions_42.items()) == set(transitions_42_bis.items())
|
| 161 | + assert _info_lru_cache() == (env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size, cache_size_42_43_42, maxsize, cache_size_42_43_42) |
111 | 162 |
|
112 |
| - # populate cache with infrastructure from seed 43 |
| 163 | + # (4) populate cache with infrastructure from seed 43 |
113 | 164 | env_43 = RailEnv(width=30, height=30,
|
114 | 165 | rail_generator=sparse_rail_generator(seed=2),
|
115 | 166 | line_generator=sparse_line_generator(), number_of_agents=2, random_seed=43)
|
116 | 167 | env_43.reset(random_seed=43)
|
| 168 | + |
| 169 | + assert _info_lru_cache() == ( |
| 170 | + env_42_hits + grid_size + env_43_hits + grid_size + env_42_hits + grid_size + env_43_hits, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) |
| 171 | + |
117 | 172 | transitions_43 = {}
|
118 | 173 | for r in range(30):
|
119 | 174 | for c in range(30):
|
120 | 175 | transitions_43[(r, c)] = env_43.rail.get_full_transitions(r, c)
|
121 | 176 | # reset clears the cache, so the transitions are indeed different
|
122 | 177 | assert set(transitions_42.items()) != set(transitions_43.items())
|
| 178 | + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) |
123 | 179 |
|
124 |
| - # load_new() env_42 from file |
| 180 | + # (5) load_new() env_42 from file |
125 | 181 | # N.B.line `env.rail = GridTransitionMap(1, 1)` in `load_new` has side effect of clearing infrastructure cache.
|
126 | 182 | env_42_tri, _ = RailEnvPersister.load_new("env_42.pkl")
|
127 |
| - |
| 183 | + # load does no reset -> no additional caching |
| 184 | + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43, maxsize, cache_size_42_43_42_43) |
128 | 185 | transitions_42_tri = {}
|
129 | 186 | for r in range(30):
|
130 | 187 | for c in range(30):
|
131 | 188 | transitions_42_tri[(r, c)] = env_42_tri.rail.get_full_transitions(r, c)
|
132 | 189 | # load_new() invalidates cache (so env_43 transitions are cleared)
|
133 | 190 | assert set(transitions_42.items()) == set(transitions_42_tri.items())
|
| 191 | + # 900 additional misses are cached: |
| 192 | + assert _info_lru_cache() == (hits_42_900_43_900_42_900_43_900, cache_size_42_43_42_43 + grid_size, maxsize, cache_size_42_43_42_43 + grid_size) |
| 193 | + |
| 194 | + |
| 195 | +def _info_lru_cache(): |
| 196 | + import functools |
| 197 | + import gc |
| 198 | + |
| 199 | + gc.collect() |
| 200 | + wrappers = [ |
| 201 | + a for a in gc.get_objects() |
| 202 | + if isinstance(a, functools._lru_cache_wrapper)] |
| 203 | + # print(wrappers) |
| 204 | + for wrapper in wrappers: |
| 205 | + if wrapper.__name__ == "get_full_transitions": |
| 206 | + print(f"{wrapper.__name__} {wrapper.cache_info()}") |
| 207 | + return wrapper.cache_info() |
| 208 | + |
| 209 | + |
| 210 | +# https://stackoverflow.com/questions/40273767/clear-all-lru-cache-in-python |
| 211 | +def _clear_all_lru_caches(): |
| 212 | + import functools |
| 213 | + import gc |
| 214 | + |
| 215 | + gc.collect() |
| 216 | + wrappers = [ |
| 217 | + a for a in gc.get_objects() |
| 218 | + if isinstance(a, functools._lru_cache_wrapper)] |
| 219 | + |
| 220 | + for wrapper in wrappers: |
| 221 | + wrapper.cache_clear() |
0 commit comments