Skip to content

Commit 564fad1

Browse files
fix: combat (#41)
Fix bugs: 1. in get_agent_obs(), the local variable '_type' was mistakenly written as 'type'. 2. self.agent_prev_pos and self.opp_prev_pos are not updated, but in __update_agent_view() and __update_opp_view(), these previous positions are cleared. Thus the movement records are incorrect. 3. The self._agent_cool and self._opp_cool are not updated (i.e., the agents and opponents could execute 'attack' action at every step if within range). updates: (1) update the is_fireable() function to check if the attack agents have been cooled down. (2) add self._agent_cool_step and self._opp_cool_step to track the cooling down of the agents and opponents. 4. In the opps_action() function, the dead opponent could still call reduce_distance_move() function to sample a 'move' action, which will raise an 'no where to move' exception when the position of the dead opponent is the same with any of the agents. 5. Update the get_agent_obs(): only return the local view of observation when the agent is alive, otherwise return a zero vector. Add 2 new functions to support CTDE 6. Add get_state() to support CTDE (centralized training with decentralized execution). In get_agent_obs(), the x-coordinates and y-coordinates of the entites (whether the entity is 'agent' or 'opponent') within each agent's view range were incorrectly set to the agent's position.
1 parent df8ef6d commit 564fad1

File tree

1 file changed

+118
-47
lines changed

1 file changed

+118
-47
lines changed

ma_gym/envs/combat/combat.py

+118-47
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ class Combat(gym.Env):
4545
metadata = {'render.modes': ['human', 'rgb_array']}
4646

4747
def __init__(self, grid_shape=(15, 15), n_agents=5, n_opponents=5, init_health=3, full_observable=False,
48-
step_cost=0, max_steps=100):
48+
step_cost=0, max_steps=100, step_cool=1):
4949
self._grid_shape = grid_shape
5050
self.n_agents = n_agents
5151
self._n_opponents = n_opponents
5252
self._max_steps = max_steps
53+
self._step_cool = step_cool + 1
5354
self._step_cost = step_cost
5455
self._step_count = None
5556

@@ -58,25 +59,32 @@ def __init__(self, grid_shape=(15, 15), n_agents=5, n_opponents=5, init_health=3
5859

5960
self.agent_pos = {_: None for _ in range(self.n_agents)}
6061
self.agent_prev_pos = {_: None for _ in range(self.n_agents)}
61-
self.opp_pos = {_: None for _ in range(self.n_agents)}
62-
self.opp_prev_pos = {_: None for _ in range(self.n_agents)}
62+
self.opp_pos = {_: None for _ in range(self._n_opponents)}
63+
self.opp_prev_pos = {_: None for _ in range(self._n_opponents)}
6364

6465
self._init_health = init_health
6566
self.agent_health = {_: None for _ in range(self.n_agents)}
6667
self.opp_health = {_: None for _ in range(self._n_opponents)}
6768
self._agent_dones = [None for _ in range(self.n_agents)]
6869
self._agent_cool = {_: None for _ in range(self.n_agents)}
70+
self._agent_cool_step = {_: None for _ in range(self.n_agents)}
6971
self._opp_cool = {_: None for _ in range(self._n_opponents)}
72+
self._opp_cool_step = {_: None for _ in range(self._n_opponents)}
7073
self._total_episode_reward = None
7174
self.viewer = None
7275
self.full_observable = full_observable
7376

7477
# 5 * 5 * (type, id, health, cool, x, y)
7578
self._obs_low = np.repeat(np.array([-1., 0., 0., -1., 0., 0.], dtype=np.float32), 5 * 5)
7679
self._obs_high = np.repeat(np.array([1., n_opponents, init_health, 1., 1., 1.], dtype=np.float32), 5 * 5)
77-
self.observation_space = MultiAgentObservationSpace([spaces.Box(self._obs_low, self._obs_high) for _ in range(self.n_agents)])
80+
self.observation_space = MultiAgentObservationSpace(
81+
[spaces.Box(self._obs_low, self._obs_high) for _ in range(self.n_agents)])
7882
self.seed()
7983

84+
# For debug only
85+
self._agents_trace = {_: None for _ in range(self.n_agents)}
86+
self._opponents_trace = {_: None for _ in range(self._n_opponents)}
87+
8088
def get_action_meanings(self, agent_i=None):
8189
action_meaning = []
8290
for _ in range(self.n_agents):
@@ -100,44 +108,68 @@ def _one_hot_encoding(i, n):
100108
def get_agent_obs(self):
101109
"""
102110
When input to a model, each agent is represented by a set of one-hot binary vectors {i, t, l, h, c}
103-
encoding its unique ID, team ID, location, health points and cooldown.
111+
encoding its team ID, unique ID, location, health points and cooldown.
104112
A model controlling an agent also sees other agents in its visual range (5 × 5 surrounding area).
105113
:return:
106114
"""
107115
_obs = []
108116
for agent_i in range(self.n_agents):
109-
pos = self.agent_pos[agent_i]
110-
111-
# _agent_i_obs = self._one_hot_encoding(agent_i, self.n_agents)
112-
# _agent_i_obs += [pos[0] / self._grid_shape[0], pos[1] / (self._grid_shape[1] - 1)] # coordinates
113-
# _agent_i_obs += [self.agent_health[agent_i]]
114-
# _agent_i_obs += [1 if self._agent_cool else 0] # flag if agent is cooling down
115-
116-
# team id , unique id, location,health, cooldown
117-
117+
# team id , unique id, location, health, cooldown
118118
_agent_i_obs = np.zeros((6, 5, 5))
119-
for row in range(0, 5):
120-
for col in range(0, 5):
121-
122-
if self.is_valid([row + (pos[0] - 2), col + (pos[1] - 2)]) and (
123-
PRE_IDS['empty'] not in self._full_obs[row + (pos[0] - 2)][col + (pos[1] - 2)]):
124-
x = self._full_obs[row + pos[0] - 2][col + pos[1] - 2]
125-
_type = 1 if PRE_IDS['agent'] in x else -1
126-
_id = int(x[1:]) - 1 # id
127-
_agent_i_obs[0][row][col] = _type
128-
_agent_i_obs[1][row][col] = _id
129-
_agent_i_obs[2][row][col] = self.agent_health[_id] if type == 1 else self.opp_health[_id]
130-
_agent_i_obs[3][row][col] = self._agent_cool[_id] if type == 1 else self._opp_cool[_id]
131-
_agent_i_obs[3][row][col] = 1 if _agent_i_obs[3][row][col] else -1 # cool/uncool
132-
133-
_agent_i_obs[4][row][col] = pos[0] / self._grid_shape[0] # x-coordinate
134-
_agent_i_obs[5][row][col] = pos[1] / self._grid_shape[1] # y-coordinate
119+
hp = self.agent_health[agent_i]
120+
121+
# If agent is alive
122+
if hp > 0:
123+
# _agent_i_obs = self._one_hot_encoding(agent_i, self.n_agents)
124+
# _agent_i_obs += [pos[0] / self._grid_shape[0], pos[1] / (self._grid_shape[1] - 1)] # coordinates
125+
# _agent_i_obs += [self.agent_health[agent_i]]
126+
# _agent_i_obs += [1 if self._agent_cool else 0] # flag if agent is cooling down
127+
128+
pos = self.agent_pos[agent_i]
129+
for row in range(0, 5):
130+
for col in range(0, 5):
131+
if self.is_valid([row + (pos[0] - 2), col + (pos[1] - 2)]) and (
132+
PRE_IDS['empty'] not in self._full_obs[row + (pos[0] - 2)][col + (pos[1] - 2)]):
133+
x = self._full_obs[row + pos[0] - 2][col + pos[1] - 2]
134+
_type = 1 if PRE_IDS['agent'] in x else -1
135+
_id = int(x[1:]) - 1 # id
136+
_agent_i_obs[0][row][col] = _type
137+
_agent_i_obs[1][row][col] = _id
138+
_agent_i_obs[2][row][col] = self.agent_health[_id] if _type == 1 else self.opp_health[_id]
139+
_agent_i_obs[3][row][col] = self._agent_cool[_id] if _type == 1 else self._opp_cool[_id]
140+
_agent_i_obs[3][row][col] = 1 if _agent_i_obs[3][row][col] else -1 # cool/uncool
141+
entity_position = self.agent_pos[_id] if _type == 1 else self.opp_pos[_id]
142+
_agent_i_obs[4][row][col] = entity_position[0] / self._grid_shape[0] # x-coordinate
143+
_agent_i_obs[5][row][col] = entity_position[1] / self._grid_shape[1] # y-coordinate
135144

136145
_agent_i_obs = _agent_i_obs.flatten().tolist()
137146
_obs.append(_agent_i_obs)
138-
139147
return _obs
140148

149+
def get_state(self):
150+
state = np.zeros((self.n_agents + self._n_opponents, 6))
151+
# agent info
152+
for agent_i in range(self.n_agents):
153+
hp = self.agent_health[agent_i]
154+
if hp > 0:
155+
pos = self.agent_pos[agent_i]
156+
feature = np.array([1, agent_i, hp, 1 if self._agent_cool[agent_i] else -1,
157+
pos[0] / self._grid_shape[0], pos[1] / self._grid_shape[1]], dtype=np.float)
158+
state[agent_i] = feature
159+
160+
# opponent info
161+
for opp_i in range(self._n_opponents):
162+
opp_hp = self.opp_health[opp_i]
163+
if opp_hp > 0:
164+
pos = self.opp_pos[opp_i]
165+
feature = np.array([-1, opp_i, opp_hp, 1 if self._opp_cool[opp_i] else -1,
166+
pos[0] / self._grid_shape[0], pos[1] / self._grid_shape[1]], dtype=np.float)
167+
state[opp_i + self.n_agents] = feature
168+
return state.flatten()
169+
170+
def get_state_size(self):
171+
return (self.n_agents + self._n_opponents) * 6
172+
141173
def __create_grid(self):
142174
_grid = [[PRE_IDS['empty'] for _ in range(self._grid_shape[1])] for row in range(self._grid_shape[0])]
143175
return _grid
@@ -161,7 +193,9 @@ def __init_full_obs(self):
161193

162194
# select agent team center
163195
# Note : Leaving space from edges so as to have a 5x5 grid around it
164-
agent_team_center = self.np_random.randint(2, self._grid_shape[0] - 3), self.np_random.randint(2, self._grid_shape[1] - 3)
196+
agent_team_center = self.np_random.randint(2, self._grid_shape[0] - 3), self.np_random.randint(2,
197+
self._grid_shape[
198+
1] - 3)
165199
# randomly select agent pos
166200
for agent_i in range(self.n_agents):
167201
while True:
@@ -199,10 +233,17 @@ def reset(self):
199233
self.agent_health = {_: self._init_health for _ in range(self.n_agents)}
200234
self.opp_health = {_: self._init_health for _ in range(self._n_opponents)}
201235
self._agent_cool = {_: True for _ in range(self.n_agents)}
236+
self._agent_cool_step = {_: 0 for _ in range(self.n_agents)}
202237
self._opp_cool = {_: True for _ in range(self._n_opponents)}
238+
self._opp_cool_step = {_: 0 for _ in range(self._n_opponents)}
203239
self._agent_dones = [False for _ in range(self.n_agents)]
204240

205241
self.__init_full_obs()
242+
243+
# For debug only
244+
self._agents_trace = {_: [self.agent_pos[_]] for _ in range(self.n_agents)}
245+
self._opponents_trace = {_: [self.opp_pos[_]] for _ in range(self._n_opponents)}
246+
206247
return self.get_agent_obs()
207248

208249
def render(self, mode='human'):
@@ -252,8 +293,9 @@ def __update_agent_pos(self, agent_i, move):
252293

253294
if next_pos is not None and self._is_cell_vacant(next_pos):
254295
self.agent_pos[agent_i] = next_pos
255-
self._full_obs[curr_pos[0]][curr_pos[1]] = PRE_IDS['empty']
296+
self.agent_prev_pos[agent_i] = curr_pos
256297
self.__update_agent_view(agent_i)
298+
self._agents_trace[agent_i].append(next_pos)
257299

258300
def __update_opp_pos(self, opp_i, move):
259301

@@ -274,8 +316,9 @@ def __update_opp_pos(self, opp_i, move):
274316

275317
if next_pos is not None and self._is_cell_vacant(next_pos):
276318
self.opp_pos[opp_i] = next_pos
277-
self._full_obs[curr_pos[0]][curr_pos[1]] = PRE_IDS['empty']
319+
self.opp_prev_pos[opp_i] = curr_pos
278320
self.__update_opp_view(opp_i)
321+
self._opponents_trace[opp_i].append(next_pos)
279322

280323
def is_valid(self, pos):
281324
return (0 <= pos[0] < self._grid_shape[0]) and (0 <= pos[1] < self._grid_shape[1])
@@ -296,19 +339,18 @@ def is_visible(source_pos, target_pos):
296339
and (source_pos[1] - 2) <= target_pos[1] <= (source_pos[1] + 2)
297340

298341
@staticmethod
299-
def is_fireable(source_pos, target_pos):
342+
def is_fireable(source_cooling_down, source_pos, target_pos):
300343
"""
301344
Checks if the target_pos is in the firing range(5x5)
302345
303346
:param source_pos: Coordinates of the source
304347
:param target_pos: Coordinates of the target
305348
:return:
306349
"""
307-
return (source_pos[0] - 1) <= target_pos[0] <= (source_pos[0] + 1) \
350+
return source_cooling_down and (source_pos[0] - 1) <= target_pos[0] <= (source_pos[0] + 1) \
308351
and (source_pos[1] - 1) <= target_pos[1] <= (source_pos[1] + 1)
309352

310-
def reduce_distance_move(self, source_pos, target_pos):
311-
353+
def reduce_distance_move(self, opp_i, source_pos, agent_i, target_pos):
312354
# Todo: makes moves Enum
313355
_moves = []
314356
if source_pos[0] > target_pos[0]:
@@ -321,6 +363,13 @@ def reduce_distance_move(self, source_pos, target_pos):
321363
elif source_pos[1] < target_pos[1]:
322364
_moves.append('RIGHT')
323365

366+
if len(_moves) == 0:
367+
print(self._step_count, source_pos, target_pos)
368+
print("agent-{}, hp={}, move_trace={}".format(agent_i, self.agent_health[agent_i],
369+
self._agents_trace[agent_i]))
370+
print(
371+
"opponent-{}, hp={}, move_trace={}".format(opp_i, self.opp_health[opp_i], self._opponents_trace[opp_i]))
372+
raise AssertionError("One place exists 2 entities!")
324373
move = self.np_random.choice(_moves)
325374
for k, v in ACTION_MEANING.items():
326375
if move.lower() == v.lower():
@@ -356,17 +405,18 @@ def opps_action(self):
356405
action = None
357406
for _, agent_i in sorted(opp_agent_distance[opp_i]):
358407
if agent_i in visible_agents:
359-
if self.is_fireable(self.opp_pos[opp_i], self.agent_pos[agent_i]):
408+
if self.is_fireable(self._opp_cool[opp_i], self.opp_pos[opp_i], self.agent_pos[agent_i]):
360409
action = agent_i + 5
361-
else:
362-
action = self.reduce_distance_move(self.opp_pos[opp_i], self.agent_pos[agent_i])
410+
elif self.opp_health[opp_i] > 0:
411+
action = self.reduce_distance_move(opp_i, self.opp_pos[opp_i], agent_i, self.agent_pos[agent_i])
363412
break
364413
if action is None:
365-
logger.info('No visible agent for enemy:{}'.format(opp_i))
366-
action = self.np_random.choice(range(5))
414+
if self.opp_health[opp_i] > 0:
415+
# logger.debug('No visible agent for enemy:{}'.format(opp_i))
416+
action = self.np_random.choice(range(5))
417+
else:
418+
action = 4 # dead opponent could only execute 'no-op' action.
367419
opp_action_n.append(action)
368-
369-
370420
return opp_action_n
371421

372422
def step(self, agents_action):
@@ -387,28 +437,49 @@ def step(self, agents_action):
387437
if self.agent_health[agent_i] > 0:
388438
if action > 4: # attack actions
389439
target_opp = action - 5
390-
if self.is_fireable(self.agent_pos[agent_i], self.opp_pos[target_opp]) \
440+
if self.is_fireable(self._agent_cool[agent_i], self.agent_pos[agent_i], self.opp_pos[target_opp]) \
391441
and opp_health[target_opp] > 0:
442+
# Fire
392443
opp_health[target_opp] -= 1
393444
rewards[agent_i] += 1
394445

446+
# Update agent cooling down
447+
self._agent_cool[agent_i] = False
448+
self._agent_cool_step[agent_i] = self._step_cool
449+
450+
# Remove opp from the map
395451
if opp_health[target_opp] == 0:
396452
pos = self.opp_pos[target_opp]
397453
self._full_obs[pos[0]][pos[1]] = PRE_IDS['empty']
398454

455+
# Update agent cooling down
456+
self._agent_cool_step[agent_i] = max(self._agent_cool_step[agent_i] - 1, 0)
457+
if self._agent_cool_step[agent_i] == 0 and not self._agent_cool[agent_i]:
458+
self._agent_cool[agent_i] = True
459+
399460
opp_action = self.opps_action
400461
for opp_i, action in enumerate(opp_action):
401462
if self.opp_health[opp_i] > 0:
402463
target_agent = action - 5
403464
if action > 4: # attack actions
404-
if self.is_fireable(self.opp_pos[opp_i], self.agent_pos[target_agent]) \
465+
if self.is_fireable(self._opp_cool[opp_i], self.opp_pos[opp_i], self.agent_pos[target_agent]) \
405466
and agent_health[target_agent] > 0:
467+
# Fire
406468
agent_health[target_agent] -= 1
407469
rewards[target_agent] -= 1
408470

471+
# Update opp cooling down
472+
self._opp_cool[opp_i] = False
473+
self._opp_cool_step[opp_i] = self._step_cool
474+
475+
# Remove agent from the map
409476
if agent_health[target_agent] == 0:
410477
pos = self.agent_pos[target_agent]
411478
self._full_obs[pos[0]][pos[1]] = PRE_IDS['empty']
479+
# Update opp cooling down
480+
self._opp_cool_step[opp_i] = max(self._opp_cool_step[opp_i] - 1, 0)
481+
if self._opp_cool_step[opp_i] == 0 and not self._opp_cool[opp_i]:
482+
self._opp_cool[opp_i] = True
412483

413484
self.agent_health, self.opp_health = agent_health, opp_health
414485

0 commit comments

Comments
 (0)