Skip to content

Commit 49b7753

Browse files
authored
feat: add clock versions (#35)
- switch and checkers has v2 and v3 version with clock info.
1 parent 8985a5b commit 49b7753

File tree

6 files changed

+52
-32
lines changed

6 files changed

+52
-32
lines changed

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# ma-gym
2-
A collection of multi agent environments based on OpenAI gym.
2+
It's a collection of multi agent environments based on OpenAI gym. Also, you can use [**minimal-marl**](https://github.com/koulanurag/minimal-marl) to warm-start training of agents.
33

44
![Python package](https://github.com/koulanurag/ma-gym/workflows/Python%20package/badge.svg)
55
![Upload Python Package](https://github.com/koulanurag/ma-gym/workflows/Upload%20Python%20Package/badge.svg)
6+
[![Wiki Docs](https://img.shields.io/badge/-Wiki%20Docs-informational?style=flat)](https://github.com/koulanurag/ma-gym/wiki)
7+
68

79
## Installation
810
Using PyPI:
@@ -45,7 +47,7 @@ while not all(done_n):
4547
env.close()
4648
```
4749

48-
Please refer to [Wiki](https://github.com/koulanurag/ma-gym/wiki/Usage) for complete usage details
50+
Please refer to [**Wiki**](https://github.com/koulanurag/ma-gym/wiki/Usage) for complete usage details
4951

5052
## Environments:
5153
- [x] Checkers

ma_gym/__init__.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,23 @@
1717

1818
# add new environments : iterate over full observability
1919
for i, observability in enumerate([False, True]):
20-
register(
21-
id='CrossOver-v' + str(i),
22-
entry_point='ma_gym.envs.crossover:CrossOver',
23-
kwargs={'full_observable': observability, 'step_cost': -0.5}
24-
)
25-
26-
register(
27-
id='Checkers-v' + str(i),
28-
entry_point='ma_gym.envs.checkers:Checkers',
29-
kwargs={'full_observable': observability}
30-
)
3120

32-
register(
33-
id='Switch2-v' + str(i),
34-
entry_point='ma_gym.envs.switch:Switch',
35-
kwargs={'n_agents': 2, 'full_observable': observability, 'step_cost': -0.1}
36-
)
37-
register(
38-
id='Switch4-v' + str(i),
39-
entry_point='ma_gym.envs.switch:Switch',
40-
kwargs={'n_agents': 4, 'full_observable': observability, 'step_cost': -0.1}
41-
)
21+
for clock in [False, True]:
22+
register(
23+
id='Checkers-v{}'.format(i + (2 if clock else 0)),
24+
entry_point='ma_gym.envs.checkers:Checkers',
25+
kwargs={'full_observable': observability, 'step_cost': -0.01, 'clock': clock}
26+
)
27+
register(
28+
id='Switch2-v{}'.format(i + (2 if clock else 0)),
29+
entry_point='ma_gym.envs.switch:Switch',
30+
kwargs={'n_agents': 2, 'full_observable': observability, 'step_cost': -0.1, 'clock': clock}
31+
)
32+
register(
33+
id='Switch4-v{}'.format(i + (2 if clock else 0)),
34+
entry_point='ma_gym.envs.switch:Switch',
35+
kwargs={'n_agents': 4, 'full_observable': observability, 'step_cost': -0.1, 'clock': clock}
36+
)
4237

4338
for num_max_cars in [4, 10]:
4439
register(

ma_gym/envs/checkers/checkers.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ class Checkers(gym.Env):
2626
"""
2727
metadata = {'render.modes': ['human', 'rgb_array']}
2828

29-
def __init__(self, full_observable=False, step_cost=-0.01, max_steps=100):
29+
def __init__(self, full_observable=False, step_cost=-0.01, max_steps=100, clock=False):
3030
self._grid_shape = (3, 8)
3131
self.n_agents = 2
3232
self._max_steps = max_steps
3333
self._step_count = None
3434
self._step_cost = step_cost
3535
self.full_observable = full_observable
36+
self._add_clock = clock
3637

3738
self.action_space = MultiAgentActionSpace([spaces.Discrete(5) for _ in range(self.n_agents)])
38-
self._obs_high = np.ones(2 + (3 * 3 * 5))
39-
self._obs_low = np.zeros(2 + (3 * 3 * 5))
39+
self._obs_high = np.ones(2 + (3 * 3 * 5) + (1 if clock else 0))
40+
self._obs_low = np.zeros(2 + (3 * 3 * 5) + (1 if clock else 0))
4041
if self.full_observable:
4142
self._obs_high = np.tile(self._obs_high, self.n_agents)
4243
self._obs_low = np.tile(self._obs_low, self.n_agents)
@@ -124,7 +125,8 @@ def get_agent_obs(self):
124125
_agent_i_obs += _agent_i_neighbour.flatten().tolist()
125126

126127
# adding time
127-
# _agent_i_obs += [self._step_count / self._max_steps]
128+
if self._add_clock:
129+
_agent_i_obs += [self._step_count / self._max_steps]
128130
_obs.append(_agent_i_obs)
129131

130132
if self.full_observable:

ma_gym/envs/switch/switch_one_corridor.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616
class Switch(gym.Env):
1717
metadata = {'render.modes': ['human', 'rgb_array']}
1818

19-
def __init__(self, full_observable: bool = False, step_cost: float = 0, n_agents: int = 4, max_steps: int = 50):
19+
def __init__(self, full_observable: bool = False, step_cost: float = 0, n_agents: int = 4, max_steps: int = 50,
20+
clock: bool = True):
2021
assert 2 <= n_agents <= 4, 'Number of Agents has to be in range [2,4]'
2122
self._grid_shape = (3, 7)
2223
self.n_agents = n_agents
2324
self._max_steps = max_steps
2425
self._step_count = None
2526
self._step_cost = step_cost
2627
self._total_episode_reward = None
28+
self._add_clock = clock
2729

2830
self.action_space = MultiAgentActionSpace([spaces.Discrete(5) for _ in range(self.n_agents)]) # l,r,t,d,noop
2931

@@ -44,8 +46,8 @@ def __init__(self, full_observable: bool = False, step_cost: float = 0, n_agents
4446

4547
self.full_observable = full_observable
4648
# agent pos (2)
47-
self._obs_high = np.array([1., 1.], dtype=np.float32)
48-
self._obs_low = np.array([0., 0.], dtype=np.float32)
49+
self._obs_high = np.ones(2 + (1 if self._add_clock else 0))
50+
self._obs_low = np.zeros(2 + (1 if self._add_clock else 0))
4951
if self.full_observable:
5052
self._obs_high = np.tile(self._obs_high, self.n_agents)
5153
self._obs_low = np.tile(self._obs_low, self.n_agents)
@@ -91,7 +93,8 @@ def get_agent_obs(self):
9193
pos = self.agent_pos[agent_i]
9294
_agent_i_obs = [round(pos[0] / (self._grid_shape[0] - 1), 2),
9395
round(pos[1] / (self._grid_shape[1] - 1), 2)]
94-
# _agent_i_obs += [self._step_count / self._max_steps] # add current step count (for time reference)
96+
if self._add_clock:
97+
_agent_i_obs += [self._step_count / self._max_steps] # add current step count (for time reference)
9598
_obs.append(_agent_i_obs)
9699

97100
if self.full_observable:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
extras['all'] = [item for group in extras.values() for item in group]
1313

1414
setup(name='ma_gym',
15-
version='0.0.6',
15+
version='0.0.7',
1616
description='A collection of multi agent environments based on OpenAI gym.',
1717
long_description_content_type='text/markdown',
1818
long_description=open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8').read(),

tests/envs/test_switch2.py

+18
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,21 @@ def test_optimal_rollout(env):
125125
reward_n, step_i)
126126
assert done == target_dones[step_i]
127127
step_i += 1
128+
129+
130+
@parametrize_plus('env',
131+
[fixture_ref(env),
132+
fixture_ref(env_full)])
133+
def test_max_steps(env):
134+
""" All agent remain at their initial position for the entire duration"""
135+
for _ in range(2):
136+
env.reset()
137+
step_i = 0
138+
done = [False for _ in range(env.n_agents)]
139+
while not all(done):
140+
obs, reward_n, done, _ = env.step([4 for _ in range(env.n_agents)])
141+
target_reward = [env._step_cost for _ in range(env.n_agents)]
142+
step_i += 1
143+
assert (reward_n == target_reward), \
144+
'step_cost is not correct. Expected {} ; Got {}'.format(target_reward, reward_n)
145+
assert step_i == env._max_steps, 'max-steps should be reached'

0 commit comments

Comments
 (0)