Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

110 Over- and underpasses (aka. level-free diamond crossings). #120

Merged
merged 14 commits into from
Mar 5, 2025
32 changes: 11 additions & 21 deletions flatland/envs/line_generators.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
"""Line generators (railway undertaking, "EVU")."""
import warnings
"""Line generators: Railway Undertaking (RU) / Eisenbahnverkehrsunternehmen (EVU)."""
from typing import Tuple, List, Callable, Mapping, Optional, Any

import numpy as np
from numpy.random.mtrand import RandomState

from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.timetable_utils import Line
from flatland.envs import persistence
from flatland.utils.decorators import enable_infrastructure_lru_cache
from flatland.envs.timetable_utils import Line

AgentPosition = Tuple[int, int]
LineGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Line]
Expand Down Expand Up @@ -46,8 +41,8 @@ def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1)
self.speed_ratio_map = speed_ratio_map
self.seed = seed

def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
pass

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -81,7 +76,7 @@ def decide_orientation(self, rail, start, target, possible_orientations, np_rand
return 0

def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
np_random: RandomState) -> Line:
np_random: RandomState) -> Line:
"""

The generator that assigns tasks to all the agents
Expand All @@ -102,12 +97,10 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
agents_target = []
agents_direction = []


city1, city2 = None, None
city1_num_stations, city2_num_stations = None, None
city1_possible_orientations, city2_possible_orientations = None, None


for agent_idx in range(num_agents):

if (agent_idx % 2 == 0):
Expand All @@ -118,9 +111,9 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
city1_num_stations = len(train_stations[city1])
city2_num_stations = len(train_stations[city2])
city1_possible_orientations = [city_orientation[city1],
(city_orientation[city1] + 2) % 4]
(city_orientation[city1] + 2) % 4]
city2_possible_orientations = [city_orientation[city2],
(city_orientation[city2] + 2) % 4]
(city_orientation[city2] + 2) % 4]

# Agent 1 : city1 > city2, Agent 2: city2 > city1
agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
Expand All @@ -143,13 +136,11 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city2_possible_orientations, np_random)


# agent1 details
agents_position.append((agent_start[0][0], agent_start[0][1]))
agents_target.append((agent_target[0][0], agent_target[0][1]))
agents_direction.append(agent_orientation)


if self.speed_ratio_map:
speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
else:
Expand All @@ -163,7 +154,7 @@ def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_re
timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions)))

return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds)
agent_targets=agents_target, agent_speeds=speeds)


def line_from_file(filename, load_from_package=None) -> LineGenerator:
Expand All @@ -182,11 +173,10 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:

def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Line:

env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)

max_episode_steps = env_dict.get("max_episode_steps", 0)
if (max_episode_steps==0):
if (max_episode_steps == 0):
print("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100

Expand All @@ -196,12 +186,12 @@ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_r
agents_position = [a.initial_position for a in agents]

# this logic is wrong - we should really load the initial_direction as the direction.
#agents_direction = [a.direction for a in agents]
# agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]

return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed)
agent_targets=agents_target, agent_speeds=agents_speed)

return generator
28 changes: 23 additions & 5 deletions flatland/envs/rail_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
"""
import random
from functools import lru_cache
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, Set

import numpy as np

import flatland.envs.timetable_generators as ttg
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid_utils import Vector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs import agent_chains as ac
from flatland.envs import line_generators as line_gen
Expand Down Expand Up @@ -99,7 +100,7 @@ class RailEnv(Environment):
def __init__(self,
width,
height,
rail_generator=None,
rail_generator: "RailGenerator" = None,
line_generator: "LineGenerator" = None, # : line_gen.LineGenerator = line_gen.random_line_generator(),
number_of_agents=2,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
Expand Down Expand Up @@ -167,7 +168,7 @@ def __init__(self,
self.rail_generator = rail_generator
if line_generator is None:
line_generator = line_gen.sparse_line_generator()
self.line_generator: LineGenerator = line_generator
self.line_generator: "LineGenerator" = line_generator
self.timetable_generator = timetable_generator

self.rail: Optional[GridTransitionMap] = None
Expand Down Expand Up @@ -205,6 +206,8 @@ def __init__(self,

self.motionCheck = ac.MotionCheck()

self.level_free_positions: Set[Vector2D] = set()

def _seed(self, seed):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
Expand Down Expand Up @@ -314,6 +317,8 @@ def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True,
agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
if optionals and 'level_free_positions' in optionals:
self.level_free_positions = optionals['level_free_positions']

line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
self.num_resets, self.np_random)
Expand Down Expand Up @@ -562,20 +567,32 @@ def step(self, action_dict: Dict[int, RailEnvActions]):
direction=new_direction,
preprocessed_action=preprocessed_action)

# only conflict if the level-free cell is traversed through the same axis (horizontally (0 north or 2 south), or vertically (1 east or 3 west)
new_position_level_free = new_position
if new_position in self.level_free_positions:
new_position_level_free = (new_position, new_direction % 2)
agent_position_level_free = agent.position
if agent.position in self.level_free_positions:
agent_position_level_free = (agent.position, agent.direction % 2)

# This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
self.motionCheck.addAgent(i_agent, agent_position_level_free, new_position_level_free)

# Find conflicts between trains trying to occupy same cell
self.motionCheck.find_conflicts()

for agent in self.agents:
i_agent = agent.handle

agent_position_level_free = agent.position
if agent.position in self.level_free_positions:
agent_position_level_free = (agent.position, agent.direction % 2)

## Update positions
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
movement_allowed = self.motionCheck.check_motion(i_agent, agent_position_level_free)

movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit
movement_allowed = movement_allowed or movement_inside_cell
Expand Down Expand Up @@ -727,6 +744,7 @@ def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVar
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)

def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
Expand Down
56 changes: 35 additions & 21 deletions flatland/envs/rail_generators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
"""Rail generators: infrastructure manager (IM) / Infrastrukturbetreiber (ISB)."""
import math
import warnings
from typing import Callable, Tuple, Optional, Dict, List

Expand All @@ -9,7 +10,7 @@
from flatland.core.grid.grid4_utils import direction_to_point
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions, RailEnvTransitionsEnum
from flatland.core.transition_map import GridTransitionMap
from flatland.envs import persistence
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
Expand All @@ -35,8 +36,7 @@ def __init__(self, *args, **kwargs):
"""
pass

def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
pass

def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
Expand All @@ -53,8 +53,7 @@ class EmptyRailGen(RailGen):
Primarily used by the editor
"""

def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
Expand Down Expand Up @@ -100,8 +99,7 @@ def __init__(self, rail_map, optionals=None):
self.rail_map = rail_map
self.optionals = optionals

def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
return self.rail_map, self.optionals


Expand All @@ -116,7 +114,7 @@ def sparse_rail_generator(*args, **kwargs):
class SparseRailGen(RailGen):

def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
max_rail_pairs_in_city: int = 2, seed=None, p_level_free: float = 0) -> RailGenerator:
"""
Generates railway networks with cities and inner city rails

Expand All @@ -133,6 +131,8 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
seed: int
Initiate the seed
p_level_free : float
Percentage of diamond-crossings which are level-free.

Returns
-------
Expand All @@ -143,9 +143,9 @@ def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_b
self.max_rails_between_cities = max_rails_between_cities
self.max_rail_pairs_in_city = max_rail_pairs_in_city
self.seed = seed
self.p_level_free = p_level_free

def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGeneratorProduct:
"""

Parameters
Expand Down Expand Up @@ -181,7 +181,7 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0
min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3)
rails_between_cities = (rail_pairs_in_city * 2) if self.max_rails_between_cities > (
rail_pairs_in_city * 2) else self.max_rails_between_cities
rail_pairs_in_city * 2) else self.max_rails_between_cities

# We compute the city radius by the given max number of rails it can contain.
# The radius is equal to the number of tracks divided by 2
Expand Down Expand Up @@ -237,11 +237,28 @@ def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0

# Fix all transition elements
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
return grid_map, {'agents_hints': {
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}}

# choose p_level_free percentage of diamond crossings to be level-free
num_diamond_crossings = np.count_nonzero(grid_map.grid[grid_map.grid == RailEnvTransitionsEnum.diamond_crossing])
num_level_free_diamond_crossings = math.floor(self.p_level_free * num_diamond_crossings)
# ceil with probability p_ceil
p_ceil = (self.p_level_free * num_diamond_crossings) % 1.0
num_level_free_diamond_crossings += np_random.choice([1, 0], p=(p_ceil, 1 - p_ceil))
level_free_positions = set()
if num_level_free_diamond_crossings > 0:
choice = np_random.choice(num_diamond_crossings, size=num_level_free_diamond_crossings, replace=False)
positions_diamond_crossings = (grid_map.grid == RailEnvTransitionsEnum.diamond_crossing).nonzero()
level_free_positions = {tuple(positions_diamond_crossings[choice[i]]) for i in range(len(choice))}

return grid_map, {
'agents_hints':
{
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
},
'level_free_positions': level_free_positions
}

def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
height: int, np_random: RandomState = None) -> Tuple[
Expand All @@ -264,7 +281,6 @@ def _generate_random_city_positions(self, num_cities: int, city_radius: int, wid
Returns
-------
Returns a list of all city positions as coordinates (x,y)

"""

city_positions: IntVector2DArray = []
Expand Down Expand Up @@ -322,7 +338,6 @@ def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: in
Returns
-------
Returns a list of all city positions as coordinates (x,y)

"""
aspect_ratio = height / width
# Compute max numbe of possible cities per row and col.
Expand Down Expand Up @@ -526,7 +541,7 @@ def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction)
"""
Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
direction. Direction is a 90 degree cone facing the desired directiont.
Exampe:
Example:
North: The closes neighbour in the North direction is within the cone spanned by a line going
North-West and North-East

Expand Down Expand Up @@ -677,7 +692,6 @@ def _fix_transitions(self, city_cells: set, inter_city_lines: List[IntVector2DAr
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
grid_map: RailEnvTransitions
The grid map containing the rails. Used to draw new rails

"""

# Fix all cities with illegal transition maps
Expand Down
2 changes: 1 addition & 1 deletion flatland/envs/schedule_generators.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
raise ImportError(" Schedule Generators is now renamed to line_generators, any reference to schedule should be replaced with line")
raise ImportError(" Schedule Generators is now renamed to line_generators + timetable_generators, any reference to schedule should be replaced with line")
21 changes: 0 additions & 21 deletions flatland/envs/sparse_rail_gen.py

This file was deleted.

Loading