@@ -27,14 +27,14 @@ def test_reset(env):
27
27
28
28
# add agent 1 obs
29
29
agent_1_obs = [0.0 , 0.86 ]
30
- agent_1_obs += np .array ([[0 , 0 , 0 ],
31
- [1 , 3 , 0 ],
32
- [2 , 0 , 0 ]]).flatten ().tolist ()
30
+ agent_1_obs += np .array ([[[ 0 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ] ],
31
+ [[ 1 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 1 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ] ],
32
+ [[ 0 , 1 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ] ]]).flatten ().tolist ()
33
33
# add agent 2 obs
34
34
agent_2_obs = [0.67 , 0.86 ]
35
- agent_2_obs += np .array ([[2 , 0 , 0 ],
36
- [1 , 3 , 0 ],
37
- [0 , 0 , 0 ]]).flatten ().tolist ()
35
+ agent_2_obs += np .array ([[[ 0 , 1 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ] ],
36
+ [[ 1 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 0 , 0 ] ],
37
+ [[ 0 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 0 ] ]]).flatten ().tolist ()
38
38
39
39
init_obs_n = [agent_1_obs , agent_2_obs ]
40
40
@@ -49,20 +49,23 @@ def test_reset(env):
49
49
50
50
@pytest .mark .parametrize ('pos,valid' ,
51
51
[((- 1 , - 1 ), False ), ((- 1 , 0 ), False ), ((- 1 , 8 ), False ), ((3 , 8 ), False )])
52
- def test_is_valid (env , pos , valid ):
52
+ def test_pos_validity (env , pos , valid ):
53
53
assert env .is_valid (pos ) == valid
54
54
55
55
56
56
@pytest .mark .parametrize ('action_n,output' ,
57
57
[([1 , 1 ], # action
58
- ([[0.0 , 0.71 , 0.0 , 0.0 , 0.0 , 2.0 , 3.0 , 0.0 , 1.0 , 2.0 , 0.0 ],
59
- [0.67 , 0.71 , 1.0 , 2.0 , 0.0 , 2.0 , 3.0 , 0.0 , 0.0 , 0.0 , 0.0 ]], # obs
58
+ ([[0.0 , 0.71 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 ,
59
+ 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
60
+ [0.67 , 0.71 , 1 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 ,
61
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]],
60
62
{'lemon' : 7 , 'apple' : 9 }))]) # food_count
61
63
def test_step (env , action_n , output ):
62
64
env .reset ()
63
65
target_obs_n , food_count = output
64
66
obs_n , reward_n , done_n , info = env .step (action_n )
65
67
68
+ assert obs_n == target_obs_n , 'observation does not match . Expected {}. Got {}' .format (target_obs_n , obs_n )
66
69
for k , v in food_count .items ():
67
70
assert info ['food_count' ][k ] == food_count [k ], '{} does not match' .format (k )
68
71
assert env ._step_count == 1
@@ -99,18 +102,18 @@ def test_observation_space(env):
99
102
assert env .observation_space .contains (env .observation_space .sample ())
100
103
101
104
102
- @parametrize_plus ('env' , [fixture_ref (env ),
103
- fixture_ref (env_full )])
104
- def test_rollout (env ):
105
+ @parametrize_plus ('env' , [fixture_ref (env )])
106
+ def test_rollout_env (env ):
105
107
actions = [[1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ],
106
108
[0 , 4 ], [3 , 4 ], [3 , 4 ], [3 , 4 ], [3 , 4 ], [3 , 4 ]]
107
- target_rewards = [[- 5.01 , - 1.01 ], [4.99 , 0.99 ], [- 5.01 , - 1.01 ], [4.99 , 0.99 ],
108
- [- 5.01 , - 1.01 ], [4.99 , 0.99 ], [- 0.01 , - 0.01 ], [- 0.01 , - 0.01 ],
109
- [- 5.01 , - 0.01 ], [4.99 , - 0.01 ], [- 5.01 , - 0.01 ], [4.99 , - 0.01 ],
110
- [- 5.01 , - 0.01 ], [4.99 , - 0.01 ]]
111
- for episode_i in range (2 ):
109
+ target_rewards = [[- 10.01 , - 1.01 ], [9.99 , 0.99 ], [- 10.01 , - 1.01 ], [9.99 , 0.99 ],
110
+ [- 10.01 , - 1.01 ], [9.99 , 0.99 ], [- 0.01 , - 0.01 ], [- 0.01 , - 0.01 ],
111
+ [- 10.01 , - 0.01 ], [9.99 , - 0.01 ], [- 10.01 , - 0.01 ], [9.99 , - 0.01 ],
112
+ [- 10.01 , - 0.01 ], [9.99 , - 0.01 ]]
112
113
113
- env .reset ()
114
+ for episode_i in range (1 ): # multiple episode to validate the seq. again on reset.
115
+
116
+ obs = env .reset ()
114
117
done = [False for _ in range (env .n_agents )]
115
118
for step_i in range (len (actions )):
116
119
obs , reward_n , done , _ = env .step (actions [step_i ])
@@ -132,3 +135,26 @@ def test_max_steps(env):
132
135
step_i += 1
133
136
assert step_i == env ._max_steps
134
137
assert done == [True for _m in range (env .n_agents )]
138
+
139
+
140
+ @parametrize_plus ('env' , [fixture_ref (env ),
141
+ fixture_ref (env_full )])
142
+ def test_collision (env ):
143
+ for episode_i in range (2 ):
144
+ env .reset ()
145
+ obs_1 , reward_n , done , _ = env .step ([0 , 2 ])
146
+ obs_2 , reward_n , done , _ = env .step ([0 , 2 ])
147
+
148
+ assert obs_1 == obs_2
149
+
150
+
151
+ @parametrize_plus ('env' , [fixture_ref (env ),
152
+ fixture_ref (env_full )])
153
+ def test_revisit_fruit_cell (env ):
154
+ for episode_i in range (2 ):
155
+ env .reset ()
156
+ obs_1 , reward_1 , done , _ = env .step ([1 , 1 ])
157
+ obs_2 , reward_2 , done , _ = env .step ([3 , 3 ])
158
+ obs_3 , reward_3 , done , _ = env .step ([1 , 1 ])
159
+
160
+ assert reward_1 != reward_3
0 commit comments