-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom_agent.py
55 lines (44 loc) · 1.43 KB
/
random_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient
remote_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
my_observation_builder = GlobalObsForRailEnv()
episode = 0
while True:
print("/ start random_agent", flush=True)
print("==============")
episode += 1
print("[INFO] EPISODE_START : {}".format(episode))
# NO WAY TO CHECK service/self.evaluation_done in client
obs, info = remote_client.env_create(obs_builder_object=my_observation_builder)
if not obs:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
print("[INFO] DONE ALL, BREAKING")
break
while True:
action = my_controller(obs, remote_client.env)
try:
observation, all_rewards, done, info = remote_client.env_step(
action)
except:
print("[ERR] DONE BUT step() CALLED")
if (True): # debug
print("-----")
# print(done)
print("[DEBUG] REW: ", all_rewards)
# break
if done['__all__']:
print("[INFO] EPISODE_DONE : ", episode)
print("[INFO] TOTAL_REW: ", sum(list(all_rewards.values())))
break
print("Evaluation Complete...")
print(remote_client.submit())
print("\\ end random_agent", flush=True)