diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 43e4a8f5..419a4662 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -1,16 +1,11 @@ -"""Line generators (railway undertaking, "EVU").""" -import warnings +"""Line generators: Railway Undertaking (RU) / Eisenbahnverkehrsunternehmen (EVU).""" from typing import Tuple, List, Callable, Mapping, Optional, Any -import numpy as np from numpy.random.mtrand import RandomState -from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgent -from flatland.envs.timetable_utils import Line from flatland.envs import persistence -from flatland.utils.decorators import enable_infrastructure_lru_cache +from flatland.envs.timetable_utils import Line AgentPosition = Tuple[int, int] LineGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Line] @@ -46,8 +41,8 @@ def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1) self.speed_ratio_map = speed_ratio_map self.seed = seed - def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0, - np_random: RandomState = None) -> Line: + def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, + np_random: RandomState = None) -> Line: pass def __call__(self, *args, **kwargs): @@ -81,7 +76,7 @@ def decide_orientation(self, rail, start, target, possible_orientations, np_rand return 0 def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int, - np_random: RandomState) -> Line: + np_random: RandomState) -> Line: """ The generator that assigns tasks to all the agents @@ -102,12 +97,10 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re agents_target = [] agents_direction = [] - city1, city2 = None, None city1_num_stations, city2_num_stations = None, None city1_possible_orientations, city2_possible_orientations = None, None - for agent_idx in range(num_agents): if (agent_idx % 2 == 0): @@ -118,9 +111,9 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re city1_num_stations = len(train_stations[city1]) city2_num_stations = len(train_stations[city2]) city1_possible_orientations = [city_orientation[city1], - (city_orientation[city1] + 2) % 4] + (city_orientation[city1] + 2) % 4] city2_possible_orientations = [city_orientation[city2], - (city_orientation[city2] + 2) % 4] + (city_orientation[city2] + 2) % 4] # Agent 1 : city1 > city2, Agent 2: city2 > city1 agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations @@ -143,13 +136,11 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re agent_orientation = self.decide_orientation( rail, agent_start, agent_target, city2_possible_orientations, np_random) - # agent1 details agents_position.append((agent_start[0][0], agent_start[0][1])) agents_target.append((agent_target[0][0], agent_target[0][1])) agents_direction.append(agent_orientation) - if self.speed_ratio_map: speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random) else: @@ -163,7 +154,7 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions))) return Line(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds) + agent_targets=agents_target, agent_speeds=speeds) def line_from_file(filename, load_from_package=None) -> LineGenerator: @@ -182,11 +173,10 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, np_random: RandomState = None) -> Line: - env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) max_episode_steps = env_dict.get("max_episode_steps", 0) - if (max_episode_steps==0): + if (max_episode_steps == 0): print("This env file has no max_episode_steps (deprecated) - setting to 100") max_episode_steps = 100 @@ -196,12 +186,12 @@ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_r agents_position = [a.initial_position for a in agents] # this logic is wrong - we should really load the initial_direction as the direction. - #agents_direction = [a.direction for a in agents] + # agents_direction = [a.direction for a in agents] agents_direction = [a.initial_direction for a in agents] agents_target = [a.target for a in agents] agents_speed = [a.speed_counter.speed for a in agents] return Line(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed) + agent_targets=agents_target, agent_speeds=agents_speed) return generator diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index c20d7a07..c68167dd 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -3,7 +3,7 @@ """ import random from functools import lru_cache -from typing import List, Optional, Dict, Tuple +from typing import List, Optional, Dict, Tuple, Set import numpy as np @@ -11,6 +11,7 @@ from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.grid_utils import Vector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs import agent_chains as ac from flatland.envs import line_generators as line_gen @@ -99,7 +100,7 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator=None, + rail_generator: "RailGenerator" = None, line_generator: "LineGenerator" = None, # : line_gen.LineGenerator = line_gen.random_line_generator(), number_of_agents=2, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), @@ -167,7 +168,7 @@ def __init__(self, self.rail_generator = rail_generator if line_generator is None: line_generator = line_gen.sparse_line_generator() - self.line_generator: LineGenerator = line_generator + self.line_generator: "LineGenerator" = line_generator self.timetable_generator = timetable_generator self.rail: Optional[GridTransitionMap] = None @@ -205,6 +206,8 @@ def __init__(self, self.motionCheck = ac.MotionCheck() + self.level_free_positions: Set[Vector2D] = set() + def _seed(self, seed): self.np_random, seed = seeding.np_random(seed) random.seed(seed) @@ -314,6 +317,8 @@ def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, agents_hints = None if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] + if optionals and 'level_free_positions' in optionals: + self.level_free_positions = optionals['level_free_positions'] line = self.line_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets, self.np_random) @@ -562,8 +567,16 @@ def step(self, action_dict: Dict[int, RailEnvActions]): direction=new_direction, preprocessed_action=preprocessed_action) + # only conflict if the level-free cell is traversed through the same axis (horizontally (0 north or 2 south), or vertically (1 east or 3 west) + new_position_level_free = new_position + if new_position in self.level_free_positions: + new_position_level_free = (new_position, new_direction % 2) + agent_position_level_free = agent.position + if agent.position in self.level_free_positions: + agent_position_level_free = (agent.position, agent.direction % 2) + # This is for storing and later checking for conflicts of agents trying to occupy same cell - self.motionCheck.addAgent(i_agent, agent.position, new_position) + self.motionCheck.addAgent(i_agent, agent_position_level_free, new_position_level_free) # Find conflicts between trains trying to occupy same cell self.motionCheck.find_conflicts() @@ -571,11 +584,15 @@ def step(self, action_dict: Dict[int, RailEnvActions]): for agent in self.agents: i_agent = agent.handle + agent_position_level_free = agent.position + if agent.position in self.level_free_positions: + agent_position_level_free = (agent.position, agent.direction % 2) + ## Update positions if agent.malfunction_handler.in_malfunction: movement_allowed = False else: - movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) + movement_allowed = self.motionCheck.check_motion(i_agent, agent_position_level_free) movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit movement_allowed = movement_allowed or movement_inside_cell @@ -727,6 +744,7 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar return self.update_renderer(mode=mode, show=show, show_observations=show_observations, show_predictions=show_predictions, show_rowcols=show_rowcols, return_image=return_image) + def initialize_renderer(self, mode, gl, agent_render_variant, show_debug, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 9c3b3c5c..fe79d8cd 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -1,4 +1,5 @@ -"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" +"""Rail generators: infrastructure manager (IM) / Infrastrukturbetreiber (ISB).""" +import math import warnings from typing import Callable, Tuple, Optional, Dict, List @@ -9,7 +10,7 @@ from flatland.core.grid.grid4_utils import direction_to_point from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D, \ Vec2dOperations -from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum from flatland.core.transition_map import GridTransitionMap from flatland.envs import persistence from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ @@ -35,8 +36,7 @@ def __init__(self, *args, **kwargs): """ pass - def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGeneratorProduct: + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct: pass def __call__(self, *args, **kwargs) -> RailGeneratorProduct: @@ -53,8 +53,7 @@ class EmptyRailGen(RailGen): Primarily used by the editor """ - def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGenerator: + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid @@ -100,8 +99,7 @@ def __init__(self, rail_map, optionals=None): self.rail_map = rail_map self.optionals = optionals - def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGeneratorProduct: + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct: return self.rail_map, self.optionals @@ -116,7 +114,7 @@ def sparse_rail_generator(*args, **kwargs): class SparseRailGen(RailGen): def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2, - max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator: + max_rail_pairs_in_city: int = 2, seed=None, p_level_free: float = 0) -> RailGenerator: """ Generates railway networks with cities and inner city rails @@ -133,6 +131,8 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b Number of parallel tracks in the city. This represents the number of tracks in the trainstations seed: int Initiate the seed + p_level_free : float + Percentage of diamond-crossings which are level-free. Returns ------- @@ -143,9 +143,9 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b self.max_rails_between_cities = max_rails_between_cities self.max_rail_pairs_in_city = max_rail_pairs_in_city self.seed = seed + self.p_level_free = p_level_free - def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGenerator: + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct: """ Parameters @@ -181,7 +181,7 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0 min_nr_rail_pairs_in_city = 1 # (min pair must be 1) rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3) rails_between_cities = (rail_pairs_in_city * 2) if self.max_rails_between_cities > ( - rail_pairs_in_city * 2) else self.max_rails_between_cities + rail_pairs_in_city * 2) else self.max_rails_between_cities # We compute the city radius by the given max number of rails it can contain. # The radius is equal to the number of tracks divided by 2 @@ -237,11 +237,28 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0 # Fix all transition elements self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field) - return grid_map, {'agents_hints': { - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - }} + + # choose p_level_free percentage of diamond crossings to be level-free + num_diamond_crossings = np.count_nonzero(grid_map.grid[grid_map.grid == RailEnvTransitionsEnum.diamond_crossing]) + num_level_free_diamond_crossings = math.floor(self.p_level_free * num_diamond_crossings) + # ceil with probability p_ceil + p_ceil = (self.p_level_free * num_diamond_crossings) % 1.0 + num_level_free_diamond_crossings += np_random.choice([1, 0], p=(p_ceil, 1 - p_ceil)) + level_free_positions = set() + if num_level_free_diamond_crossings > 0: + choice = np_random.choice(num_diamond_crossings, size=num_level_free_diamond_crossings, replace=False) + positions_diamond_crossings = (grid_map.grid == RailEnvTransitionsEnum.diamond_crossing).nonzero() + level_free_positions = {tuple(positions_diamond_crossings[choice[i]]) for i in range(len(choice))} + + return grid_map, { + 'agents_hints': + { + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + }, + 'level_free_positions': level_free_positions + } def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int, height: int, np_random: RandomState = None) -> Tuple[ @@ -264,7 +281,6 @@ def _generate_random_city_positions(self, num_cities: int, city_radius: int, wid Returns ------- Returns a list of all city positions as coordinates (x,y) - """ city_positions: IntVector2DArray = [] @@ -322,7 +338,6 @@ def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: in Returns ------- Returns a list of all city positions as coordinates (x,y) - """ aspect_ratio = height / width # Compute max numbe of possible cities per row and col. @@ -526,7 +541,7 @@ def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction) """ Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given direction. Direction is a 90 degree cone facing the desired directiont. - Exampe: + Example: North: The closes neighbour in the North direction is within the cone spanned by a line going North-West and North-East @@ -677,7 +692,6 @@ def _fix_transitions(self, city_cells: set, inter_city_lines: List[IntVector2DAr Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1 grid_map: RailEnvTransitions The grid map containing the rails. Used to draw new rails - """ # Fix all cities with illegal transition maps diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 6abaddd0..1a51037f 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -1 +1 @@ -raise ImportError(" Schedule Generators is now renamed to line_generators, any reference to schedule should be replaced with line") \ No newline at end of file +raise ImportError(" Schedule Generators is now renamed to line_generators + timetable_generators, any reference to schedule should be replaced with line") diff --git a/flatland/envs/sparse_rail_gen.py b/flatland/envs/sparse_rail_gen.py deleted file mode 100644 index e001754b..00000000 --- a/flatland/envs/sparse_rail_gen.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" -import sys -import warnings -from typing import Callable, Tuple, Optional, Dict, List - -import numpy as np -from numpy.random.mtrand import RandomState - -from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d -from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \ - Vec2dOperations -from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ - fix_inner_nodes, align_cell_to_city -from flatland.envs import persistence - -from flatland.envs.rail_generators import RailGeneratorProduct, RailGenerator - diff --git a/flatland/envs/timetable_generators.py b/flatland/envs/timetable_generators.py index 5a2a8ad9..36132a34 100644 --- a/flatland/envs/timetable_generators.py +++ b/flatland/envs/timetable_generators.py @@ -1,9 +1,5 @@ -import os -import json -import itertools -import warnings -from typing import Tuple, List, Callable, Mapping, Optional, Any -from flatland.envs.timetable_utils import Timetable +"""Timetable generators: Railway Undertaking (RU) / Eisenbahnverkehrsunternehmen (EVU).""" +from typing import List import numpy as np from numpy.random.mtrand import RandomState @@ -11,6 +7,8 @@ from flatland.envs.agent_utils import EnvAgent from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env_shortest_paths import get_shortest_paths +from flatland.envs.timetable_utils import Timetable + def len_handle_none(v): if v is not None: @@ -18,14 +16,15 @@ def len_handle_none(v): else: return 0 -def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, - agents_hints: dict, np_random: RandomState = None) -> Timetable: + +def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, + agents_hints: dict, np_random: RandomState = None) -> Timetable: """ Calculates earliest departure and latest arrival times for the agents This is the new addition in Flatland 3 Also calculates the max episodes steps based on the density of the timetable - inputs: + inputs: agents - List of all the agents rail_env.agents distance_map - Distance map of positions to tagets of each agent in each direction agent_hints - Uses the number of cities @@ -43,35 +42,35 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, timedelay_factor = 4 alpha = 2 max_episode_steps = int(timedelay_factor * alpha * \ - (distance_map.rail.width + distance_map.rail.height + (len(agents) / num_cities))) - + (distance_map.rail.width + distance_map.rail.height + (len(agents) / num_cities))) + # Multipliers old_max_episode_steps_multiplier = 3.0 new_max_episode_steps_multiplier = 1.5 - travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier + travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier assert new_max_episode_steps_multiplier > travel_buffer_multiplier end_buffer_multiplier = 0.05 mean_shortest_path_multiplier = 0.2 - + shortest_paths = get_shortest_paths(distance_map) - shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()] + shortest_paths_lengths = [len_handle_none(v) for k, v in shortest_paths.items()] # Find mean_shortest_path_time agent_speeds = [agent.speed_counter.speed for agent in agents] - agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds) + agent_shortest_path_times = np.array(shortest_paths_lengths) / np.array(agent_speeds) mean_shortest_path_time = np.mean(agent_shortest_path_times) # Deciding on a suitable max_episode_steps longest_speed_normalized_time = np.max(agent_shortest_path_times) mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay) - + max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier) max_episode_steps = min(max_episode_steps_new, max_episode_steps_old) - + end_buffer = int(max_episode_steps * end_buffer_multiplier) - latest_arrival_max = max_episode_steps-end_buffer + latest_arrival_max = max_episode_steps - end_buffer # Useless unless needed by returning earliest_departures = [] @@ -80,12 +79,12 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, for agent in agents: agent_shortest_path_time = agent_shortest_path_times[agent.handle] agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay)) - + departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1) - + earliest_departure = np_random.randint(0, departure_window_max) latest_arrival = earliest_departure + agent_travel_time_max - + earliest_departures.append(earliest_departure) latest_arrivals.append(latest_arrival) @@ -93,15 +92,13 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, agent.latest_arrival = latest_arrival return Timetable(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals, - max_episode_steps=max_episode_steps) + max_episode_steps=max_episode_steps) -def ttgen_flatland2(agents: List[EnvAgent], distance_map: DistanceMap, - agents_hints: dict, np_random: RandomState = None) -> Timetable: +def ttgen_flatland2(agents: List[EnvAgent], distance_map: DistanceMap, + agents_hints: dict, np_random: RandomState = None) -> Timetable: nMaxSteps = 1000 return Timetable( - earliest_departures=[0]*len(agents), - latest_arrivals=[nMaxSteps]*len(agents), + earliest_departures=[0] * len(agents), + latest_arrivals=[nMaxSteps] * len(agents), max_episode_steps=1000) - - diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 96243c00..0fe1ffcc 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -2,7 +2,7 @@ import numpy as np -from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum from flatland.core.transition_map import GridTransitionMap @@ -42,16 +42,16 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - city_positions = [(0,3), (6, 6)] + city_positions = [(0, 3), (6, 6)] train_stations = [ - [( (0, 3), 0 ) ], - [( (6, 6), 0 ) ], - ] + [((0, 3), 0)], + [((6, 6), 0)], + ] city_orientations = [0, 2] agents_hints = {'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations - } + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -93,16 +93,16 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - city_positions = [(0,3), (6, 6)] + city_positions = [(0, 3), (6, 6)] train_stations = [ - [( (0, 3), 0 ) ], - [( (6, 6), 0 ) ], - ] + [((0, 3), 0)], + [((6, 6), 0)], + ] city_orientations = [0, 2] agents_hints = {'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations - } + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -141,16 +141,16 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array, Dict]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - city_positions = [(0,3), (6, 6)] + city_positions = [(0, 3), (6, 6)] train_stations = [ - [( (0, 3), 0 ) ], - [( (6, 6), 0 ) ], - ] + [((0, 3), 0)], + [((6, 6), 0)], + ] city_orientations = [0, 2] agents_hints = {'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations - } + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -190,16 +190,16 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array, Dict]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - city_positions = [(0,3), (6, 6)] + city_positions = [(0, 3), (6, 6)] train_stations = [ - [( (0, 3), 0 ) ], - [( (6, 6), 0 ) ], - ] + [((0, 3), 0)], + [((6, 6), 0)], + ] city_orientations = [0, 2] agents_hints = {'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations - } + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -245,16 +245,16 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array, D rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - city_positions = [(0,3), (6, 6)] + city_positions = [(0, 3), (6, 6)] train_stations = [ - [( (0, 3), 0 ) ], - [( (6, 6), 0 ) ], - ] + [((0, 3), 0)], + [((6, 6), 0)], + ] city_orientations = [0, 2] agents_hints = {'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations - } + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -294,20 +294,21 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict[str, A rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - city_positions = [(0,3), (6, 6)] + city_positions = [(0, 3), (6, 6)] train_stations = [ - [( (0, 3), 0 ) ], - [( (6, 6), 0 ) ], - ] + [((0, 3), 0)], + [((6, 6), 0)], + ] city_orientations = [0, 2] agents_hints = {'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations - } + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals -def make_oval_rail() -> Tuple[GridTransitionMap, np.array]: + +def make_oval_rail() -> Tuple[GridTransitionMap, np.array, Any]: transitions = RailEnvTransitions() cells = transitions.transition_list @@ -322,7 +323,7 @@ def make_oval_rail() -> Tuple[GridTransitionMap, np.array]: rail_map = np.array( [[empty] * 9] + [[empty] + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west] + [empty]] + - [[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]]+ + [[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] + [[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] + [[empty] + [right_turn_from_east] + [horizontal_straight] * 5 + [right_turn_from_north] + [empty]] + [[empty] * 9], dtype=np.uint16) @@ -341,4 +342,39 @@ def make_oval_rail() -> Tuple[GridTransitionMap, np.array]: 'city_orientations': city_orientations } optionals = {'agents_hints': agents_hints} - return rail, rail_map, optionals + return rail, rail_map, optionals + + +def make_diamond_crossing_rail() -> Tuple[GridTransitionMap, np.array, Dict]: + # We instantiate a very simple rail network on a 6x10 grid: + # Note that some cells have invalid RailEnvTransitions! + # | + # | + # _ _ | _ _ _ _ _ _ _ + # | + # | + # | + transitions = RailEnvTransitions() + rail_map = np.array( + [[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.dead_end_from_south] + [RailEnvTransitionsEnum.empty] * 7] + + [[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.vertical_straight] + [RailEnvTransitionsEnum.empty] * 7] * 2 + + [[RailEnvTransitionsEnum.dead_end_from_east] + [RailEnvTransitionsEnum.horizontal_straight] * 1 + [RailEnvTransitionsEnum.diamond_crossing] * 1 + [ + RailEnvTransitionsEnum.horizontal_straight] * 6 + [RailEnvTransitionsEnum.dead_end_from_west]] + + [[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.vertical_straight] + [RailEnvTransitionsEnum.empty] * 7] * 2 + + [[RailEnvTransitionsEnum.empty] * 2 + [RailEnvTransitionsEnum.dead_end_from_north] + [RailEnvTransitionsEnum.empty] * 7] + , dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + city_positions = [(1, 4), (4, 4)] + train_stations = [ + [((1, 4), 0)], + [((4, 4), 0)], + ] + city_orientations = [1, 3] + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index f6c6a9cd..cd2f2fa8 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -4,13 +4,14 @@ import numpy as np from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.envs.line_generators import sparse_line_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool -#deactivated: need a fix to match with astar, to activate rename def test_sparse... + +# deactivated: need a fix to match with astar, to activate rename def test_sparse... def notest_sparse_rail_generator(): env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10, max_rails_between_cities=3, @@ -503,7 +504,8 @@ def notest_sparse_rail_generator(): assert s0 == 36, "actual={}".format(s0) assert s1 == 27, "actual={}".format(s1) -#deactivated test: need a fix for astar, to activate rename test_sparse... + +# deactivated test: need a fix for astar, to activate rename test_sparse... def notest_sparse_rail_generator_deterministic(): """Check that sparse_rail_generator runs deterministic over different python versions!""" @@ -1274,6 +1276,7 @@ def notest_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(29, 23) == 0, "[29][23]" assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]" + def test_rail_env_action_required_info(): speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train @@ -1285,7 +1288,7 @@ def test_rail_env_action_required_info(): seed=5, # Random seed grid_mode=False # Ordered distribution of nodes ), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False) + obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False) env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( max_num_cities=10, @@ -1315,7 +1318,7 @@ def test_rail_env_action_required_info(): action_dict_only_if_action_required.update({a: action}) else: print("[{}] not action_required {}, speed_counter={}".format(step, a, - env_always_action.agents[a].speed_counter)) + env_always_action.agents[a].speed_counter)) obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step( action_dict_always_action) @@ -1435,22 +1438,141 @@ def test_sparse_generator_changes_to_grid_mode(): with warnings.catch_warnings(record=True) as w: rail_env.reset(True, True, random_seed=15) assert "[WARNING]" in str(w[-1].message) - + + +def test_sparse_generator_with_level_free_03(): + """Check that sparse generator generates level-free diamond-crossings.""" + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator( + max_num_cities=5, + max_rails_between_cities=3, + seed=215545, # Random seed + grid_mode=True, + p_level_free=0.3 + ), + line_generator=sparse_line_generator(speed_ration_map), + number_of_agents=1, + random_seed=1) + env.reset() + assert env.level_free_positions == set() + + +def test_sparse_generator_with_level_free_04(): + """Check that sparse generator generates level-free diamond crossings.""" + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator( + max_num_cities=5, + max_rails_between_cities=3, + seed=215545, # Random seed + grid_mode=True, + p_level_free=0.4 + ), + line_generator=sparse_line_generator(speed_ration_map), + number_of_agents=1, + random_seed=1) + env.reset() + assert env.level_free_positions == {(12, 20)} + + +def test_sparse_generator_with_level_free_08(): + """Check that sparse generator generates level with diamond crossings.""" + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator( + max_num_cities=5, + max_rails_between_cities=3, + seed=215545, # Random seed + grid_mode=True, + p_level_free=0.8 + ), + line_generator=sparse_line_generator(speed_ration_map), + number_of_agents=1, + random_seed=1) + env.reset() + assert env.level_free_positions == {(12, 20)} + + +def test_sparse_generator_with_level_free_09(): + """Check that sparse generator generates level with diamond crossings.""" + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator( + max_num_cities=5, + max_rails_between_cities=3, + seed=215545, # Random seed + grid_mode=True, + p_level_free=0.9 + ), + line_generator=sparse_line_generator(speed_ration_map), + number_of_agents=1, + random_seed=1) + env.reset() + assert env.level_free_positions == {(12, 20), (4, 18)} + + +def test_sparse_generator_with_level_free_10(): + """Check that sparse generator generates all diamond crossings as level-free if p_level_free=1.0.""" + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator( + max_num_cities=5, + max_rails_between_cities=3, + seed=215545, # Random seed + grid_mode=True, + p_level_free=1 + ), + line_generator=sparse_line_generator(speed_ration_map), + number_of_agents=1, + random_seed=1) + env.reset() + assert env.level_free_positions == {(4, 18), (12, 20)} def main(): # Make warnings into errors, to generate stack backtraces - warnings.simplefilter("error",) # category=DeprecationWarning) + warnings.simplefilter("error", ) # category=DeprecationWarning) # Then run selected tests. - #test_sparse_rail_generator() - #test_sparse_rail_generator_deterministic() - #test_rail_env_action_required_info() - #test_rail_env_malfunction_speed_info() - #test_sparse_generator_with_too_man_cities_does_not_break_down() - #test_sparse_generator_with_illegal_params_aborts() - #test_sparse_generator_changes_to_grid_mode() + # test_sparse_rail_generator() + # test_sparse_rail_generator_deterministic() + # test_rail_env_action_required_info() + # test_rail_env_malfunction_speed_info() + # test_sparse_generator_with_too_man_cities_does_not_break_down() + # test_sparse_generator_with_illegal_params_aborts() + # test_sparse_generator_changes_to_grid_mode() + if __name__ == "__main__": main() - diff --git a/tests/test_over_under_passes.py b/tests/test_over_under_passes.py new file mode 100644 index 00000000..2cccc52e --- /dev/null +++ b/tests/test_over_under_passes.py @@ -0,0 +1,235 @@ +import time + +from flatland.envs.line_generators import sparse_line_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env_action import RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.rail_trainrun_data_structures import Waypoint +from flatland.envs.step_utils.states import TrainState +from flatland.utils.rendertools import RenderTool +from flatland.utils.simple_rail import make_diamond_crossing_rail + + +def test_diamond_crossing_without_over_and_underpasses(rendering: bool = False): + rail, rail_map, optionals = make_diamond_crossing_rail() + + env = RailEnv( + width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + record_steps=True + ) + + env.reset() + env._max_episode_steps = 555 + + # set the initial position + agent_0 = env.agents[0] + agent_0.initial_position = (3, 0) # one cell ahead of diamond crossing facing east + agent_0.position = (3, 0) # one cell ahead of diamond crossing facing east + agent_0.direction = 3 # east + agent_0.initial_direction = 3 # east + agent_0.target = (3, 9) # east dead-end + agent_0.moving = True + agent_0.latest_arrival = 999 + agent_0._set_state(TrainState.MOVING) + + agent_1 = env.agents[1] + agent_1.initial_position = (1, 2) # one cell ahead of diamond crossing facing south + agent_1.position = (1, 2) # one cell ahead of diamond crossing facing south + agent_1.direction = 2 # south + agent_1.initial_direction = 2 # south + agent_1.target = (6, 2) # south dead-end + agent_1.moving = True + agent_1.latest_arrival = 999 + agent_1._set_state(TrainState.MOVING) + + env.distance_map._compute(env.agents, env.rail) + done = False + env_renderer = None + if rendering: + env_renderer = RenderTool(env) + while not done: + _, _, dones, _ = env.step({ + 0: RailEnvActions.MOVE_FORWARD, + 1: RailEnvActions.MOVE_FORWARD, + }) + done = dones["__all__"] + if env_renderer is not None: + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + time.sleep(1.2) + + waypoints = [] + for agent_states in env.cur_episode: + cur = [] + for agent_state in agent_states: + r, c, d, _, _, _ = agent_state + cur.append(Waypoint((r, c), d)) + waypoints.append(cur) + expected = [ + # agent 0 and agent 1 both want to enter the diamond-crossing at (3,2) + [Waypoint(position=(3, 1), direction=1), Waypoint(position=(2, 2), direction=2)], + # agent 1 waits until agent 0 has passed the diamond crossing at (3,2) + [Waypoint(position=(3, 2), direction=1), Waypoint(position=(2, 2), direction=2)], + [Waypoint(position=(3, 3), direction=1), Waypoint(position=(3, 2), direction=2)], + [Waypoint(position=(3, 4), direction=1), Waypoint(position=(4, 2), direction=2)], + [Waypoint(position=(3, 5), direction=1), Waypoint(position=(5, 2), direction=2)], + [Waypoint(position=(3, 6), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(3, 7), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(3, 8), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(0, 0), direction=1), Waypoint(position=(0, 0), direction=2)] + ] + assert expected == waypoints, waypoints + + +def test_diamond_crossing_with_over_and_underpasses(rendering: bool = False): + rail, rail_map, optionals = make_diamond_crossing_rail() + + env = RailEnv( + width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + record_steps=True + ) + + env.reset() + env._max_episode_steps = 555 + + # set the initial position + agent_0 = env.agents[0] + agent_0.initial_position = (3, 0) # one cell ahead of diamond crossing facing east + agent_0.position = (3, 0) # one cell ahead of diamond crossing facing east + agent_0.direction = 3 # east + agent_0.initial_direction = 3 # east + agent_0.target = (3, 9) # east dead-end + agent_0.moving = True + agent_0.latest_arrival = 999 + agent_0._set_state(TrainState.MOVING) + + agent_1 = env.agents[1] + agent_1.initial_position = (1, 2) # one cell ahead of diamond crossing facing south + agent_1.position = (1, 2) # one cell ahead of diamond crossing facing south + agent_1.direction = 2 # south + agent_1.initial_direction = 2 # south + agent_1.target = (6, 2) # south dead-end + agent_1.moving = True + agent_1.latest_arrival = 999 + agent_1._set_state(TrainState.MOVING) + + env.level_free_positions.add((3, 2)) + + env.distance_map._compute(env.agents, env.rail) + done = False + env_renderer = None + if rendering: + env_renderer = RenderTool(env) + while not done: + _, _, dones, _ = env.step({ + 0: RailEnvActions.MOVE_FORWARD, + 1: RailEnvActions.MOVE_FORWARD, + }) + done = dones["__all__"] + if env_renderer is not None: + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + time.sleep(1.2) + + waypoints = [] + for agent_states in env.cur_episode: + cur = [] + for agent_state in agent_states: + r, c, d, _, _, _ = agent_state + cur.append(Waypoint((r, c), d)) + waypoints.append(cur) + expected = [ + # agent 0 and agent 1 both want to enter the diamond-crossing at (3,2) + [Waypoint(position=(3, 1), direction=1), Waypoint(position=(2, 2), direction=2)], + # agent 0 and agent 1 can enter the level-free diamond crossing at (3,2) + [Waypoint(position=(3, 2), direction=1), Waypoint(position=(3, 2), direction=2)], + [Waypoint(position=(3, 3), direction=1), Waypoint(position=(4, 2), direction=2)], + [Waypoint(position=(3, 4), direction=1), Waypoint(position=(5, 2), direction=2)], + [Waypoint(position=(3, 5), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(3, 6), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(3, 7), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(3, 8), direction=1), Waypoint(position=(0, 0), direction=2)], + [Waypoint(position=(0, 0), direction=1), Waypoint(position=(0, 0), direction=2)] + ] + assert expected == waypoints, waypoints + + +def test_diamond_crossing_with_over_and_underpasses_head_on(rendering: bool = False): + rail, rail_map, optionals = make_diamond_crossing_rail() + + env = RailEnv( + width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + record_steps=True + ) + env.reset() + env._max_episode_steps = 5 + + # set the initial position + agent_0 = env.agents[0] + agent_0.initial_position = (3, 0) # one cell ahead of diamond crossing facing east + agent_0.position = (3, 0) # one cell ahead of diamond crossing facing east + agent_0.direction = 3 # east + agent_0.initial_direction = 3 # east + agent_0.target = (3, 9) # east dead-end + agent_0.moving = True + agent_0.latest_arrival = 999 + agent_0._set_state(TrainState.MOVING) + + agent_1 = env.agents[1] + agent_1.initial_position = (3, 4) # one cell ahead of diamond crossing facing west + agent_1.position = (3, 4) # one cell ahead of diamond crossing facing west + agent_1.direction = 3 # west + agent_1.initial_direction = 3 # west + agent_1.target = (3, 0) # west dead-end + agent_1.moving = True + agent_1.latest_arrival = 999 + agent_1._set_state(TrainState.MOVING) + + env.level_free_positions.add((3, 2)) + + env.distance_map._compute(env.agents, env.rail) + done = False + env_renderer = None + if rendering: + env_renderer = RenderTool(env) + while not done: + _, _, dones, _ = env.step({ + 0: RailEnvActions.MOVE_FORWARD, + 1: RailEnvActions.MOVE_FORWARD, + }) + done = dones["__all__"] + if env_renderer is not None: + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + time.sleep(1.2) + + waypoints = [] + for agent_states in env.cur_episode: + cur = [] + for agent_state in agent_states: + r, c, d, _, _, _ = agent_state + cur.append(Waypoint((r, c), d)) + waypoints.append(cur) + expected = [ + # agent 0 and agent 1 both want to enter the diamond-crossing at (3,2) + [Waypoint(position=(3, 1), direction=1), Waypoint(position=(3, 3), direction=3)], + # agent 0 and agent 1 are stuck (head-on) + [Waypoint(position=(3, 2), direction=1), Waypoint(position=(3, 3), direction=3)], + [Waypoint(position=(3, 2), direction=1), Waypoint(position=(3, 3), direction=3)], + [Waypoint(position=(3, 2), direction=1), Waypoint(position=(3, 3), direction=3)], + [Waypoint(position=(3, 2), direction=1), Waypoint(position=(3, 3), direction=3)]] + assert expected == waypoints, waypoints