Skip to content

Commit 8985a5b

Browse files
authored
Merge pull request #34 from koulanurag/checkers
- one hot representation for items - reward matched with original paper
2 parents 1b0f9db + da89bb2 commit 8985a5b

File tree

2 files changed

+64
-32
lines changed

2 files changed

+64
-32
lines changed

ma_gym/envs/checkers/checkers.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
class Checkers(gym.Env):
1717
"""
18-
The map contains apples and lemons. The first player (red) is very sensitive and scores 5 for
19-
the team for an apple (green square) and −5 for a lemon (orange square). The second (blue), less sensitive
18+
The map contains apples and lemons. The first player (red) is very sensitive and scores 10 for
19+
the team for an apple (green square) and −10 for a lemon (orange square). The second (blue), less sensitive
2020
player scores 1 for the team for an apple and −1 for a lemon. There is a wall of lemons between the
2121
players and the apples. Apples and lemons disappear when collected, and the environment resets
2222
when all apples are eaten. It is important that the sensitive agent eats the apples while the less sensitive
@@ -35,16 +35,16 @@ def __init__(self, full_observable=False, step_cost=-0.01, max_steps=100):
3535
self.full_observable = full_observable
3636

3737
self.action_space = MultiAgentActionSpace([spaces.Discrete(5) for _ in range(self.n_agents)])
38-
self._obs_high = np.array([1.0, 1.0] + [max(OBSERVATION_MEANING.keys())] * 9, dtype=np.float32)
39-
self._obs_low = np.array([0.0, 0.0] + [min(OBSERVATION_MEANING.keys())] * 9, dtype=np.float32)
38+
self._obs_high = np.ones(2 + (3 * 3 * 5))
39+
self._obs_low = np.zeros(2 + (3 * 3 * 5))
4040
if self.full_observable:
4141
self._obs_high = np.tile(self._obs_high, self.n_agents)
4242
self._obs_low = np.tile(self._obs_low, self.n_agents)
43-
self.observation_space = MultiAgentObservationSpace(
44-
[spaces.Box(self._obs_low, self._obs_high) for _ in range(self.n_agents)])
43+
self.observation_space = MultiAgentObservationSpace([spaces.Box(self._obs_low, self._obs_high)
44+
for _ in range(self.n_agents)])
4545

4646
self.init_agent_pos = {0: [0, self._grid_shape[1] - 2], 1: [2, self._grid_shape[1] - 2]}
47-
self.agent_reward = {0: {'lemon': -5, 'apple': 5},
47+
self.agent_reward = {0: {'lemon': -10, 'apple': 10},
4848
1: {'lemon': -1, 'apple': 1}}
4949

5050
self.agent_prev_pos = None
@@ -107,19 +107,19 @@ def get_agent_obs(self):
107107

108108
# add 3 x3 mask around the agent current location and share neighbours
109109
# ( in practice: this information may not be so critical since the map never changes)
110-
_agent_i_neighbour = np.zeros((3, 3))
110+
_agent_i_neighbour = np.zeros((3, 3, 5))
111111
for r in range(pos[0] - 1, pos[0] + 2):
112112
for c in range(pos[1] - 1, pos[1] + 2):
113113
if self.is_valid((r, c)):
114-
item = 0
114+
item = [0, 0, 0, 0, 0]
115115
if PRE_IDS['lemon'] in self._full_obs[r][c]:
116-
item = 1
116+
item[ITEM_ONE_HOT_INDEX['lemon']] = 1
117117
elif PRE_IDS['apple'] in self._full_obs[r][c]:
118-
item = 2
118+
item[ITEM_ONE_HOT_INDEX['apple']] = 1
119119
elif PRE_IDS['agent'] in self._full_obs[r][c]:
120-
item = 3
120+
item[ITEM_ONE_HOT_INDEX[self._full_obs[r][c]]] = 1
121121
elif PRE_IDS['wall'] in self._full_obs[r][c]:
122-
item = -1
122+
item[ITEM_ONE_HOT_INDEX['wall']] = 1
123123
_agent_i_neighbour[r - (pos[0] - 1)][c - (pos[1] - 1)] = item
124124
_agent_i_obs += _agent_i_neighbour.flatten().tolist()
125125

@@ -267,7 +267,6 @@ def close(self):
267267
# each pre-id should be unique and single char
268268
PRE_IDS = {
269269
'agent': 'A',
270-
'prey': 'P',
271270
'wall': 'W',
272271
'empty': '0',
273272
'lemon': 'Y', # yellow color
@@ -278,6 +277,13 @@ def close(self):
278277
0: 'red',
279278
1: 'blue'
280279
}
280+
ITEM_ONE_HOT_INDEX = {
281+
'lemon': 0,
282+
'apple': 1,
283+
'A1': 2,
284+
'A2': 3,
285+
'wall': 4,
286+
}
281287
WALL_COLOR = 'black'
282288
LEMON_COLOR = 'yellow'
283289
APPLE_COLOR = 'green'

tests/envs/test_checkers.py

+44-18
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ def test_reset(env):
2727

2828
# add agent 1 obs
2929
agent_1_obs = [0.0, 0.86]
30-
agent_1_obs += np.array([[0, 0, 0],
31-
[1, 3, 0],
32-
[2, 0, 0]]).flatten().tolist()
30+
agent_1_obs += np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
31+
[[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0]],
32+
[[0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]).flatten().tolist()
3333
# add agent 2 obs
3434
agent_2_obs = [0.67, 0.86]
35-
agent_2_obs += np.array([[2, 0, 0],
36-
[1, 3, 0],
37-
[0, 0, 0]]).flatten().tolist()
35+
agent_2_obs += np.array([[[0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
36+
[[1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0]],
37+
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]).flatten().tolist()
3838

3939
init_obs_n = [agent_1_obs, agent_2_obs]
4040

@@ -49,20 +49,23 @@ def test_reset(env):
4949

5050
@pytest.mark.parametrize('pos,valid',
5151
[((-1, -1), False), ((-1, 0), False), ((-1, 8), False), ((3, 8), False)])
52-
def test_is_valid(env, pos, valid):
52+
def test_pos_validity(env, pos, valid):
5353
assert env.is_valid(pos) == valid
5454

5555

5656
@pytest.mark.parametrize('action_n,output',
5757
[([1, 1], # action
58-
([[0.0, 0.71, 0.0, 0.0, 0.0, 2.0, 3.0, 0.0, 1.0, 2.0, 0.0],
59-
[0.67, 0.71, 1.0, 2.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0]], # obs
58+
([[0.0, 0.71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0,
59+
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
60+
[0.67, 0.71, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,
61+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
6062
{'lemon': 7, 'apple': 9}))]) # food_count
6163
def test_step(env, action_n, output):
6264
env.reset()
6365
target_obs_n, food_count = output
6466
obs_n, reward_n, done_n, info = env.step(action_n)
6567

68+
assert obs_n == target_obs_n, 'observation does not match . Expected {}. Got {}'.format(target_obs_n, obs_n)
6669
for k, v in food_count.items():
6770
assert info['food_count'][k] == food_count[k], '{} does not match'.format(k)
6871
assert env._step_count == 1
@@ -99,18 +102,18 @@ def test_observation_space(env):
99102
assert env.observation_space.contains(env.observation_space.sample())
100103

101104

102-
@parametrize_plus('env', [fixture_ref(env),
103-
fixture_ref(env_full)])
104-
def test_rollout(env):
105+
@parametrize_plus('env', [fixture_ref(env)])
106+
def test_rollout_env(env):
105107
actions = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1],
106108
[0, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]]
107-
target_rewards = [[-5.01, -1.01], [4.99, 0.99], [-5.01, -1.01], [4.99, 0.99],
108-
[-5.01, -1.01], [4.99, 0.99], [-0.01, -0.01], [-0.01, -0.01],
109-
[-5.01, -0.01], [4.99, -0.01], [-5.01, -0.01], [4.99, -0.01],
110-
[-5.01, -0.01], [4.99, -0.01]]
111-
for episode_i in range(2):
109+
target_rewards = [[-10.01, -1.01], [9.99, 0.99], [-10.01, -1.01], [9.99, 0.99],
110+
[-10.01, -1.01], [9.99, 0.99], [-0.01, -0.01], [-0.01, -0.01],
111+
[-10.01, -0.01], [9.99, -0.01], [-10.01, -0.01], [9.99, -0.01],
112+
[-10.01, -0.01], [9.99, -0.01]]
112113

113-
env.reset()
114+
for episode_i in range(1): # multiple episode to validate the seq. again on reset.
115+
116+
obs = env.reset()
114117
done = [False for _ in range(env.n_agents)]
115118
for step_i in range(len(actions)):
116119
obs, reward_n, done, _ = env.step(actions[step_i])
@@ -132,3 +135,26 @@ def test_max_steps(env):
132135
step_i += 1
133136
assert step_i == env._max_steps
134137
assert done == [True for _m in range(env.n_agents)]
138+
139+
140+
@parametrize_plus('env', [fixture_ref(env),
141+
fixture_ref(env_full)])
142+
def test_collision(env):
143+
for episode_i in range(2):
144+
env.reset()
145+
obs_1, reward_n, done, _ = env.step([0, 2])
146+
obs_2, reward_n, done, _ = env.step([0, 2])
147+
148+
assert obs_1 == obs_2
149+
150+
151+
@parametrize_plus('env', [fixture_ref(env),
152+
fixture_ref(env_full)])
153+
def test_revisit_fruit_cell(env):
154+
for episode_i in range(2):
155+
env.reset()
156+
obs_1, reward_1, done, _ = env.step([1, 1])
157+
obs_2, reward_2, done, _ = env.step([3, 3])
158+
obs_3, reward_3, done, _ = env.step([1, 1])
159+
160+
assert reward_1 != reward_3

0 commit comments

Comments
 (0)