diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d8e74fd..2f81a23 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -25,6 +25,10 @@ jobs: registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Convert to github_repo_name and tag string to lowercase + id: convert2lowercase + run: INPUT=${{ github.repository }}:${{ github.event.inputs.tag }}; echo "REPO_TAG_LOWERCASE=${INPUT,,}">>${GITHUB_OUTPUT} + - run: echo ${{steps.convert2lowercase.outputs.REPO_TAG_LOWERCASE}} - name: Build and Push Container Image uses: docker/build-push-action@v6 with: @@ -33,4 +37,4 @@ jobs: platforms: linux/amd64,linux/arm64/v8 push: true tags: | - ghcr.io/${{ github.repository }}:${{ github.event.inputs.tag }} \ No newline at end of file + ghcr.io/${{ steps.convert2lowercase.outputs.REPO_TAG_LOWERCASE }} diff --git a/Dockerfile b/Dockerfile index 7a57382..7eff961 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,6 +22,7 @@ RUN conda --version && \ python -c 'from flatland.evaluators.client import FlatlandRemoteClient' COPY run.sh ./ -COPY random_agent.py ./ +COPY src/ ./src +COPY run_solution.py ./ ENTRYPOINT bash run.sh diff --git a/README.md b/README.md index cf20b8f..30991dd 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ More precisely, Flatland 3 Benchmarks follow Flatland 3 Challenge's [Round 2 Environment Configurations](https://flatland-association.github.io/flatland-book/challenges/flatland3/envconfig.html#round-2), having the same environment configuration but generated with different seeds. -This starterkit features a random agent [random_agent.py](random_agent.py) +This starterkit features a shortest path deadlock avoidance agent [run_solution.py](run_solution.py) ## TL;DR; aka. First Submission @@ -41,6 +41,7 @@ There is a `demo` showcase illustrating this setup: ```shell cd demo +docker compose build docker compose up ``` diff --git a/random_agent.py b/random_agent.py deleted file mode 100644 index 3bd69bd..0000000 --- a/random_agent.py +++ /dev/null @@ -1,55 +0,0 @@ -import numpy as np -from flatland.envs.observations import GlobalObsForRailEnv -from flatland.evaluators.client import FlatlandRemoteClient - -remote_client = FlatlandRemoteClient() - - -def my_controller(obs, _env): - _action = {} - for _idx, _ in enumerate(_env.agents): - _action[_idx] = np.random.randint(0, 5) - return _action - - -my_observation_builder = GlobalObsForRailEnv() - -episode = 0 - -while True: - print("/ start random_agent", flush=True) - print("==============") - episode += 1 - print("[INFO] EPISODE_START : {}".format(episode)) - # NO WAY TO CHECK service/self.evaluation_done in client - - obs, info = remote_client.env_create(obs_builder_object=my_observation_builder) - if not obs: - """ - The remote env returns False as the first obs - when it is done evaluating all the individual episodes - """ - print("[INFO] DONE ALL, BREAKING") - break - - while True: - action = my_controller(obs, remote_client.env) - try: - observation, all_rewards, done, info = remote_client.env_step( - action) - except: - print("[ERR] DONE BUT step() CALLED") - - if (True): # debug - print("-----") - # print(done) - print("[DEBUG] REW: ", all_rewards) - # break - if done['__all__']: - print("[INFO] EPISODE_DONE : ", episode) - print("[INFO] TOTAL_REW: ", sum(list(all_rewards.values()))) - break - -print("Evaluation Complete...") -print(remote_client.submit()) -print("\\ end random_agent", flush=True) diff --git a/run.sh b/run.sh index 5695821..e32d995 100755 --- a/run.sh +++ b/run.sh @@ -6,5 +6,5 @@ source /home/conda/.bashrc source activate base conda activate flatland-rl python -m pip list -python random_agent.py +python run_solution.py echo "\\ end submission_template/run.sh" diff --git a/run_solution.py b/run_solution.py new file mode 100644 index 0000000..23ccd02 --- /dev/null +++ b/run_solution.py @@ -0,0 +1,93 @@ +from flatland.envs.rail_env import RailEnv +from flatland.evaluators.client import FlatlandRemoteClient + +from src.observation.dummy_observation import FlatlandDummyObservation +from src.policy.deadlock_avoidance_policy import DeadLockAvoidancePolicy +from src.policy.random_policy import RandomPolicy +from src.utils.progress_bar import ProgressBar + +remote_client = FlatlandRemoteClient() + +# ------------------- user code ------------------------- +my_observation_builder = FlatlandDummyObservation() +flatlandSolver = DeadLockAvoidancePolicy() + +use_random_policy = False +if use_random_policy: + flatlandSolver = RandomPolicy() +# --------------------------------------------------------- + +episode = 0 + +while True: + print("/ start {}".format(flatlandSolver.get_name()), flush=True) + episode += 1 + print("[INFO] EPISODE_START : {}".format(episode)) + # NO WAY TO CHECK service/self.evaluation_done in client + + observations, info = remote_client.env_create(obs_builder_object=my_observation_builder) + if isinstance(observations, bool): + if not observations: + """ + The remote env returns False as the first obs + when it is done evaluating all the individual episodes + """ + print("[INFO] DONE ALL, BREAKING") + break + + # ------------------- user code ------------------------- + # init the policy + env: RailEnv = remote_client.env + flatlandSolver.reset(env) + flatlandSolver.start_episode(False) + # --------------------------------------------------------- + + pbar = ProgressBar() + total_reward = 0 + nbr_done = 0 + while True: + try: + # ------------------- user code ------------------------- + # call the policy to act + flatlandSolver.start_step(False) + # --------------------------------------------------------- + + # ------------------- user code ------------------------- + actions = {} + eps = 0 + for handle in env.get_agent_handles(): + # choose action for agent (handle) + action = flatlandSolver.act(handle, observations[handle]) + actions.update({handle: action}) + # --------------------------------------------------------- + + observations, all_rewards, done, info = remote_client.env_step(actions) + total_reward += sum(list(all_rewards.values())) + if env._elapsed_steps < env._max_episode_steps: + nbr_done = sum(list(done.values())[:-1]) + + # ------------------- user code ------------------------- + flatlandSolver.end_step(False) + # --------------------------------------------------------- + + except: + print("[ERR] DONE BUT step() CALLED") + + if (True): # debug + if done['__all__']: + pbar.console_print(nbr_done, env.get_num_agents(), 'Nbr of done agents: {}'.format(len(done) - 1), '') + + # break + if done['__all__']: + print("[INFO] TOTAL_REW: ", total_reward) + print("[INFO] EPISODE_DONE : ", episode) + break + + # ------------------- user code ------------------------- + # do clean up + flatlandSolver.end_episode(False) + # --------------------------------------------------------- + +print("Evaluation Complete...") +print(remote_client.submit()) +print("\\ end random_agent", flush=True) diff --git a/src/observation/dummy_observation.py b/src/observation/dummy_observation.py new file mode 100644 index 0000000..de086d3 --- /dev/null +++ b/src/observation/dummy_observation.py @@ -0,0 +1,15 @@ +from typing import Optional, List + +import numpy as np +from flatland.core.env_observation_builder import DummyObservationBuilder + + +class FlatlandDummyObservation(DummyObservationBuilder): + def __init__(self): + pass + + def get_many(self, handles: Optional[List[int]] = None) -> bool: + obs = np.ones((len(handles), 1)) + for handle in handles: + obs[handle][0] = handle + return obs diff --git a/src/policy/deadlock_avoidance_policy.py b/src/policy/deadlock_avoidance_policy.py new file mode 100644 index 0000000..4b82955 --- /dev/null +++ b/src/policy/deadlock_avoidance_policy.py @@ -0,0 +1,218 @@ +from functools import lru_cache +from typing import Union + +import matplotlib.pyplot as plt +import numpy as np +from flatland.envs.fast_methods import fast_count_nonzero +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.step_utils.states import TrainState + +from src.policy.policy import Policy +from src.utils.flatland.shortest_distance_walker import ShortestDistanceWalker + +# activate LRU caching +flatland_deadlock_avoidance_policy_lru_cache_functions = [] + + +def _enable_flatland_deadlock_avoidance_policy_lru_cache(*args, **kwargs): + def decorator(func): + func = lru_cache(*args, **kwargs)(func) + flatland_deadlock_avoidance_policy_lru_cache_functions.append(func) + return func + + return decorator + + +def _send_flatland_deadlock_avoidance_policy_data_change_signal_to_reset_lru_cache(): + for func in flatland_deadlock_avoidance_policy_lru_cache_functions: + func.cache_clear() + + +class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): + def __init__(self, env: RailEnv): + super().__init__(env) + self.shortest_distance_agent_map = None + self.full_shortest_distance_agent_map = None + self.agent_positions = None + self.opp_agent_map = {} + self.same_agent_map = {} + + def reset(self, env: RailEnv): + super(DeadlockAvoidanceShortestDistanceWalker, self).reset(env) + self.shortest_distance_agent_map = None + self.full_shortest_distance_agent_map = None + self.agent_positions = None + self.opp_agent_map = {} + self.same_agent_map = {} + _send_flatland_deadlock_avoidance_policy_data_change_signal_to_reset_lru_cache() + + def clear(self, agent_positions): + self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), + self.env.height, + self.env.width), + dtype=int) - 1 + + self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), + self.env.height, + self.env.width), + dtype=int) - 1 + + self.agent_positions = agent_positions + + self.opp_agent_map = {} + self.same_agent_map = {} + + def getData(self): + return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map + + def callback(self, handle, agent, position, direction, action, possible_transitions) -> bool: + opp_a = self.agent_positions[position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != direction: + d = self.opp_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.opp_agent_map.update({handle: d}) + else: + if len(self.opp_agent_map.get(handle, [])) == 0: + d = self.same_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.same_agent_map.update({handle: d}) + + if len(self.opp_agent_map.get(handle, [])) == 0: + if self._is_no_switch_cell(position): + self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + return True + + @_enable_flatland_deadlock_avoidance_policy_lru_cache(maxsize=100000) + def _is_no_switch_cell(self, position) -> bool: + for new_dir in range(4): + possible_transitions = self.env.rail.get_transitions(*position, new_dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + return False + return True + + +# define Python user-defined exceptions +class InvalidRawEnvironmentException(Exception): + def __init__(self, env, message="This policy works only with a RailEnv or its specialized version. " + "Please check the raw_env . "): + self.env = env + self.message = message + super().__init__(self.message) + + +class DeadLockAvoidancePolicy(Policy): + def __init__(self, + action_size: int = 5, + min_free_cell: int = 1, + enable_eps=False, + show_debug_plot=False): + super(Policy, self).__init__() + self.env: RailEnv = None + self.loss = 0 + self.action_size = action_size + self.agent_can_move = {} + self.show_debug_plot = show_debug_plot + self.enable_eps = enable_eps + self.shortest_distance_walker: Union[DeadlockAvoidanceShortestDistanceWalker, None] = None + self.min_free_cell = min_free_cell + self.agent_positions = None + + def get_name(self): + return self.__class__.__name__ + + def step(self, handle, state, action, reward, next_state, done): + pass + + def act(self, handle, state, eps=0.): + # Epsilon-greedy action selection + if self.enable_eps: + if np.random.random() < eps: + return np.random.choice(np.arange(self.action_size)) + + # agent = self.env.agents[state[0]] + check = self.agent_can_move.get(handle, None) + act = RailEnvActions.STOP_MOVING + if check is not None: + act = check[3] + return act + + def reset(self, env: RailEnv): + self.env = env + if self.shortest_distance_walker is not None: + self.shortest_distance_walker.reset(self.env) + self.shortest_distance_walker = None + self.agent_positions = None + self.shortest_distance_walker = None + + def start_step(self, train): + self._build_agent_position_map() + self._shortest_distance_mapper() + self._extract_agent_can_move() + + def _build_agent_position_map(self): + # build map with agent positions (only active agents) + self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1 + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.state in [TrainState.MOVING, TrainState.STOPPED, TrainState.MALFUNCTION]: + if agent.position is not None: + self.agent_positions[agent.position] = handle + + def _shortest_distance_mapper(self): + if self.shortest_distance_walker is None: + self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env) + self.shortest_distance_walker.clear(self.agent_positions) + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.state <= TrainState.MALFUNCTION: + self.shortest_distance_walker.walk_to_target(handle) + + def _extract_agent_can_move(self): + self.agent_can_move = {} + shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData() + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.state < TrainState.DONE: + if self._check_agent_can_move(handle, + shortest_distance_agent_map[handle], + self.shortest_distance_walker.same_agent_map.get(handle, []), + self.shortest_distance_walker.opp_agent_map.get(handle, []), + full_shortest_distance_agent_map): + next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle) + self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]}) + + if self.show_debug_plot: + a = np.floor(np.sqrt(self.env.get_num_agents())) + b = np.ceil(self.env.get_num_agents() / a) + for handle in range(self.env.get_num_agents()): + plt.subplot(a, b, handle + 1) + plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle]) + plt.show(block=False) + plt.pause(0.01) + + def _check_agent_can_move(self, + handle, + my_shortest_walking_path, + same_agents, + opp_agents, + full_shortest_distance_agent_map): + agent_positions_map = (self.agent_positions > -1).astype(int) + len_opp_agents = len(opp_agents) + for opp_a in opp_agents: + opp = full_shortest_distance_agent_map[opp_a] + delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int) + sum_delta = np.sum(delta) + if sum_delta < (self.min_free_cell + len_opp_agents): + return False + return True + + def save(self, filename): + pass + + def load(self, filename): + pass diff --git a/src/policy/policy.py b/src/policy/policy.py new file mode 100644 index 0000000..6df10d6 --- /dev/null +++ b/src/policy/policy.py @@ -0,0 +1,46 @@ +from flatland.envs.rail_env import RailEnv + + +class Policy: + + def __init__(self): + print('>> ' + self.get_name()) + + def get_name(self): + raise NotImplementedError + + def save(self, filename): + raise NotImplementedError + + def load(self, filename): + raise NotImplementedError + + def start_episode(self, train: bool): + pass + + def start_step(self, train: bool): + pass + + def act(self, handle: int, state, eps=0.): + raise NotImplementedError + + def step(self, handle: int, state, action, reward, next_state, done): + raise NotImplementedError + + def end_step(self, train: bool): + pass + + def end_episode(self, train: bool): + pass + + def load_replay_buffer(self, filename): + pass + + def test(self): + pass + + def reset(self, env: RailEnv): + pass + + def clone(self): + return self diff --git a/src/policy/random_policy.py b/src/policy/random_policy.py new file mode 100644 index 0000000..a7802f6 --- /dev/null +++ b/src/policy/random_policy.py @@ -0,0 +1,25 @@ +import numpy as np + +from src.policy.policy import Policy + + +class RandomPolicy(Policy): + def __init__(self, + action_size: int = 5): + super(RandomPolicy, self).__init__() + self.action_size = action_size + + def get_name(self): + return self.__class__.__name__ + + def save(self, filename): + pass + + def load(self, filename): + pass + + def act(self, handle: int, state, eps=0.): + return np.random.choice(self.action_size) + + def step(self, handle: int, state, action, reward, next_state, done): + pass diff --git a/src/utils/flatland/shortest_distance_walker.py b/src/utils/flatland/shortest_distance_walker.py new file mode 100644 index 0000000..2bea882 --- /dev/null +++ b/src/utils/flatland/shortest_distance_walker.py @@ -0,0 +1,130 @@ +from functools import lru_cache + +import numpy as np +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.fast_methods import fast_count_nonzero, fast_argmax +from flatland.envs.rail_env import RailEnv, RailEnvActions + +# activate LRU caching +_flatland_shortest_distance_walker_lru_cache_functions = [] + + +def _enable_flatland_shortest_distance_walker_lru_cache(*args, **kwargs): + def decorator(func): + func = lru_cache(*args, **kwargs)(func) + _flatland_shortest_distance_walker_lru_cache_functions.append(func) + return func + + return decorator + + +def _send_flatland_shortest_distance_walker_data_change_signal_to_reset_lru_cache(): + for func in _flatland_shortest_distance_walker_lru_cache_functions: + func.cache_clear() + + +class ShortestDistanceWalker: + def __init__(self, env: RailEnv): + self.env = env + self.distance_map = None + + def reset(self, env: RailEnv): + _send_flatland_shortest_distance_walker_data_change_signal_to_reset_lru_cache() + self.env = env + self.distance_map = None + + @_enable_flatland_shortest_distance_walker_lru_cache(maxsize=100000) + def walk(self, handle, position, direction): + if self.distance_map is None: + self.distance_map = self.env.distance_map.get() + + possible_transitions = self.env.rail.get_transitions(*position, direction) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions == 1: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(position, new_direction) + dist = self.distance_map[handle, new_position[0], new_position[1], new_direction] + return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions + else: + min_distances = [] + positions = [] + directions = [] + for new_direction in [(direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(position, new_direction) + min_distances.append( + self.distance_map[handle, new_position[0], new_position[1], new_direction] + ) + positions.append(new_position) + directions.append(new_direction) + else: + min_distances.append(np.inf) + positions.append(None) + directions.append(None) + + a = self.get_action(min_distances) + return positions[a], directions[a], min_distances[a], a + 1, possible_transitions + + def get_action(self, min_distances): + return np.argmin(min_distances) + + def callback(self, handle, agent, position, direction, action, possible_transitions) -> bool: + return True + + @_enable_flatland_shortest_distance_walker_lru_cache(maxsize=100000) + def get_agent_position_and_direction(self, agent_position, agent_direction, agent_initial_position): + if agent_position is not None: + position = agent_position + else: + position = agent_initial_position + direction = agent_direction + return position, direction + + def walk_to_target(self, handle, position=None, direction=None, max_step=500): + agent = self.env.agents[handle] + position, direction = self._get_pos_dir_wtt(position, direction, + agent.position, agent.direction, + agent.initial_position) + + agent = self.env.agents[handle] + step = 0 + while (position != agent.target) and (step < max_step): + position, direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if position is None: + break + if not self.callback(handle, agent, position, direction, action, possible_transitions): + break + step += 1 + + @_enable_flatland_shortest_distance_walker_lru_cache(maxsize=100000) + def _get_pos_dir_wtt(self, position, direction, agent_pos, agent_dir, agent_initial_position): + + if position is None and direction is None: + position, direction = self.get_agent_position_and_direction(agent_pos, agent_dir, agent_initial_position) + elif position is None: + position, _ = self.get_agent_position_and_direction(agent_pos, agent_dir, agent_initial_position) + elif direction is None: + _, direction = self.get_agent_position_and_direction(agent_pos, agent_dir, agent_initial_position) + + return position, direction + + def callback_one_step(self, handle, agent, position, direction, action, possible_transitions): + pass + + def walk_one_step(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + possible_transitions = (0, 1, 0, 0) + new_position = agent.target + new_direction = agent.direction + action = RailEnvActions.STOP_MOVING + if position != agent.target: + new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if new_position is None: + return position, direction, RailEnvActions.STOP_MOVING, possible_transitions + self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions) + return new_position, new_direction, action, possible_transitions diff --git a/src/utils/progress_bar.py b/src/utils/progress_bar.py new file mode 100644 index 0000000..100be45 --- /dev/null +++ b/src/utils/progress_bar.py @@ -0,0 +1,16 @@ +import numpy as np + + +class ProgressBar: + @staticmethod + def console_print(itr: float, max_value: float, info: str = '', start_str='\r'): + percent = max(0.0, min(100.0, (itr / max_value * 100))) + i_percent = 100.0 - percent + print('{}{}{} {:5.1f}% {}'.format(start_str, + "#" * int(np.ceil(percent)), + "-" * int(np.ceil(i_percent)), + percent, + info), end=' ') + + if itr >= max_value: + print('')