Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 4093 additions and 776 deletions
import numpy as np
# from flatland.core.env import Environment
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror
from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a list of start and goal positions
# - use a min/max distance allowed as input for this
# - validate that start/goals are not placed too close to other start/goals
#
# step 2: (optional)
# - place random elements on rails array
# - for instance "train station", etc.
#
# step 3:
# - iterate over all [start, goal] pairs:
# - [first X pairs]
# - draw a rail from [start,goal]
# - draw either vertical or horizontal part first (randomly)
# - if rail crosses existing rail then validate new connection
# - if new connection is invalid turn 90 degrees to left/right
# - possibility that this fails to create a path to goal
# - on failure goto step1 and retry with seed+1
# - [avoid crossing other start,goal positions] (optional)
#
# - [after X pairs]
# - find closest rail from start (Pa)
# - iterating outwards in a "circle" from start until an existing rail cell is hit
# - connect [start, Pa]
# - validate crossing rails
# - Do A* from Pa to find closest point on rail (Pb) to goal point
# - Basically normal A* but find point on rail which is closest to goal
# - since full path to goal is unlikely
# - connect [Pb, goal]
# - validate crossing rails
#
# step 4: (optional)
# - add more rails to map randomly
#
# step 5:
# - return transition map + list of [start, goal] points
#
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
# print("too close:", dist, sg_new[i], sg[j])
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
if not all_ok:
# we can might as well give up at this point
# print("\n> Complex Rail Gen: Sanity counter reached, giving up!")
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
# print(":::: path: ", new_path)
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
# print("failed...")
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
#print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
# print(start_goal)
agents_position = [sg[0] for sg in start_goal]
agents_target = [sg[1] for sg in start_goal]
agents_direction = start_dir
return grid_map, agents_position, agents_direction, agents_target
return generator
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
-------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
for r in range(height):
for c in range(width):
cell = rail_spec[r][c]
if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
print("ERROR - invalid cell type=", cell[0])
return []
rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail,
num_agents)
return rail, agents_position, agents_direction, agents_target
return generator
def rail_from_GridTransitionMap_generator(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions
are legit; if no cell can be filled in without violating some
transitions, pick one among those that can satisfy most transitions
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
incompatible.
- keep trying for a total number of insertions
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)): # don't include dead-ends
if t_utils.transitions[i] == int('0010000000000000', 2):
continue
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
# cell = random.sample(cells_to_fill, 1)[0]
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
return_rail,
num_agents)
return return_rail, agents_position, agents_direction, agents_target
return generator
"""
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point, get_new_position
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
rail_trans: RailEnvTransitions,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
flip_start_node_trans: bool = False, flip_end_node_trans: bool = False,
respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None,
avoid_rail=False) -> IntVector2DArray:
"""
Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions.
:param avoid_rail:
:param rail_trans: basic rail transition object
:param grid_map: grid map
:param start: start position of rail
:param end: end position of rail
:param flip_start_node_trans: make valid start position by adding dead-end, empty start if False
:param flip_end_node_trans: make valid end position by adding dead-end, empty end if False
:param respect_transition_validity: Only draw rail maps if legal rail elements can be use, False, draw line without
respecting rail transitions.
:param a_star_distance_function: Define what distance function a-star should use
:param forbidden_cells: cells to avoid when drawing rail. Rail cannot go through this list of cells
:return: List of cells in the path
"""
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, avoid_rail,
respect_transition_validity,
forbidden_cells)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = grid_map.grid[current_pos]
if index == 0:
if new_trans == 0:
# end-point
if flip_start_node_trans:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
grid_map.grid[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = grid_map.grid[end_pos]
if new_trans_e == 0:
# end-point
if flip_end_node_trans:
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
grid_map.grid[end_pos] = new_trans_e
current_dir = new_dir
return path
def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D, rail_trans: RailEnvTransitions) -> IntVector2DArray:
"""
Generates a straight rail line from start cell to end cell.
Diagonal lines are not allowed
:param rail_trans:
:param grid_map:
:param start: Cell coordinates for start of line
:param end: Cell coordinates for end of line
:return: A list of all cells in the path
"""
if not (start[0] == end[0] or start[1] == end[1]):
print("No straight line possible!")
return []
direction = direction_to_point(start, end)
if direction is Grid4TransitionsEnum.NORTH or direction is Grid4TransitionsEnum.SOUTH:
start_row = min(start[0], end[0])
end_row = max(start[0], end[0]) + 1
rows = np.arange(start_row, end_row)
length = np.abs(end[0] - start[0]) + 1
cols = np.repeat(start[1], length)
else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST
start_col = min(start[1], end[1])
end_col = max(start[1], end[1]) + 1
cols = np.arange(start_col, end_col)
length = np.abs(end[1] - start[1]) + 1
rows = np.repeat(start[0], length)
path = list(zip(rows, cols))
for cell in path:
transition = grid_map.grid[cell]
transition = rail_trans.set_transition(transition, direction, direction, 1)
transition = rail_trans.set_transition(transition, mirror(direction), mirror(direction), 1)
grid_map.grid[cell] = transition
return path
def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, rail_trans: RailEnvTransitions):
"""
Fix inner city nodes by connecting it to its neighbouring parallel track
:param grid_map:
:param inner_node_pos: inner city node to fix
:param rail_trans:
:return:
"""
corner_directions = []
for direction in range(4):
tmp_pos = get_new_position(inner_node_pos, direction)
if grid_map.grid[tmp_pos] > 0:
corner_directions.append(direction)
if len(corner_directions) == 2:
transition = 0
transition = rail_trans.set_transition(transition, mirror(corner_directions[0]), corner_directions[1], 1)
transition = rail_trans.set_transition(transition, mirror(corner_directions[1]), corner_directions[0], 1)
grid_map.grid[inner_node_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[0])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[0], mirror(corner_directions[0]), 1)
grid_map.grid[tmp_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[1])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[1], mirror(corner_directions[1]),
1)
grid_map.grid[tmp_pos] = transition
return
def align_cell_to_city(city_center, city_orientation, cell):
"""
Alig all cells to face the city center along the city orientation
@param city_center: Center needed for orientation
@param city_orientation: Orientation of the city
@param cell: Cell we would like to orient
:@return: Orientation of cell towards city center along axis of city orientation
"""
if city_orientation % 2 == 0:
return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
else:
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
"""Line generators (railway undertaking, "EVU")."""
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.timetable_utils import Line
from flatland.envs import persistence
AgentPosition = Tuple[int, int]
LineGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Line]
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None,
seed: int = None, np_random: RandomState = None) -> List[float]:
"""
Parameters
----------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
if speed_ratio_map is None:
return [1.0] * nb_agents
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np_random.choice(nb_classes, nb_agents, p=speed_ratios)))
class BaseLineGen(object):
def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1):
self.speed_ratio_map = speed_ratio_map
self.seed = seed
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
pass
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
def sparse_line_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> LineGenerator:
return SparseLineGen(speed_ratio_map, seed)
class SparseLineGen(BaseLineGen):
"""
This is the line generator which is used for Round 2 of the Flatland challenge. It produces lines
to railway networks provided by sparse_rail_generator.
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
"""
def decide_orientation(self, rail, start, target, possible_orientations, np_random: RandomState) -> int:
feasible_orientations = []
for orientation in possible_orientations:
if rail.check_path_exists(start[0], orientation, target[0]):
feasible_orientations.append(orientation)
if len(feasible_orientations) > 0:
return np_random.choice(feasible_orientations)
else:
return 0
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
np_random: RandomState) -> Line:
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the line
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = self.seed + num_resets
train_stations = hints['train_stations']
city_positions = hints['city_positions']
city_orientation = hints['city_orientations']
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
city1, city2 = None, None
city1_num_stations, city2_num_stations = None, None
city1_possible_orientations, city2_possible_orientations = None, None
for agent_idx in range(num_agents):
if (agent_idx % 2 == 0):
# Setlect 2 cities, find their num_stations and possible orientations
city_idx = np_random.choice(len(city_positions), 2, replace=False)
city1 = city_idx[0]
city2 = city_idx[1]
city1_num_stations = len(train_stations[city1])
city2_num_stations = len(train_stations[city2])
city1_possible_orientations = [city_orientation[city1],
(city_orientation[city1] + 2) % 4]
city2_possible_orientations = [city_orientation[city2],
(city_orientation[city2] + 2) % 4]
# Agent 1 : city1 > city2, Agent 2: city2 > city1
agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations
agent_start = train_stations[city1][agent_start_idx]
agent_target = train_stations[city2][agent_target_idx]
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city1_possible_orientations, np_random)
else:
agent_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations
agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations
agent_start = train_stations[city2][agent_start_idx]
agent_target = train_stations[city1][agent_target_idx]
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city2_possible_orientations, np_random)
# agent1 details
agents_position.append((agent_start[0][0], agent_start[0][1]))
agents_target.append((agent_target[0][0], agent_target[0][1]))
agents_direction.append(agent_orientation)
if self.speed_ratio_map:
speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
else:
speeds = [1.0] * len(agents_position)
# We add multiply factors to the max number of time steps to simplify task in Flatland challenge.
# These factors might change in the future.
timedelay_factor = 4
alpha = 2
max_episode_steps = int(
timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions)))
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds)
def line_from_file(filename, load_from_package=None) -> LineGenerator:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
max_episode_steps = env_dict.get("max_episode_steps", 0)
if (max_episode_steps==0):
print("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100
agents = env_dict["agents"]
# setup with loaded data
agents_position = [a.initial_position for a in agents]
# this logic is wrong - we should really load the initial_direction as the direction.
#agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed)
return generator
"""Malfunction generators for rail systems"""
from typing import Callable, NamedTuple, Optional, Tuple
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs import persistence
# why do we have both MalfunctionParameters and MalfunctionProcessData - they are both the same!
MalfunctionParameters = NamedTuple('MalfunctionParameters',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float:
"""
Probability of a single agent to break. According to Poisson process with given rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(-rate)
class ParamMalfunctionGen(object):
""" Preserving old behaviour of using MalfunctionParameters for constructor,
but returning MalfunctionProcessData in get_process_data.
Data structure and content is the same.
"""
def __init__(self, parameters: MalfunctionParameters):
#self.mean_malfunction_rate = parameters.malfunction_rate
#self.min_number_of_steps_broken = parameters.min_duration
#self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters
def generate(self, np_random: RandomState) -> Malfunction:
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
else:
num_broken_steps = 0
return Malfunction(num_broken_steps)
def get_process_data(self):
return MalfunctionProcessData(*self.MFP)
class NoMalfunctionGen(ParamMalfunctionGen):
def __init__(self):
super().__init__(MalfunctionParameters(0,0,0))
class FileMalfunctionGen(ParamMalfunctionGen):
def __init__(self, env_dict=None, filename=None, load_from_package=None):
""" uses env_dict if populated, otherwise tries to load from file / package.
"""
if env_dict is None:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
if env_dict.get('malfunction') is not None:
oMFP = MalfunctionParameters(*env_dict["malfunction"])
else:
oMFP = MalfunctionParameters(0,0,0) # no malfunctions
super().__init__(oMFP)
################################################################################################
# OLD / DEPRECATED generator functions below. To be removed.
def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which generates no malfunctions
Parameters
----------
Nothing
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use NoMalfunctionGen instead of no_malfunction_generator")
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
Parameters
----------
earlierst_malfunction: Earliest possible malfunction onset
malfunction_duration: The duration of the single malfunction
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
# Keep track of the total number of malfunctions in the env
global_nr_malfunctions = 0
# Malfunction calls per agent
malfunction_calls = dict()
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
# We use the global variable to assure only a single malfunction in the env
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
# Reset malfunciton generator
if reset:
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
global_nr_malfunctions = 0
malfunction_calls = dict()
return Malfunction(0)
# No more malfunctions if we already had one, ignore all updates
if global_nr_malfunctions > 0:
return Malfunction(0)
# Update number of calls per agent
if agent.handle in malfunction_calls:
malfunction_calls[agent.handle] += 1
else:
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \
and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use FileMalfunctionGen instead of malfunction_from_file")
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
# TODO: make this better by using namedtuple in the pickle file. See issue 282
if env_dict.get('malfunction') is not None:
env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction'])
else:
oMPD = None
if oMPD is not None:
# Mean malfunction in number of time steps
mean_malfunction_rate = oMPD.malfunction_rate
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = oMPD.min_duration
max_number_of_steps_broken = oMPD.max_duration
else:
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if agent.malfunction_handler.malfunction_down_counter < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
Parameters
----------
parameters : contains all the parameters of the malfunction
malfunction_rate : float rate per timestep at which each agent malfunctions
min_duration : int minimal duration of a failure
max_number_of_steps_broken : int maximal duration of a failure
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use ParamMalfunctionGen instead of malfunction_from_params")
mean_malfunction_rate = parameters.malfunction_rate
min_number_of_steps_broken = parameters.min_duration
max_number_of_steps_broken = parameters.max_duration
def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1)
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
"""
Collection of environment-specific ObservationBuilder.
"""
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero, fast_position_equal, fast_delete, fast_where
from flatland.envs.step_utils.states import TrainState
from flatland.utils.ordered_set import OrderedSet
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
class TreeObsForRailEnv(ObservationBuilder):
"""
TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv environment.
The information is local to each agent and exploits the graph structure of the rail
network to simplify the representation of the state of the environment for each agent.
For details about the features in the tree observation see the get() function.
"""
tree_explored_actions_char = ['L', 'F', 'R', 'B']
def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
super().__init__()
self.max_depth = max_depth
self.observation_dim = 11
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.location_has_target = None
def reset(self):
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
if handles is None:
handles = []
if self.predictor:
self.max_prediction_depth = 0
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get()
if self.predictions:
for t in range(self.predictor.max_depth + 1):
pos_list = []
dir_list = []
for a in handles:
if self.predictions[a] is None:
continue
pos_list.append(self.predictions[a][t][1:3])
dir_list.append(self.predictions[a][t][3])
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos)
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
# self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
# agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if not _agent.state.is_off_map_state() and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = \
_agent.malfunction_handler.malfunction_down_counter
# [NIMISH] WHAT IS THIS
if _agent.state.is_off_map_state() and \
_agent.initial_position:
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
# self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
# self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles)
return observations
def get(self, handle: int = 0) -> Node:
"""
Computes the current observation for agent `handle` in env
The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
the transitions. The order is::
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as::
[root node information] +
[recursive branch data from 'left'] +
[... from 'forward'] +
[... from 'right] +
[... from 'back']
Each node information is composed of 9 features:
#1:
if own target lies on the explored branch the current distance from the agent in number of cells is stored.
#2:
if another agents target is detected the distance in number of cells from the agents current location\
is stored
#3:
if another agent is detected the distance in number of cells from current agent position is stored.
#4:
possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
distance in number of cells from current agent position
0 = No other agent reserve the same cell at similar time
#5:
if an not usable switch (for agent) is detected we store the distance.
#6:
This feature stores the distance in number of cells to the next branching (current node)
#7:
minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#8:
agent in the same direction
n = number of agents present same direction \
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#9:
agent in the opposite direction
n = number of agents present other direction than myself (so conflict) \
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#10:
malfunctioning/blokcing agents
n = number of time steps the oberved agent remains blocked
#11:
slowest observed speed of an agent in same direction
1 if no agent is observed
min_fractional speed otherwise
#12:
number of agents ready to depart but no yet active
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
# was referring to TreeObsForRailEnv.Node
root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[
(handle, *agent_virtual_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_handler.malfunction_down_counter,
speed_min_fractional=agent.speed_counter.speed,
num_agents_ready_to_depart=0,
childs={})
# print("root node type:", type(root_node_observation))
visited = OrderedSet()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch.
orientation = agent.direction
if num_transitions == 1:
orientation = fast_argmax(possible_transitions)
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_cell = get_new_position(agent_virtual_position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
self.env.dev_obs_dict[handle] = visited
return root_node_observation
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
# [Recursive branch opened]
if depth >= self.max_depth + 1:
return [], []
# Continue along direction until next switch or
# until no transitions are possible along the current direction (i.e., dead-ends)
# We treat dead-ends as nodes, instead of going back, to avoid loops
exploring = True
last_is_switch = False
last_is_dead_end = False
last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_is_target = False
visited = OrderedSet()
agent = self.env.agents[handle]
distance_map_handle = self.env.distance_map.get()[handle]
time_per_cell = 1.0 / agent.speed_counter.speed
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
potential_conflict = np.inf
unusable_switch = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
malfunctioning_agent = 0
min_fractional_speed = 1.
num_steps = 1
other_agent_ready_to_depart_encountered = 0
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if self.location_has_agent.get(position, 0) == 1:
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
# Check if any of the observed agents is malfunctioning, store agent with longest duration left
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += 1
# Check fractional speed of agents
current_fractional_speed = self.location_has_agent_speed[position]
if current_fractional_speed < min_fractional_speed:
min_fractional_speed = current_fractional_speed
else:
# If no agent in the same direction was found all agents in that position are other direction
# Attention this counts to many agents as a few might be going off on a switch.
other_agent_opposite_direction += self.location_has_agent[position]
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
crossing_found = False
if int(transition_bit, 2) == int('1000010000100001', 2):
crossing_found = True
# Register possible future conflict
predicted_time = int(tot_dist * time_per_cell)
if self.predictor and predicted_time < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
if tot_dist < self.max_prediction_depth:
pre_step = max(0, predicted_time - 1)
post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
# Look for conflicting paths at distance tot_dist
if int_position in fast_delete(self.predicted_pos[predicted_time], handle):
conflicting_agent = fast_where(self.predicted_pos[predicted_time] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
elif int_position in fast_delete(self.predicted_pos[pre_step], handle):
conflicting_agent = fast_where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca] \
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step+1
elif int_position in fast_delete(self.predicted_pos[post_step], handle):
conflicting_agent = fast_where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
if tot_dist < other_target_encountered:
other_target_encountered = tot_dist
if position == agent.target and tot_dist < own_target_encountered:
own_target_encountered = tot_dist
# #############################
# #############################
if (position[0], position[1], direction) in visited:
last_is_terminal = True
break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if fast_position_equal(position, self.env.agents[handle].target):
last_is_target = True
break
# Check if crossing is found --> Not an unusable switch
if crossing_found:
# Treat the crossing as a straight rail cell
total_transitions = 2
num_transitions = fast_count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
unusable_switch = tot_dist
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = total_transitions
if nbits == 1:
# Dead-end!
last_is_dead_end = True
if not last_is_dead_end:
# Keep walking through the tree along `direction`
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = fast_argmax(cell_transitions)
position = get_new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
# Switch detected
last_is_switch = True
break
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
position[1], direction)
last_is_terminal = True
break
# `position` is either a terminal node or a switch
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
if last_is_target:
dist_to_next_branch = tot_dist
dist_min_to_target = 0
elif last_is_terminal:
dist_to_next_branch = np.inf
dist_min_to_target = distance_map_handle[position[0], position[1], direction]
else:
dist_to_next_branch = tot_dist
dist_min_to_target = distance_map_handle[position[0], position[1], direction]
# TreeObsForRailEnv.Node
node = Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions(*position, direction)
for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = get_new_position(position, (branch_direction + 2) % 4)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
node.childs[self.tree_explored_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = get_new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
tot_dist + 1,
depth + 1)
node.childs[self.tree_explored_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
node.childs[self.tree_explored_actions_char[i]] = -np.inf
if depth == self.max_depth:
node.childs.clear()
return node, visited
def util_print_obs_subtree(self, tree: Node):
"""
Utility function to print tree observations returned by this object.
"""
self.print_node_features(tree, "root", "")
for direction in self.tree_explored_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
@staticmethod
def print_node_features(node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
node.num_agents_ready_to_depart)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
print(indent, "Direction ", label, ": -np.inf")
return
self.print_node_features(node, label, indent)
if not node.childs:
return
for direction in self.tree_explored_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
self.predictor.set_env(self.env)
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
class GlobalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),\
assuming 16 bits encoding of transitions.
- obs_agents_state: A 3D array (map_height, map_width, 5) with
- first channel containing the agents position and direction
- second channel containing the other agents positions and direction
- third channel containing agent/other agent malfunctions
- fourth channel containing agent/other agent fractional speeds
- fifth channel containing number of other agents ready to depart
- obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
target and the positions of the other agents targets (flag only, no counter!).
"""
def __init__(self):
super(GlobalObsForRailEnv, self).__init__()
def set_env(self, env: Environment):
super().set_env(env)
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
# TODO can we do this more elegantly?
# for r in range(self.env.height):
# for c in range(self.env.width):
# obs_agents_state[(r, c)][4] = 0
obs_agents_state[:, :, 4] = 0
obs_agents_state[agent_virtual_position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.state == TrainState.DONE:
continue
obs_targets[other_agent.target][1] = 1
# second to fourth channel only if in the grid
if other_agent.position is not None:
# second channel only for other agents
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_handler.malfunction_down_counter
obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
class LocalObsForRailEnv(ObservationBuilder):
"""
!!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
Gives a local observation of the rail environment around the agent.
The observation is composed of the following elements:
- transition map array of the local environment around the given agent, \
with dimensions (view_height,2*view_width+1, 16), \
assuming 16 bits encoding of transitions.
- Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
if they are in the agent's vision range, its target position, the positions of the other targets.
- A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
Use the parameters view_width and view_height to define the rectangular view of the agent.
The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
observation in front of it.
.. deprecated:: 2.0.0
"""
def __init__(self, view_width, view_height, center):
super(LocalObsForRailEnv, self).__init__()
self.view_width = view_width
self.view_height = view_height
self.center = center
self.max_padding = max(self.view_width, self.view_height - self.center)
def reset(self):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.max_padding = max(self.view_width, self.view_height)
self.rail_obs = np.zeros((self.env.height,
self.env.width, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
agents = self.env.agents
agent = agents[handle]
# Correct agents position for padding
# agent_rel_pos[0] = agent.position[0] + self.max_padding
# agent_rel_pos[1] = agent.position[1] + self.max_padding
# Collect visible cells as set to be plotted
visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
local_rail_obs = None
# Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = set(visited)
# Locate observed agents and their coresponding targets
local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
_idx = 0
for pos in visited:
curr_rel_coord = rel_coords[_idx]
local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
if pos == agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
else:
for tmp_agent in agents:
if pos == tmp_agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
if pos != agent.position:
for tmp_agent in agents:
if pos == tmp_agent.position:
obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
tmp_agent.direction]
_idx += 1
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles: Optional[List[int]] = None) -> Dict[
int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
return super().get_many(handles)
def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
if state is not None:
temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
data_collection = True
if direction == 0:
origin = (position[0] + self.center, position[1] - self.view_width)
elif direction == 1:
origin = (position[0] - self.view_width, position[1] - self.center)
elif direction == 2:
origin = (position[0] - self.center, position[1] + self.view_width)
else:
origin = (position[0] + self.view_width, position[1] + self.center)
visible = list()
rel_coords = list()
for h in range(self.view_height):
for w in range(2 * self.view_width + 1):
if direction == 0:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.append((origin[0] - h, origin[1] + w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.append((origin[0] + w, origin[1] + h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2:
if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
visible.append((origin[0] + h, origin[1] - w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else:
if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
visible.append((origin[0] - w, origin[1] - h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection:
return temp_visible_data
else:
return visible, rel_coords
import pickle
import msgpack
import numpy as np
import msgpack_numpy
msgpack_numpy.patch()
from flatland.envs import rail_env
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent, load_env_agent
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
class RailEnvPersister(object):
@classmethod
def save(cls, env, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
env_dict = cls.get_full_state(env)
# We have an unresolved problem with msgpack loading the list of agents
# see also 20 lines below.
# print(f"env save - agents: {env_dict['agents'][0]}")
# a0 = env_dict["agents"][0]
# print("agent type:", type(a0))
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
if len(oDistMap) > 0:
env_dict["distance_map"] = oDistMap
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
data = msgpack.packb(env_dict)
elif filename.endswith("pkl"):
data = pickle.dumps(env_dict)
#pickle.dump(env_dict, file_out)
file_out.write(data)
# We have an unresovled problem with msgpack loading the list of Agents
# with open(filename, "rb") as file_in:
# if filename.endswith("mpk"):
# bytes_in = file_in.read()
# dIn = msgpack.unpackb(data, encoding="utf-8")
# print(f"msgpack check - {dIn.keys()}")
# print(f"msgpack check - {dIn['agents'][0]}")
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
# Add additional info to dict_env before saving
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
dict_env["shape"] = (env.width, env.height)
dict_env["max_episode_steps"] = env._max_episode_steps
with open(filename, "wb") as file_out:
if filename.endswith(".mpk"):
file_out.write(msgpack.packb(dict_env))
elif filename.endswith(".pkl"):
pickle.dump(dict_env, file_out)
@classmethod
def load(cls, env, filename, load_from_package=None):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
cls.set_full_state(env, env_dict)
@classmethod
def load_new(cls, filename, load_from_package=None):
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
llGrid = env_dict["grid"]
height = len(llGrid)
width = len(llGrid[0])
# TODO: inefficient - each one of these generators loads the complete env file.
env = rail_env.RailEnv(#width=1, height=1,
width=width, height=height,
rail_generator=rail_gen.rail_from_file(filename,
load_from_package=load_from_package),
line_generator=line_gen.line_from_file(filename,
load_from_package=load_from_package),
#malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
# load_from_package=load_from_package),
malfunction_generator=mal_gen.FileMalfunctionGen(env_dict),
obs_builder_object=DummyObservationBuilder(),
record_steps=True)
env.rail = GridTransitionMap(1,1) # dummy
cls.set_full_state(env, env_dict)
return env, env_dict
@classmethod
def load_env_dict(cls, filename, load_from_package=None):
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
elif filename.endswith("pkl"):
try:
env_dict = pickle.loads(load_data)
except ValueError:
print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
# Replace the agents tuple with EnvAgent objects
if "agents_static" in env_dict:
env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
# remove the legacy key
del env_dict["agents_static"]
elif "agents" in env_dict:
# env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
env_dict["agents"] = [load_env_agent(d) for d in env_dict["agents"]]
return env_dict
@classmethod
def load_resource(cls, package, resource):
"""
Load environment (with distance map?) from a binary
"""
#from importlib_resources import read_binary
#load_data = read_binary(package, resource)
#if resource.endswith("pkl"):
# env_dict = pickle.loads(load_data)
#elif resource.endswith("mpk"):
# env_dict = msgpack.unpackb(load_data, encoding="utf-8")
#cls.set_full_state(env, env_dict)
return cls.load_new(resource, load_from_package=package)
@classmethod
def set_full_state(cls, env, env_dict):
"""
Sets environment state from env_dict
Parameters
-------
env_dict: dict
"""
env.rail.grid = np.array(env_dict["grid"])
# Initialise the env with the frozen agents in the file
env.agents = env_dict.get("agents", [])
# For consistency, set number_of_agents, which is the number which will be generated on reset
env.number_of_agents = env.get_num_agents()
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
env.rail.width = env.width
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
@classmethod
def get_full_state(cls, env):
"""
Returns state of environment in dict object, ready for serialization
"""
grid_data = env.rail.grid.tolist()
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
#print("get_full_state - agent_data:", agent_data)
malfunction_data: mal_gen.MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
"grid": grid_data,
"agents": agent_data,
"malfunction": malfunction_data,
"max_episode_steps": env._max_episode_steps,
}
return msg_data_dict
################################################################################################
# deprecated methods moved from RailEnv. Most likely broken.
def deprecated_get_full_state_msg(self) -> msgpack.Packer:
"""
Returns state of environment in msgpack object
"""
msg_data_dict = self.get_full_state_dict()
return msgpack.packb(msg_data_dict, use_bin_type=True)
def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_agent() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
# I think these calls do nothing - they create packed data and it is discarded
#msgpack.packb(grid_data, use_bin_type=True)
#msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data: mal_gen.MalfunctionProcessData = self.malfunction_process_data
#msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data = {
"grid": grid_data,
"agents": agent_data,
"distance_map": distance_map_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def deprecated_set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils import transition_utils
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Parameters
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
if not agent.state.is_on_map_state():
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
agent_virtual_position = agent.position
agent_virtual_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
new_cell_isValid, new_direction, new_position, transition_isValid = \
transition_utils.check_action_on_agent(action, self.env.rail, agent.position, agent.direction)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = agent_virtual_position
agent.direction = agent_virtual_direction
return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self, max_depth: int = 20):
super().__init__(max_depth)
def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Does not take into account future positions of other agents!
If there is no shortest path, the agent just stands still and stops moving.
Parameters
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here (not implemented yet)
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
distance_map: DistanceMap = self.env.distance_map
shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
prediction_dict = {}
for agent in agents:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
prediction = np.zeros(shape=(self.max_depth + 1, 5))
for i in range(self.max_depth):
prediction[i] = [i, None, None, None, None]
prediction_dict[agent.handle] = prediction
continue
agent_virtual_direction = agent.direction
agent_speed = agent.speed_counter.speed
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
shortest_path = shortest_paths[agent.handle]
# if there is a shortest path, remove the initial position
if shortest_path:
shortest_path = shortest_path[1:]
new_direction = agent_virtual_direction
new_position = agent_virtual_position
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving until max_depth is reached
if new_position == agent.target or not shortest_path:
prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
visited.add((*new_position, agent.direction))
continue
if index % times_per_cell == 0:
new_position = shortest_path[0].position
new_direction = shortest_path[0].direction
shortest_path = shortest_path[1:]
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((*new_position, new_direction))
# TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
return prediction_dict
"""
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
Definition of the RailEnv environment.
"""
import numpy as np
import pickle
import random
from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from flatland.envs.env_utils import get_new_position
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from typing import List, Optional, Dict, Tuple
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
import numpy as np
from gym.utils import seeding
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs import persistence
from flatland.envs import agent_chains as ac
from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.envs.step_utils import action_preprocessing
from flatland.envs.step_utils import env_utils
class RailEnv(Environment):
"""
......@@ -27,44 +43,90 @@ class RailEnv(Environment):
to avoid bottlenecks.
The valid actions in the environment are:
0: do nothing
1: turn left and move to the next cell
2: move to the next cell in front of the agent
3: turn right and move to the next cell
- 0: do nothing (continue moving or stay still)
- 1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
- 2: move to the next cell in front of the agent; if the agent was not moving, movement is started
- 3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
- 4: stop moving
Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from.
The actions of the agents are executed in order of their handle to prevent
deadlocks and to allow them to learn relative priorities.
TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
beta to be passed as parameters to __init__().
Reward Function:
It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.
alpha = 0
beta = 0
Reward function parameters:
- invalid_action_penalty = 0
- step_penalty = -alpha
- global_reward = beta
- epsilon = avoid rounding errors
- stop_penalty = 0 # penalty for stopping a moving agent
- start_penalty = 0 # penalty for starting a stopped agent
Stochastic malfunctioning of trains:
Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
action or cell is selected.
Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
complexity managable.
TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
"""
# Epsilon to avoid rounding errors
epsilon = 0.01
# NEW : REW: Sparse Reward
alpha = 0
beta = 0
step_penalty = -1 * alpha
global_reward = 1 * beta
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
cancellation_factor = 1
cancellation_time_buffer = 0
def __init__(self,
width,
height,
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)):
rail_generator=None,
line_generator=None, # : line_gen.LineGenerator = line_gen.random_line_generator(),
number_of_agents=2,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(),
malfunction_generator=None,
remove_agents_at_target=True,
random_seed=None,
record_steps=False,
):
"""
Environment init.
Parameters
-------
----------
rail_generator : function
The rail_generator function is a function that takes the width,
height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
Implemented functions are:
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
a GridTransitionMap object
rail_from_manual_specifications_generator(rail_spec) : generate a rail from
a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
The rail_generator can pass a distance map in the hints or information for specific line_generators.
Implementations can be found in flatland/envs/rail_generators.py
line_generator : function
The line_generator function is a function that takes the grid, the number of agents and optional hints
and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
Implementations can be found in flatland/envs/line_generators.py
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
......@@ -77,260 +139,635 @@ class RailEnv(Environment):
obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation
vectors for each agent.
remove_agents_at_target : bool
If remove_agents_at_target is set to true then the agents will be removed by placing to
RailEnv.DEPOT_POSITION when the agent has reach it's target position.
random_seed : int or None
if None, then its ignored, else the random generators are seeded with this number to ensure
that stochastic operations are replicable across multiple operations
"""
super().__init__()
if malfunction_generator_and_process_data is not None:
print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
elif malfunction_generator is not None:
self.malfunction_generator = malfunction_generator
# malfunction_process_data is not used
# self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.malfunction_process_data = self.malfunction_generator.get_process_data()
# replace default values here because we can't use default args values because of cyclic imports
else:
self.malfunction_generator = mal_gen.NoMalfunctionGen()
self.malfunction_process_data = self.malfunction_generator.get_process_data()
self.number_of_agents = number_of_agents
if rail_generator is None:
rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator
self.rail = None
if line_generator is None:
line_generator = line_gen.sparse_line_generator()
self.line_generator = line_generator
self.rail: Optional[GridTransitionMap] = None
self.width = width
self.height = height
# use get_num_agents() instead
# self.number_of_agents = number_of_agents
self.remove_agents_at_target = remove_agents_at_target
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.obs_builder.set_env(self)
self.actions = [0] * number_of_agents
self.rewards = [0] * number_of_agents
self.done = False
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
self._max_episode_steps: Optional[int] = None
self._elapsed_steps = 0
self.obs_dict = {}
self.rewards_dict = {}
self.dev_obs_dict = {}
self.dev_pred_dict = {}
# self.agents_handles = list(range(self.number_of_agents))
# self.agents_position = []
# self.agents_target = []
# self.agents_direction = []
self.agents = [None] * number_of_agents # live agents
self.agents_static = [None] * number_of_agents # static agent information
self.agents: List[EnvAgent] = []
self.num_resets = 0
self.reset()
self.num_resets = 0 # yes, set it to zero again!
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.valid_positions = None
self.action_space = [5]
self._seed()
if random_seed:
self._seed(seed=random_seed)
self.agent_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.record_steps = record_steps # whether to save timesteps
# save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
self.cur_episode = []
self.list_actions = [] # save actions in here
self.motionCheck = ac.MotionCheck()
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
self.random_seed = seed
# Keep track of all the seeds in order
if not hasattr(self, 'seed_history'):
self.seed_history = [seed]
if self.seed_history[-1] != seed:
self.seed_history.append(seed)
return [seed]
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
def get_num_agents(self, static=True):
if static:
return len(self.agents_static)
else:
return len(self.agents)
def get_num_agents(self) -> int:
return len(self.agents)
def add_agent_static(self, agent_static):
def add_agent(self, agent):
""" Add static info for a single agent.
Returns the index of the new agent.
"""
self.agents_static.append(agent_static)
return len(self.agents_static) - 1
self.agents.append(agent)
return len(self.agents) - 1
def restart_agents(self):
""" Reset the agents to their starting positions defined in agents_static
def reset_agents(self):
""" Reset the agents to their starting positions
"""
self.agents = EnvAgent.list_from_static(self.agents_static)
for agent in self.agents:
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]
def reset(self, regen_rail=True, replace_agents=True):
""" if regen_rail then regenerate the rails.
if replace_agents then regenerate the agents static.
Relies on the rail_generator returning agent_static lists (pos, dir, target)
def action_required(self, agent):
"""
tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
Check if an agent needs to provide an action
if regen_rail or self.rail is None:
self.rail = tRailAgents[0]
Parameters
----------
agent: RailEnvAgent
Agent we want to check
if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return agent.state == TrainState.READY_TO_DEPART or \
( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: int = None) -> Tuple[Dict, Dict]:
"""
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
The method resets the rail environment
Parameters
----------
regenerate_rail : bool, optional
regenerate the rails
regenerate_schedule : bool, optional
regenerate the schedule and the static agents
random_seed : int, optional
random seed for environment
Returns
-------
observation_dict: Dict
Dictionary with an observation for each agent
info_dict: Dict with agent specific information
"""
# Take the agent static info and put (live) agents at the start positions
# self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
self.restart_agents()
if random_seed:
self._seed(random_seed)
optionals = {}
if regenerate_rail or self.rail is None:
if "__call__" in dir(self.rail_generator):
rail, optionals = self.rail_generator(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
elif "generate" in dir(self.rail_generator):
rail, optionals = self.rail_generator.generate(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
else:
raise ValueError("Could not invoke __call__ or generate on rail_generator")
self.rail = rail
self.height, self.width = self.rail.grid.shape
# Do a new set_env call on the obs_builder to ensure
# that obs_builder specific instantiations are made according to the
# specifications of the current environment : like width, height, etc
self.obs_builder.set_env(self)
if optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
self.num_resets, self.np_random)
self.agents = EnvAgent.from_line(line)
# Reset distance map - basically initializing
self.distance_map.reset(self.agents, self.rail)
# NEW : Time Schedule Generation
timetable = timetable_generator(self.agents, self.distance_map,
agents_hints, self.np_random)
self._max_episode_steps = timetable.max_episode_steps
for agent_i, agent in enumerate(self.agents):
agent.earliest_departure = timetable.earliest_departures[agent_i]
agent.latest_arrival = timetable.latest_arrivals[agent_i]
else:
self.distance_map.reset(self.agents, self.rail)
# Reset agents to initial states
self.reset_agents()
self.num_resets += 1
self._elapsed_steps = 0
# Agent positions map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
self._update_agent_positions_map(ignore_old_positions=False)
# for handle in self.agents_handles:
# self.dones[handle] = False
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
# perhaps dones should be part of each agent.
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
# Empty the episode store of agent positions
self.cur_episode = []
info_dict = self.get_info_dict()
# Return the new observation vectors for each agent
return self._get_observations()
observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def _update_agent_positions_map(self, ignore_old_positions=True):
""" Update the agent_positions array for agents that changed positions """
for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position:
if agent.position is not None:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals()
# Malfunction starts when in_malfunction is set to true
st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
# Malfunction counter complete - Malfunction ends next timestep
st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
# Earliest departure reached - Train is allowed to move now
st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
# Stop Action Given
st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
# Valid Movement action Given
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
# Target Reached
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
return st_signals
def _handle_end_reward(self, agent: EnvAgent) -> int:
'''
Handles end-of-episode reward for a particular agent.
Parameters
----------
agent : EnvAgent
'''
reward = None
# agent done? (arrival_time is not None)
if agent.state == TrainState.DONE:
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.state.is_off_map_state()):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
# Departed but never reached
if (agent.state.is_on_map_state()):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
return reward
def preprocess_action(self, action, agent):
"""
Preprocess the provided action
* Change to DO_NOTHING if illegal action
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
def step(self, action_dict):
alpha = 1.0
beta = 1.0
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
if current_position is None: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
invalid_action_penalty = -2
step_penalty = -1 * alpha
global_reward = 1 * beta
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
# Reset the step rewards
self.rewards_dict = dict()
# for handle in self.agents_handles:
# self.rewards_dict[handle] = 0
for iAgent in range(self.get_num_agents()):
self.rewards_dict[iAgent] = 0
# Check transitions, bounts for executing the action in the given position and directon
if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction):
action = RailEnvActions.STOP_MOVING
return action
def clear_rewards_dict(self):
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
def get_info_dict(self):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
"""
info_dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
def update_step_rewards(self, i_agent):
"""
Update the rewards dict for agent id i_agent for every timestep
"""
pass
def end_of_episode_update(self, have_all_agents_ended):
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if have_all_agents_ended or \
( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
def handle_done_state(self, agent):
""" Any updates to agent to be made in Done state """
if agent.state == TrainState.DONE and agent.arrival_time is None:
agent.arrival_time = self._elapsed_steps
self.dones[agent.handle] = True
if self.remove_agents_at_target:
agent.position = None
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
self._elapsed_steps += 1
# Not allowed to step further once done
if self.dones["__all__"]:
return self._get_observations(), self.rewards_dict, self.dones, {}
# for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()):
# handle = self.agents_handles[i]
transition_isValid = None
agent = self.agents[iAgent]
if iAgent not in action_dict: # no action has been supplied for this agent
continue
if self.dones[iAgent]: # this agent has already completed...
continue
action = action_dict[iAgent]
if action < 0 or action > 3:
print('ERROR: illegal action=', action,
'for agent with index=', iAgent)
return
if action > 0:
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
movement = agent.direction
# print(nbits,np.sum(possible_transitions))
if action == 1:
movement = agent.direction - 1
if num_transitions <= 1:
transition_isValid = False
elif action == 3:
movement = agent.direction + 1
if num_transitions <= 1:
transition_isValid = False
movement %= 4
if action == 2:
if num_transitions == 1:
# - dead-end, straight line or curved line;
# movement will be the only valid transition
# - take only available transition
movement = np.argmax(possible_transitions)
transition_isValid = True
new_position = get_new_position(agent.position, movement)
# Is it a legal move?
# 1) transition allows the movement in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height-1, self.width-1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
movement)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
if all([new_cell_isValid, transition_isValid, cell_isFree]):
# move and change direction to face the movement that was
# performed
# self.agents_position[i] = new_position
# self.agents_direction[i] = movement
agent.position = new_position
agent.direction = movement
else:
# the action was not valid, add penalty
self.rewards_dict[iAgent] += invalid_action_penalty
# if agent is not in target position, add step penalty
# if self.agents_position[i][0] == self.agents_target[i][0] and \
# self.agents_position[i][1] == self.agents_target[i][1]:
# self.dones[handle] = True
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
raise Exception("Episode is done, cannot call step()")
self.clear_rewards_dict()
have_all_agents_ended = True # Boolean flag to check if all agents are done
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_transition_data = {}
for agent in self.agents:
i_agent = agent.handle
agent.old_position = agent.position
agent.old_direction = agent.direction
# Generate malfunction
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
# Get action for the agent
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
preprocessed_action = self.preprocess_action(action, agent)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if train is at cell's exit and train is not in malfunction
position_update_allowed = agent.speed_counter.is_cell_exit and \
not agent.malfunction_handler.malfunction_down_counter > 0 and \
not preprocessed_action == RailEnvActions.STOP_MOVING
# Calculate new position
# Keep agent in same place if already done
if agent.state == TrainState.DONE:
new_position, new_direction = agent.position, agent.direction
# Add agent to the map if not on it yet
elif agent.position is None and agent.action_saver.is_action_saved:
new_position = agent.initial_position
new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = env_utils.apply_action_independent(saved_action,
self.rail,
agent.position,
agent.direction)
preprocessed_action = saved_action
else:
self.rewards_dict[iAgent] += step_penalty
# Check for end of episode + add global reward to all rewards!
# num_agents_in_target_position = 0
# for i in range(self.number_of_agents):
# if self.agents_position[i][0] == self.agents_target[i][0] and \
# self.agents_position[i][1] == self.agents_target[i][1]:
# num_agents_in_target_position += 1
# if num_agents_in_target_position == self.number_of_agents:
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
self.dones["__all__"] = True
self.rewards_dict = [r + global_reward for r in self.rewards_dict]
new_position, new_direction = agent.position, agent.direction
temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
# This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
# Find conflicts between trains trying to occupy same cell
self.motionCheck.find_conflicts()
for agent in self.agents:
i_agent = agent.handle
## Update positions
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit
movement_allowed = movement_allowed or movement_inside_cell
# Fetch the saved transition data
agent_transition_data = temp_transition_data[i_agent]
preprocessed_action = agent_transition_data.preprocessed_action
## Update states
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
# Needed when not removing agents at target
movement_allowed = movement_allowed and agent.state != TrainState.DONE
# Agent is being added to map
if agent.state.is_on_map_state():
if agent.state_machine.previous_state.is_off_map_state():
agent.position = agent.initial_position
agent.direction = agent.initial_direction
# Speed counter completes
elif movement_allowed and (agent.speed_counter.is_cell_exit):
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
agent.state_machine.update_if_reached(agent.position, agent.target)
# Off map or on map state and position should match
env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards
self.update_step_rewards(i_agent)
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state, agent.old_position)
# agent.state_machine.previous_state)
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action()
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
self.actions = [0] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {}
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions):
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if agent.position is None:
pos = (0, 0)
else:
pos = (int(agent.position[0]), int(agent.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_handler.malfunction_down_counter,
int(agent.status),
int(agent.position in self.motionCheck.svDeadlocked)
])
self.cur_episode.append(list_agents_state)
self.list_actions.append(dActions)
def _get_observations(self):
self.obs_dict = {}
# for handle in self.agents_handles:
for iAgent in range(self.get_num_agents()):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
"""
Utility which returns the dictionary of observations for an agent with respect to environment
"""
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def render(self):
# TODO:
pass
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def save(self, sFilename):
dSave = {
"grid": self.rail.grid,
"agents_static": self.agents_static
}
with open(sFilename, "wb") as fOut:
pickle.dump(dSave, fOut)
def load(self, sFilename):
with open(sFilename, "rb") as fIn:
dLoad = pickle.load(fIn)
self.rail.grid = dLoad["grid"]
self.height, self.width = self.rail.grid.shape
self.agents_static = dLoad["agents_static"]
self.agents = [None] * self.get_num_agents()
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def _exp_distirbution_synced(self, rate: float) -> float:
"""
Generates sample from exponential distribution
We need this to guarantee synchronity between different instances with same seed.
:param rate:
:return:
"""
u = self.np_random.rand()
x = - np.log(1 - u) * rate
return x
def _is_agent_ok(self, agent: EnvAgent) -> bool:
"""
Check if an agent is ok, meaning it can move and is not malfuncitoinig
Parameters
----------
agent
Returns
-------
True if agent is ok, False otherwise
"""
return agent.malfunction_handler.in_malfunction
def save(self, filename):
print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
def close(self):
"""
This methods closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
from enum import IntEnum
from typing import NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
@classmethod
def is_action_valid(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
('next_direction', Grid4TransitionsEnum)])
import math
from typing import Dict, List, Optional, Tuple, Set
import matplotlib.pyplot as plt
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.step_utils.states import TrainState
from flatland.envs.distance_map import DistanceMap
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env_action import RailEnvActions, RailEnvNextAction
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.utils.ordered_set import OrderedSet
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
rail: GridTransitionMap) -> Set[RailEnvNextAction]:
"""
Get the valid move actions (forward, left, right) for an agent.
TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
and more elegant. But given the few calls this has no priority now.
Parameters
----------
agent_direction : Grid4TransitionsEnum
agent_position: Tuple[int,int]
rail : GridTransitionMap
Returns
-------
Set of `RailEnvNextAction` (tuples of (action,position,direction))
Possible move actions (forward,left,right) and the next position/direction they lead to.
It is not checked that the next cell is free.
"""
valid_actions: Set[RailEnvNextAction] = []
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions = [(RailEnvNextAction(action, new_position, exit_direction))]
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions = [(RailEnvNextAction(action, new_position, new_direction))]
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
action = RailEnvActions.MOVE_FORWARD
elif new_direction == (agent_direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif new_direction == (agent_direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.append(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
def get_new_position_for_action(
agent_position: Tuple[int, int],
agent_direction: Grid4TransitionsEnum,
action: RailEnvActions,
rail: GridTransitionMap) -> Tuple[int, int, int]:
"""
Get the next position for this action.
TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
and more elegant. But given the few calls this has no priority now.
Parameters
----------
agent_position
agent_direction
action
rail
Returns
-------
Tuple[int,int,int]
row, column, direction
"""
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
valid_action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
if valid_action == action:
return new_position, exit_direction
elif num_transitions == 1:
valid_action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
if valid_action == action:
return new_position, new_direction
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
valid_action = RailEnvActions.MOVE_FORWARD
if valid_action == action:
new_position = get_new_position(agent_position, new_direction)
return new_position, new_direction
elif new_direction == (agent_direction + 1) % 4:
valid_action = RailEnvActions.MOVE_RIGHT
if valid_action == action:
new_position = get_new_position(agent_position, new_direction)
return new_position, new_direction
elif new_direction == (agent_direction - 1) % 4:
valid_action = RailEnvActions.MOVE_LEFT
if valid_action == action:
new_position = get_new_position(agent_position, new_direction)
return new_position, new_direction
def get_action_for_move(
agent_position: Tuple[int, int],
agent_direction: Grid4TransitionsEnum,
next_agent_position: Tuple[int, int],
next_agent_direction: int,
rail: GridTransitionMap) -> Optional[RailEnvActions]:
"""
Get the action (if any) to move from a position and direction to another.
TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
and more elegant. But given the few calls this has no priority now.
Parameters
----------
agent_position
agent_direction
next_agent_position
next_agent_direction
rail
Returns
-------
Optional[RailEnvActions]
the action (if direct transition possible) or None.
"""
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
valid_action = RailEnvActions.MOVE_FORWARD
new_direction = (agent_direction + 2) % 4
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
elif num_transitions == 1:
valid_action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
valid_action = RailEnvActions.MOVE_FORWARD
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
elif new_direction == (agent_direction + 1) % 4:
valid_action = RailEnvActions.MOVE_RIGHT
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
elif new_direction == (agent_direction - 1) % 4:
valid_action = RailEnvActions.MOVE_LEFT
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \
-> Dict[int, Optional[List[Waypoint]]]:
"""
Computes the shortest path for each agent to its target and the action to be taken to do so.
The paths are derived from a `DistanceMap`.
If there is no path (rail disconnected), the path is given as None.
The agent state (moving or not) and its speed are not taken into account
example:
agent_fixed_travel_paths = get_shortest_paths(env.distance_map, None, agent.handle)
path = agent_fixed_travel_paths[agent.handle]
Parameters
----------
distance_map : reference to the distance_map
max_depth : max path length, if the shortest path is longer, it will be cutted
agent_handle : if set, the shortest for agent.handle will be returned , otherwise for all agents
Returns
-------
Dict[int, Optional[List[WalkingElement]]]
"""
shortest_paths = dict()
def _shortest_path_for_agent(agent):
if agent.state.is_off_map_state():
position = agent.initial_position
elif agent.state.is_on_map_state():
position = agent.position
elif agent.state == TrainState.DONE:
position = agent.target
else:
shortest_paths[agent.handle] = None
return
direction = agent.direction
shortest_paths[agent.handle] = []
distance = math.inf
depth = 0
while (position != agent.target and (max_depth is None or depth < max_depth)):
next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
best_next_action = None
for next_action in next_actions:
next_action_distance = distance_map.get()[
agent.handle, next_action.next_position[0], next_action.next_position[
1], next_action.next_direction]
if next_action_distance < distance:
best_next_action = next_action
distance = next_action_distance
shortest_paths[agent.handle].append(Waypoint(position, direction))
depth += 1
# if there is no way to continue, the rail must be disconnected!
# (or distance map is incorrect)
if best_next_action is None:
shortest_paths[agent.handle] = None
return
position = best_next_action.next_position
direction = best_next_action.next_direction
if max_depth is None or depth < max_depth:
shortest_paths[agent.handle].append(Waypoint(position, direction))
if agent_handle is not None:
_shortest_path_for_agent(distance_map.agents[agent_handle])
else:
for agent in distance_map.agents:
_shortest_path_for_agent(agent)
return shortest_paths
def get_k_shortest_paths(env,
source_position: Tuple[int, int],
source_direction: int,
target_position=Tuple[int, int],
k: int = 1, debug=False) -> List[Tuple[Waypoint]]:
"""
Computes the k shortest paths using modified Dijkstra
following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing
In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths.
Parameters
----------
env : RailEnv
source_position: Tuple[int,int]
source_direction: int
target_position: Tuple[int,int]
k : int
max number of shortest paths
debug: bool
print debug statements
Returns
-------
List[Tuple[WalkingElement]]
We use tuples since we need the path elements to be hashable.
We use a list of paths in order to keep the order of length.
"""
# P: set of shortest paths from s to t
# P =empty,
shortest_paths: List[Tuple[Waypoint]] = []
# countu: number of shortest paths found to node u
# countu = 0, for all u in V
count = {(r, c, d): 0 for r in range(env.height) for c in range(env.width) for d in range(4)}
# B is a heap data structure containing paths
# N.B. use OrderedSet to make result deterministic!
heap: OrderedSet[Tuple[Waypoint]] = OrderedSet()
# insert path Ps = {s} into B with cost 0
heap.add((Waypoint(source_position, source_direction),))
# while B is not empty and countt < K:
while len(heap) > 0 and len(shortest_paths) < k:
if debug:
print("iteration heap={}, shortest_paths={}".format(heap, shortest_paths))
# – let Pu be the shortest cost path in B with cost C
cost = np.inf
pu = None
for path in heap:
if len(path) < cost:
pu = path
cost = len(path)
u: Waypoint = pu[-1]
if debug:
print(" looking at pu={}".format(pu))
# – B = B − {Pu }
heap.remove(pu)
# – countu = countu + 1
urcd = (*u.position, u.direction)
count[urcd] += 1
# – if u = t then P = P U {Pu}
if u.position == target_position:
if debug:
print(" found of length {} {}".format(len(pu), pu))
shortest_paths.append(pu)
# – if countu ≤ K then
# CAVEAT: do not allow for loopy paths
elif count[urcd] <= k:
possible_transitions = env.rail.get_transitions(*urcd)
if debug:
print(" looking at neighbors of u={}, transitions are {}".format(u, possible_transitions))
# for each vertex v adjacent to u:
for new_direction in range(4):
if debug:
print(" looking at new_direction={}".format(new_direction))
if possible_transitions[new_direction]:
new_position = get_new_position(u.position, new_direction)
if debug:
print(" looking at neighbor v={}".format((*new_position, new_direction)))
v = Waypoint(position=new_position, direction=new_direction)
# CAVEAT: do not allow for loopy paths
if v in pu:
continue
# – let Pv be a new path with cost C + w(u, v) formed by concatenating edge (u, v) to path Pu
pv = pu + (v,)
# – insert Pv into B
heap.add(pv)
# return P
return shortest_paths
def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
if agent_handle >= distance_map.get().shape[0]:
print("Error: agent_handle cannot be larger than actual number of agents")
return
# take min value of all 4 directions
min_distance_map = np.min(distance_map.get(), axis=3)
plt.imshow(min_distance_map[agent_handle][:][:])
plt.show()
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.line_generators import line_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None,
record_steps = False,
) -> RailEnv:
"""
Parameters
----------
file_name : str
The pickle file.
load_from_package : str
The python module to import from. Example: 'env_data.tests'
This requires that there are `__init__.py` files in the folder structure we load the file from.
obs_builder_object: ObservationBuilder
The obs builder for the `RailEnv` that is created.
Returns
-------
RailEnv
The environment loaded from the pickle file.
"""
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
line_generator=line_from_file(file_name, load_from_package),
number_of_agents=1,
obs_builder_object=obs_builder_object,
record_steps=record_steps,
)
return environment
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import warnings
from typing import Callable, Tuple, Optional, Dict, List
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import direction_to_point
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs import persistence
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
fix_inner_nodes, align_cell_to_city
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
""" A rail generator returns a RailGenerator Product, which is just
a GridTransitionMap followed by an (optional) dict/
"""
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
class RailGen(object):
""" Base class for RailGen(erator) replacement
WIP to replace bare generators with classes / objects without unnamed local variables
which prevent pickling.
"""
def __init__(self, *args, **kwargs):
""" constructor to record any state to be reused in each "generation"
"""
pass
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
pass
def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
return self.generate(*args, **kwargs)
def empty_rail_generator() -> RailGenerator:
return EmptyRailGen()
class EmptyRailGen(RailGen):
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, None
def rail_from_file(filename, load_from_package=None) -> RailGenerator:
"""
Utility to load pickle file
Parameters
----------
filename : Pickle file generated by env.save() or editor
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> List:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
rail_env_transitions = RailEnvTransitions()
grid = np.array(env_dict["grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid
if "distance_map" in env_dict:
distance_map = env_dict["distance_map"]
if len(distance_map) > 0:
return rail, {'distance_map': distance_map}
return [rail, None]
return generator
class RailFromGridGen(RailGen):
def __init__(self, rail_map, optionals=None):
self.rail_map = rail_map
self.optionals = optionals
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
return self.rail_map, self.optionals
def rail_from_grid_transition_map(rail_map, optionals=None) -> RailGenerator:
return RailFromGridGen(rail_map, optionals)
def sparse_rail_generator(*args, **kwargs):
return SparseRailGen(*args, **kwargs)
class SparseRailGen(RailGen):
def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
"""
Generates railway networks with cities and inner city rails
Parameters
----------
max_num_cities : int
Max number of cities to build. The generator tries to achieve this numbers given all the parameters
grid_mode: Bool
How to distribute the cities in the path, either equally in a grid or random
max_rails_between_cities: int
Max number of rails connecting to a city. This is only the number of connection points at city boarder.
Number of tracks drawn inbetween cities can still vary
max_rail_pairs_in_city: int
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
seed: int
Initiate the seed
Returns
-------
Returns the rail generator object to the rail env constructor
"""
self.max_num_cities = max_num_cities
self.grid_mode = grid_mode
self.max_rails_between_cities = max_rails_between_cities
self.max_rail_pairs_in_city = max_rail_pairs_in_city
self.seed = seed
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
"""
Parameters
----------
width: int
Width of the environment
height: int
Height of the environment
num_agents:
Number of agents to be placed within the environment
num_resets: int
Count for how often the environment has been reset
Returns
-------
Returns the grid_map --> The railway infrastructure
Hints:
agents_hints': {
'num_agents': how many agents have starting and end spots
'agent_start_targets_cities': touples of agent start and target cities
'train_stations': locations of train stations for start and targets
'city_orientations' : orientation of cities
"""
if self.seed is not None:
np_random = RandomState(self.seed)
elif np_random is None:
np_random = RandomState(np.random.randint(2 ** 32))
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
# NEW : SCHED CONST (Pairs of rails (1,2,3 pairs))
min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3)
rails_between_cities = (rail_pairs_in_city * 2) if self.max_rails_between_cities > (
rail_pairs_in_city * 2) else self.max_rails_between_cities
# We compute the city radius by the given max number of rails it can contain.
# The radius is equal to the number of tracks divided by 2
# We add 2 cells to avoid that track lenght is to short
city_padding = 2
# We use ceil if we get uneven numbers of city radius. This is to guarantee that all rails fit within the city.
city_radius = int(np.ceil((rail_pairs_in_city * 2) / 2)) + city_padding
vector_field = np.zeros(shape=(height, width)) - 1.
# Calculate the max number of cities allowed
# and reduce the number of cities to build to avoid problems
max_feasible_cities = min(self.max_num_cities,
((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
if max_feasible_cities < 2:
# sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!")
# Evenly distribute cities
if self.grid_mode:
city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
# Distribute cities randomlz
else:
city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height,
np_random=np_random)
# reduce num_cities if less were generated in random mode
num_cities = len(city_positions)
# If random generation failed just put the cities evenly
if num_cities < 2:
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
num_cities = len(city_positions)
# Set up connection points for all cities
inner_connection_points, outer_connection_points, city_orientations, city_cells = \
self._generate_city_connection_points(
city_positions, city_radius, vector_field, rails_between_cities,
rail_pairs_in_city, np_random=np_random)
# Connect the cities through the connection points
inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells,
rail_trans, grid_map)
# Build inner cities
free_rails = self._build_inner_cities(city_positions, inner_connection_points,
outer_connection_points,
rail_trans,
grid_map)
# Populate cities
train_stations = self._set_trainstation_positions(city_positions, city_radius, free_rails)
# Fix all transition elements
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
return grid_map, {'agents_hints': {
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}}
def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
height: int, np_random: RandomState = None) -> Tuple[
IntVector2DArray, IntVector2DArray]:
"""
Distribute the cities randomly in the environment while respecting city sizes and guaranteeing that they
don't overlap.
Parameters
----------
num_cities: int
Max number of cities that should be placed
city_radius: int
Radius of each city. Cities are squares with edge length 2 * city_radius + 1
width: int
Width of the environment
height: int
Height of the environment
Returns
-------
Returns a list of all city positions as coordinates (x,y)
"""
city_positions: IntVector2DArray = []
# We track a grid of allowed indexes that can be sampled from for creating a new city
# This removes the old sampling method of retrying a random sample on failure
allowed_grid = np.zeros((height, width), dtype=np.uint8)
city_radius_pad1 = city_radius + 1
# Borders have to be not allowed from the start
# allowed_grid == 1 indicates locations that are allowed
allowed_grid[city_radius_pad1:-city_radius_pad1, city_radius_pad1:-city_radius_pad1] = 1
for _ in range(num_cities):
allowed_indexes = np.where(allowed_grid == 1)
num_allowed_points = len(allowed_indexes[0])
if num_allowed_points == 0:
break
# Sample one of the allowed indexes
point_index = np_random.randint(num_allowed_points)
row = int(allowed_indexes[0][point_index])
col = int(allowed_indexes[1][point_index])
# Need to block city radius and extra margin so that next sampling is correct
# Clipping handles the case for negative indexes being generated
row_start = max(0, row - 2 * city_radius_pad1)
col_start = max(0, col - 2 * city_radius_pad1)
row_end = row + 2 * city_radius_pad1 + 1
col_end = col + 2 * city_radius_pad1 + 1
allowed_grid[row_start: row_end, col_start: col_end] = 0
city_positions.append((row, col))
created_cites = len(city_positions)
if created_cites < num_cities:
city_warning = f"Could not set all required cities! Created {created_cites}/{num_cities}"
warnings.warn(city_warning)
return city_positions
def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
) -> Tuple[IntVector2DArray, IntVector2DArray]:
"""
Distribute the cities in an evenly spaced grid
Parameters
----------
num_cities: int
Max number of cities that should be placed
city_radius: int
Radius of each city. Cities are squares with edge length 2 * city_radius + 1
width: int
Width of the environment
height: int
Height of the environment
Returns
-------
Returns a list of all city positions as coordinates (x,y)
"""
aspect_ratio = height / width
# Compute max numbe of possible cities per row and col.
# Respect padding at edges of environment
# Respect padding between cities
padding = 2
city_size = 2 * (city_radius + 1)
max_cities_per_row = int((height - padding) // city_size)
max_cities_per_col = int((width - padding) // city_size)
# Choose number of cities per row.
# Limit if it is more then max number of possible cities
cities_per_row = min(int(np.ceil(np.sqrt(num_cities * aspect_ratio))), max_cities_per_row)
cities_per_col = min(int(np.ceil(num_cities / cities_per_row)), max_cities_per_col)
num_build_cities = min(num_cities, cities_per_col * cities_per_row)
row_positions = np.linspace(city_radius + 2, height - (city_radius + 2), cities_per_row, dtype=int)
col_positions = np.linspace(city_radius + 2, width - (city_radius + 2), cities_per_col, dtype=int)
city_positions = []
for city_idx in range(num_build_cities):
row = row_positions[city_idx % cities_per_row]
col = col_positions[city_idx // cities_per_row]
city_positions.append((row, col))
return city_positions
def _generate_city_connection_points(self, city_positions: IntVector2DArray, city_radius: int,
vector_field: IntVector2DArray, rails_between_cities: int,
rail_pairs_in_city: int = 1, np_random: RandomState = None) -> Tuple[
List[List[List[IntVector2D]]],
List[List[List[IntVector2D]]],
List[np.ndarray],
List[Grid4TransitionsEnum]]:
"""
Generate the city connection points. Internal connection points are used to generate the parallel paths
within the city.
External connection points are used to connect different cities together
Parameters
----------
city_positions: IntVector2DArray
Vector that contains all the positions of the cities
city_radius: int
Radius of each city. Cities are squares with edge length 2 * city_radius + 1
vector_field: IntVector2DArray
Vectorfield of the size of the environment. It is used to generate preferred orienations for each cell.
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
rails_between_cities: int
Number of rails that connect out from the city
rail_pairs_in_city: int
Number of rails within the city
Returns
-------
inner_connection_points: List of List of length number of cities
Contains all the inner connection points for each boarder of each city.
[North_Points, East_Poinst, South_Points, West_Points]
outer_connection_points: List of List of length number of cities
Contains all the outer connection points for each boarder of the city.
[North_Points, East_Poinst, South_Points, West_Points]
city_orientations: List of length number of cities
Contains all the orientations of cities. This is then used to orient agents according to the rails
city_cells: List
List containing the coordinates of all the cells that belong to a city. This is used by other algorithms
to avoid drawing inter-city-rails through cities.
"""
inner_connection_points: List[List[List[IntVector2D]]] = []
outer_connection_points: List[List[List[IntVector2D]]] = []
city_orientations: List[Grid4TransitionsEnum] = []
city_cells: IntVector2DArray = []
for city_position in city_positions:
# Chose the directions where close cities are situated
neighb_dist = []
for neighbour_city in city_positions:
neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city))
closest_neighb_idx = self.__class__.argsort(neighb_dist)
# Store the directions to these neighbours and orient city to face closest neighbour
connection_sides_idx = []
idx = 1
if self.grid_mode:
current_closest_direction = np_random.randint(4)
else:
current_closest_direction = direction_to_point(city_position, city_positions[closest_neighb_idx[idx]])
connection_sides_idx.append(current_closest_direction)
connection_sides_idx.append((current_closest_direction + 2) % 4)
city_orientations.append(current_closest_direction)
city_cells.extend(self._get_cells_in_city(city_position, city_radius, city_orientations[-1], vector_field))
# set the number of tracks within a city, at least 2 tracks per city
connections_per_direction = np.zeros(4, dtype=int)
# NEW : SCHED CONST
nr_of_connection_points = np_random.randint(1, rail_pairs_in_city + 1) * 2 # can be (1,2,3)*2 = (2,4,6)
for idx in connection_sides_idx:
connections_per_direction[idx] = nr_of_connection_points
connection_points_coordinates_inner: List[List[IntVector2D]] = [[] for i in range(4)]
connection_points_coordinates_outer: List[List[IntVector2D]] = [[] for i in range(4)]
number_of_out_rails = np_random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1)
start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
for direction in range(4):
connection_slots = np.arange(nr_of_connection_points) - start_idx
# Offset the rails away from the center of the city
offset_distances = np.arange(nr_of_connection_points) - int(nr_of_connection_points / 2)
# The clipping helps ofsetting one side more than the other to avoid switches at same locations
# The magic number plus one is added such that all points have at least one offset
inner_point_offset = np.abs(offset_distances) + np.clip(offset_distances, 0, 1) + 1
for connection_idx in range(connections_per_direction[direction]):
if direction == 0:
tmp_coordinates = (
city_position[0] - city_radius + inner_point_offset[connection_idx],
city_position[1] + connection_slots[connection_idx])
out_tmp_coordinates = (
city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx])
if direction == 1:
tmp_coordinates = (
city_position[0] + connection_slots[connection_idx],
city_position[1] + city_radius - inner_point_offset[connection_idx])
out_tmp_coordinates = (
city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius)
if direction == 2:
tmp_coordinates = (
city_position[0] + city_radius - inner_point_offset[connection_idx],
city_position[1] + connection_slots[connection_idx])
out_tmp_coordinates = (
city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx])
if direction == 3:
tmp_coordinates = (
city_position[0] + connection_slots[connection_idx],
city_position[1] - city_radius + inner_point_offset[connection_idx])
out_tmp_coordinates = (
city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius)
connection_points_coordinates_inner[direction].append(tmp_coordinates)
if connection_idx in range(start_idx, start_idx + number_of_out_rails):
connection_points_coordinates_outer[direction].append(out_tmp_coordinates)
inner_connection_points.append(connection_points_coordinates_inner)
outer_connection_points.append(connection_points_coordinates_outer)
return inner_connection_points, outer_connection_points, city_orientations, city_cells
def _connect_cities(self, city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]],
city_cells: IntVector2DArray,
rail_trans: RailEnvTransitions, grid_map: RailEnvTransitions) -> List[IntVector2DArray]:
"""
Connects cities together through rails. Each city connects from its outgoing connection points to the closest
cities. This guarantees that all connection points are used.
Parameters
----------
city_positions: IntVector2DArray
All coordinates of the cities
connection_points: List[List[List[IntVector2D]]]
List of coordinates of all outer connection points
city_cells: IntVector2DArray
Coordinates of all the cells contained in any city. This is used to avoid drawing rails through existing
cities.
rail_trans: RailEnvTransitions
Railway transition objects
grid_map: RailEnvTransitions
The grid map containing the rails. Used to draw new rails
Returns
-------
Returns a list of all the cells (Coordinates) that belong to a rail path. This can be used to access railway
cells later.
"""
all_paths: List[IntVector2DArray] = []
grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH,
Grid4TransitionsEnum.WEST]
for current_city_idx in np.arange(len(city_positions)):
closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
for out_direction in grid4_directions:
neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction)
for city_out_connection_point in connection_points[current_city_idx][out_direction]:
min_connection_dist = np.inf
for direction in grid4_directions:
current_points = connection_points[neighbour_idx][direction]
for tmp_in_connection_point in current_points:
tmp_dist = Vec2dOperations.get_manhattan_distance(city_out_connection_point,
tmp_in_connection_point)
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighbour_connection_point = tmp_in_connection_point
new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point,
rail_trans, flip_start_node_trans=False,
flip_end_node_trans=False, respect_transition_validity=False,
avoid_rail=True,
forbidden_cells=city_cells)
if len(new_line) == 0:
warnings.warn("[WARNING] No line added between stations")
elif new_line[-1] != neighbour_connection_point or new_line[0] != city_out_connection_point:
warnings.warn("[WARNING] Unable to connect requested stations")
all_paths.extend(new_line)
return all_paths
def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction):
"""
Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
direction. Direction is a 90 degree cone facing the desired directiont.
Exampe:
North: The closes neighbour in the North direction is within the cone spanned by a line going
North-West and North-East
Parameters
----------
closest_neighbours: List
List of length 4 containing the index of closes neighbour in the corresponfing direction:
[North-Neighbour, East-Neighbour, South-Neighbour, West-Neighbour]
out_direction: int
Direction we want to get city index from
North: 0, East: 1, South: 2, West: 3
Returns
-------
Returns the index of the closest neighbour in the desired direction. If none was present the neighbor clockwise
or counter clockwise is returned
"""
neighbour_idx = closest_neighbours[out_direction]
if neighbour_idx is not None:
return neighbour_idx
neighbour_idx = closest_neighbours[(out_direction - 1) % 4] # counter-clockwise
if neighbour_idx is not None:
return neighbour_idx
neighbour_idx = closest_neighbours[(out_direction + 1) % 4] # clockwise
if neighbour_idx is not None:
return neighbour_idx
return closest_neighbours[(out_direction + 2) % 4] # clockwise
def _build_inner_cities(self, city_positions: IntVector2DArray,
inner_connection_points: List[List[List[IntVector2D]]],
outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap) -> Tuple[
List[IntVector2DArray], List[List[List[IntVector2D]]]]:
"""
Set the parallel tracks within the city. The center track of the city is of the length of the city, the lenght
of the tracks decrease by 2 for every parallel track away from the center
EG:
--- Left Track
----- Center Track
--- Right Track
Parameters
----------
city_positions: IntVector2DArray
All coordinates of the cities
inner_connection_points: List[List[List[IntVector2D]]]
Points on city boarder that are used to generate inner city track
outer_connection_points: List[List[List[IntVector2D]]]
Points where the city is connected to neighboring cities
rail_trans: RailEnvTransitions
Railway transition objects
grid_map: RailEnvTransitions
The grid map containing the rails. Used to draw new rails
Returns
-------
Returns a list of all the cells (Coordinates) that belong to a rail paths within the city.
"""
free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))]
for current_city in range(len(city_positions)):
# This part only works if we have keep same number of connection points for both directions
# Also only works with two connection direction at each city
for i in range(4):
if len(inner_connection_points[current_city][i]) > 0:
boarder = i
break
opposite_boarder = (boarder + 2) % 4
nr_of_connection_points = len(inner_connection_points[current_city][boarder])
number_of_out_rails = len(outer_connection_points[current_city][boarder])
start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
# Connect parallel tracks
for track_id in range(nr_of_connection_points):
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans)
free_rails[current_city].append(current_track)
for track_id in range(nr_of_connection_points):
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
# Connect parallel tracks with each other
fix_inner_nodes(
grid_map, source, rail_trans)
fix_inner_nodes(
grid_map, target, rail_trans)
# Connect outer tracks to inner tracks
if start_idx <= track_id < start_idx + number_of_out_rails:
source_outer = outer_connection_points[current_city][boarder][track_id - start_idx]
target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx]
connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans)
connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
return free_rails
def _set_trainstation_positions(self, city_positions: IntVector2DArray, city_radius: int,
free_rails: List[List[List[IntVector2D]]]) -> List[List[Tuple[IntVector2D, int]]]:
"""
Populate the cities with possible start and end positions. Trainstations are set on the center of each paralell
track. Each trainstation gets a coordinate as well as number indicating what track it is on
Parameters
----------
city_positions: IntVector2DArray
All coordinates of the cities
city_radius: int
Radius of each city. Cities are squares with edge length 2 * city_radius + 1
free_rails: List[List[List[IntVector2D]]]
Cells that allow for trainstations to be placed
Returns
-------
Returns a List[List[Tuple[IntVector2D, int]]] containing the coordinates of trainstations as well as their
track number within the city
"""
num_cities = len(city_positions)
train_stations = [[] for i in range(num_cities)]
for current_city in range(len(city_positions)):
for track_nbr in range(len(free_rails[current_city])):
possible_location = free_rails[current_city][track_nbr][
int(len(free_rails[current_city][track_nbr]) / 2)]
train_stations[current_city].append((possible_location, track_nbr))
return train_stations
def _fix_transitions(self, city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
grid_map: GridTransitionMap, vector_field):
"""
Check and fix transitions of all the cells that were modified. This is necessary because we ignore validity
while drawing the rails.
Parameters
----------
city_cells: IntVector2DArray
Cells within cities. All of these might have changed and are thus checked
inter_city_lines: List[IntVector2DArray]
All cells within rails drawn between cities
vector_field: IntVector2DArray
Vectorfield of the size of the environment. It is used to generate preferred orienations for each cell.
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
grid_map: RailEnvTransitions
The grid map containing the rails. Used to draw new rails
"""
# Fix all cities with illegal transition maps
rails_to_fix = np.zeros(3 * grid_map.height * grid_map.width * 2, dtype='int')
rails_to_fix_cnt = 0
cells_to_fix = city_cells + inter_city_lines
for cell in cells_to_fix:
cell_valid = grid_map.cell_neighbours_valid(cell, True)
if not cell_valid:
rails_to_fix[3 * rails_to_fix_cnt] = cell[0]
rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1]
rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell]
rails_to_fix_cnt += 1
# Fix all other cells
for cell in range(rails_to_fix_cnt):
grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
def _closest_neighbour_in_grid4_directions(self, current_city_idx: int, city_positions: IntVector2DArray) -> List[
int]:
"""
Finds the closest city in each direction of the current city
Parameters
----------
current_city_idx: int
Index of current city
city_positions: IntVector2DArray
Vector containing the coordinates of all cities
Returns
-------
Returns indices of closest neighbour in every direction NESW
"""
city_distances = []
closest_neighbour: List[int] = [None for i in range(4)]
# compute distance to all other cities
for city_idx in range(len(city_positions)):
city_distances.append(
Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx]))
sorted_neighbours = np.argsort(city_distances)
for neighbour in sorted_neighbours[1:]: # do not include city itself
direction_to_neighbour = direction_to_point(city_positions[current_city_idx], city_positions[neighbour])
if closest_neighbour[direction_to_neighbour] is None:
closest_neighbour[direction_to_neighbour] = neighbour
# early return once all 4 directions have a closest neighbour
if None not in closest_neighbour:
return closest_neighbour
return closest_neighbour
@staticmethod
def argsort(seq):
"""
Same as Numpy sort but for lists
Parameters
----------
seq: List
list that we would like to sort from smallest to largest
Returns
-------
Returns the sorted list
"""
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def _get_cells_in_city(self, center: IntVector2D, radius: int, city_orientation: int,
vector_field: IntVector2DArray) -> IntVector2DArray:
"""
Function the collect cells of a city. It also populates the vector field accoring to the orientation of the
city.
Example: City oriented north with a radius of 5, the vectorfield in the city will be as follows:
|S|S|S|S|S|
|S|S|S|S|S|
|S|S|S|S|S| <-- City center
|N|N|N|N|N|
|N|N|N|N|N|
This is used to later orient the switches to avoid infeasible maps.
Parameters
----------
center: IntVector2D
center coordinates of city
radius: int
radius of city (it is a square)
city_orientation: int
Orientation of city
Returns
-------
flat list of all cell coordinates in the city
"""
x_range = np.arange(center[0] - radius, center[0] + radius + 1)
y_range = np.arange(center[1] - radius, center[1] + radius + 1)
x_values = np.repeat(x_range, len(y_range))
y_values = np.tile(y_range, len(x_range))
city_cells = list(zip(x_values, y_values))
for cell in city_cells:
vector_field[cell] = align_cell_to_city(center, city_orientation, cell)
return city_cells
@staticmethod
def _are_cities_overlapping(center_1, center_2, radius):
"""
Check if two cities overlap. That is we check if two squares with certain edge length and position overlap
Parameters
----------
center_1: (int, int)
Center of first city
center_2: (int, int)
Center of second city
radius: int
Radius of each city. Cities are squares with edge length 2 * city_radius + 1
Returns
-------
Returns True if the cities overlap and False otherwise
"""
return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius
from typing import NamedTuple, Tuple, List, Dict
# A way point is the entry into a cell defined by
# - the row and column coordinates of the cell entered
# - direction, in which the agent is facing to enter the cell.
# This induces a graph on top of the FLATland cells:
# - four possible way points per cell
# - edges are the possible transitions in the cell.
Waypoint = NamedTuple('Waypoint', [('position', Tuple[int, int]), ('direction', int)])
# A train run is represented by the waypoints traversed and the times of traversal
# The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md
TrainrunWaypoint = NamedTuple('TrainrunWaypoint', [
('scheduled_at', int),
('waypoint', Waypoint)
])
# A train run is the list of an agent's way points and their scheduled time
Trainrun = List[TrainrunWaypoint]
TrainrunDict = Dict[int, Trainrun]
raise ImportError(" Schedule Generators is now renamed to line_generators, any reference to schedule should be replaced with line")
\ No newline at end of file
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import sys
import warnings
from typing import Callable, Tuple, Optional, Dict, List
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
fix_inner_nodes, align_cell_to_city
from flatland.envs import persistence
from flatland.envs.rail_generators import RailGeneratorProduct, RailGenerator
from flatland.core.grid.grid_utils import position_to_coordinate
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
if not RailEnvActions.is_action_valid(action):
return RailEnvActions.DO_NOTHING
else:
return RailEnvActions(action)
def process_do_nothing(state: TrainState, saved_action: RailEnvActions):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
elif saved_action:
action = saved_action
else:
action = RailEnvActions.DO_NOTHING
return action
def process_left_right(action, rail, position, direction):
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
def preprocess_action_when_waiting(action, state):
"""
Set action to DO_NOTHING if in waiting state
"""
if state == TrainState.WAITING:
action = RailEnvActions.DO_NOTHING
return action
def preprocess_raw_action(action, state, saved_action):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
"""
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state, saved_action)
return action
def preprocess_moving_action(action, rail, position, direction):
"""
LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
FORWARD is converted to STOP_MOVING if leading to dead end?
"""
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
return action
\ No newline at end of file
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
class ActionSaver:
def __init__(self):
self.saved_action = None
@property
def is_action_saved(self):
return self.saved_action is not None
def __repr__(self):
return f"is_action_saved: {self.is_action_saved}, saved_action: {str(self.saved_action)}"
def save_action_if_allowed(self, action, state):
"""
Save the action if all conditions are met
1. It is a movement based action -> Forward, Left, Right
2. Action is not already saved
3. Agent is not already done
"""
if action.is_moving_action() and not self.is_action_saved and not state == TrainState.DONE:
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
def to_dict(self):
return {"saved_action": self.saved_action}
def from_dict(self, load_dict):
self.saved_action = load_dict['saved_action']
def __eq__(self, other):
return self.saved_action == other.saved_action
from dataclasses import dataclass
from typing import Tuple
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.step_utils import transition_utils
from flatland.envs.rail_env_action import RailEnvActions
from flatland.core.grid.grid4 import Grid4Transitions
@dataclass(repr=True)
class AgentTransitionData:
""" Class for keeping track of temporary agent data for position update """
position : Tuple[int, int]
direction : Grid4Transitions
preprocessed_action : RailEnvActions
def apply_action_independent(action, rail, position, direction):
""" Apply the action on the train regardless of locations of other trains
Checks for valid cells to move and valid rail transitions
---------------------------------------------------------------------
Parameters: action - Action to execute
rail - Flatland env.rail object
position - current position of the train
direction - current direction of the train
---------------------------------------------------------------------
Returns: new_position - New position after applying the action
new_direction - New direction after applying the action
"""
if action.is_moving_action():
new_direction, _ = transition_utils.check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
new_position, new_direction = position, direction
return new_position, new_direction
def state_position_sync_check(state, position, i_agent):
""" Check for whether on map and off map states are matching with position """
if state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(state), str(position) ))
elif state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(state), str(position) ))
\ No newline at end of file
def get_number_of_steps_to_break(malfunction_generator, np_random):
if hasattr(malfunction_generator, "generate"):
malfunction = malfunction_generator.generate(np_random)
else:
malfunction = malfunction_generator(np_random)
return malfunction.num_broken_steps
class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
self.num_malfunctions = 0
def reset(self):
self._malfunction_down_counter = 0
self.num_malfunctions = 0
@property
def in_malfunction(self):
return self._malfunction_down_counter > 0
@property
def malfunction_counter_complete(self):
return self._malfunction_down_counter == 0
@property
def malfunction_down_counter(self):
return self._malfunction_down_counter
@malfunction_down_counter.setter
def malfunction_down_counter(self, val):
self._set_malfunction_down_counter(val)
def _set_malfunction_down_counter(self, val):
if val < 0:
raise ValueError("Cannot set a negative value to malfunction down counter")
# Only set new malfunction value if old malfunction is completed
if self._malfunction_down_counter == 0:
self._malfunction_down_counter = val
if val > 0:
self.num_malfunctions += 1
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
self._set_malfunction_down_counter(num_broken_steps)
def update_counter(self):
if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1
def __repr__(self):
return f"malfunction_down_counter: {self._malfunction_down_counter} \
in_malfunction: {self.in_malfunction} \
num_malfunctions: {self.num_malfunctions}"
def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter,
"num_malfunctions": self.num_malfunctions}
def from_dict(self, load_dict):
self._malfunction_down_counter = load_dict['malfunction_down_counter']
self.num_malfunctions = load_dict['num_malfunctions']
def __eq__(self, other):
return self._malfunction_down_counter == other._malfunction_down_counter and \
self.num_malfunctions == other.num_malfunctions