Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 3619 additions and 1683 deletions
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2D
def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
......@@ -9,66 +12,41 @@ def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1]
if diff_0 < 0:
return 0
return Grid4TransitionsEnum.NORTH
if diff_0 > 0:
return 2
return Grid4TransitionsEnum.SOUTH
if diff_1 > 0:
return 1
return Grid4TransitionsEnum.EAST
if diff_1 < 0:
return 3
return Grid4TransitionsEnum.WEST
raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
def mirror(dir):
return (dir + 2) % 4
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = rail_array[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
:param pos2: position we want to know it is facing
:return: direction NESW as int N:0 E:1 S:2 W:3
"""
diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1]))
axis = np.argmax(np.power(diff_vec, 2))
direction = np.sign(diff_vec[axis])
if axis == 0:
if direction > 0:
return Grid4TransitionsEnum.NORTH
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
return Grid4TransitionsEnum.SOUTH
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
if direction > 0:
return Grid4TransitionsEnum.WEST
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
return Grid4TransitionsEnum.EAST
......@@ -20,9 +20,9 @@ class Grid8Transitions(Transitions):
"""
Grid8Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 64 bits.
0=North, 1=North-East, etc.
......@@ -82,8 +82,8 @@ class Grid8Transitions(Transitions):
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
......@@ -106,8 +106,8 @@ class Grid8Transitions(Transitions):
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
......@@ -131,8 +131,8 @@ class Grid8Transitions(Transitions):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
......@@ -150,8 +150,8 @@ class Grid8Transitions(Transitions):
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
if new_transition:
......@@ -172,7 +172,7 @@ class Grid8Transitions(Transitions):
64 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
`cell_transition` by. I.e., rotation={0, 45, 90, 135, 180,
225, 270, 315} degrees.
Returns
......
from math import isnan
from typing import Tuple, Callable, List, Type
import numpy as np
Vector2D: Type = Tuple[float, float]
IntVector2D: Type = Tuple[int, int]
def position_to_coordinate(depth, positions):
"""Converts coordinates to positions:
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
IntVector2DArray: Type = List[IntVector2D]
IntVector2DArrayArray: Type = List[List[IntVector2D]]
Vector2DArray: Type = List[Vector2D]
Vector2DArrayArray: Type = List[List[Vector2D]]
IntVector2DDistance: Type = Callable[[IntVector2D, IntVector2D], float]
class Vec2dOperations:
@staticmethod
def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
check if node_a and nobe_b are equal
"""
return node_a[0] == node_b[0] and node_a[1] == node_b[1]
@staticmethod
def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a - node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] - node_b[0], node_a[1] - node_b[1]
@staticmethod
def add(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] + node_b[0], node_a[1] + node_b[1]
@staticmethod
def make_orthogonal(node: Vector2D) -> Vector2D:
"""
vector operation : rotates the 2D vector +90°
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[1], -node[0]
@staticmethod
def get_norm(node: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return np.sqrt(node[0] * node[0] + node[1] * node[1])
@staticmethod
def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Euclidean distance
"""
return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a))
@staticmethod
def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the manhattan distance of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Mahnhattan distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return np.abs(delta[0]) + np.abs(delta[1])
@staticmethod
def get_chebyshev_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the chebyshev norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
the chebyshev distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return max(np.abs(delta[0]), np.abs(delta[1]))
@staticmethod
def normalize(node: Vector2D) -> Tuple[float, float]:
"""
normalize the 2d vector = `v/|v|`
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
n = Vec2dOperations.get_norm(node)
if n > 0.0:
n = 1 / n
return Vec2dOperations.scale(node, n)
@staticmethod
def scale(node: Vector2D, scale: float) -> Vector2D:
"""
scales the 2d vector = node * scale
:param node: tuple with coordinate (x,y) or 2d vector
:param scale: scalar to scale
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[0] * scale, node[1] * scale
@staticmethod
def round(node: Vector2D) -> IntVector2D:
"""
rounds the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return int(np.round(node[0])), int(np.round(node[1]))
@staticmethod
def ceil(node: Vector2D) -> IntVector2D:
"""
ceiling the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.ceil(node[0])), int(np.ceil(node[1]))
@staticmethod
def floor(node: Vector2D) -> IntVector2D:
"""
floor the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.floor(node[0])), int(np.floor(node[1]))
@staticmethod
def bound(node: Vector2D, min_value: float, max_value: float) -> Vector2D:
"""
force the values x and y to be between min_value and max_value
:param node: tuple with coordinate (x,y) or 2d vector
:param min_value: scalar value
:param max_value: scalar value
:return:
tuple with coordinate (x,y) or 2d vector
"""
return max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))
@staticmethod
def rotate(node: Vector2D, rot_in_degree: float) -> Vector2D:
"""
rotate the 2d vector with given angle in degree
:param node: tuple with coordinate (x,y) or 2d vector
:param rot_in_degree: angle in degree
:return:
tuple with coordinate (x,y) or 2d vector
"""
alpha = rot_in_degree / 180.0 * np.pi
x0 = node[0]
y0 = node[1]
x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha)
y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha)
return x1, y1
def position_to_coordinate(depth: int, positions: List[int]):
"""Converts coordinates to positions::
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
-->
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
:param depth:
:param positions:
:return:
Parameters
----------
depth : int
positions : List[Tuple[int,int]]
"""
coords = ()
for p in positions:
......@@ -29,7 +264,8 @@ def position_to_coordinate(depth, positions):
def coordinate_to_position(depth, coords):
"""
Converts positions to coordinates:
Converts positions to coordinates::
[ 0 d .. (w-1)*d
1 d+1
...
......@@ -46,13 +282,17 @@ def coordinate_to_position(depth, coords):
:param coords:
:return:
"""
position = np.empty(len(coords), dtype=int)
idx = 0
for t in coords:
position[idx] = int(t[1] * depth + t[0])
idx += 1
position = list(range(len(coords)))
for index, t in enumerate(coords):
if isnan(t[0]):
position[index] = -1
else:
position[index] = int(t[1] * depth + t[0])
return position
def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def distance_on_rail(pos1, pos2, metric="Euclidean"):
if metric == "Euclidean":
return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
if metric == "Manhattan":
return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.utils.ordered_set import OrderedSet
class RailEnvTransitions(Grid4Transitions):
"""
Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
Special case of `GridTransitions` over a 2D-grid, with a pre-defined set
of transitions mimicking the types of real Swiss rail connections.
--------------------------------------------------------------------------
As no diagonal transitions are allowed in the RailEnv environment, the
possible transitions for RailEnv from a cell to its neighboring ones
are represented over 16 bits.
......@@ -44,7 +43,7 @@ class RailEnvTransitions(Grid4Transitions):
)
# create this to make validation faster
self.transitions_all = set()
self.transitions_all = OrderedSet()
for index, trans in enumerate(self.transitions):
self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10):
......
......@@ -7,9 +7,15 @@ from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
class TransitionMap:
"""
Base TransitionMap class.
......@@ -21,7 +27,7 @@ class TransitionMap:
def get_transitions(self, cell_id):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -41,8 +47,8 @@ class TransitionMap:
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
......@@ -58,8 +64,8 @@ class TransitionMap:
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
......@@ -83,8 +89,8 @@ class TransitionMap:
def set_transition(self, cell_id, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
......@@ -111,7 +117,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])):
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None):
"""
Builder for GridTransitionMap object.
......@@ -130,7 +136,11 @@ class GridTransitionMap(TransitionMap):
self.width = width
self.height = height
self.transitions = transitions
self.random_generator = np.random.RandomState()
if random_seed is None:
self.random_generator.seed(12)
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_full_transitions(self, row, column):
......@@ -154,7 +164,7 @@ class GridTransitionMap(TransitionMap):
def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -176,8 +186,8 @@ class GridTransitionMap(TransitionMap):
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
......@@ -202,8 +212,8 @@ class GridTransitionMap(TransitionMap):
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
......@@ -230,8 +240,8 @@ class GridTransitionMap(TransitionMap):
def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
......@@ -259,7 +269,7 @@ class GridTransitionMap(TransitionMap):
def save_transition_map(self, filename):
"""
Save the transitions grid as `filename', in npy format.
Save the transitions grid as `filename`, in npy format.
Parameters
----------
......@@ -271,9 +281,9 @@ class GridTransitionMap(TransitionMap):
def load_transition_map(self, package, resource):
"""
Load the transitions grid from `filename' (npy format).
Load the transitions grid from `filename` (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway.
initialized with the correct `transitions` object anyway.
Parameters
----------
......@@ -283,7 +293,7 @@ class GridTransitionMap(TransitionMap):
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) )
......@@ -298,12 +308,155 @@ class GridTransitionMap(TransitionMap):
self.height = new_height
self.grid = new_grid
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a dead-end.
"""
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a left/right simple turn.
"""
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
def is_simple_turn(trans):
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
for _ in range(3):
trans = self.transitions.rotate_transition(trans, rotation=90)
all_simple_turns.add(trans)
return trans in all_simple_turns
return is_simple_turn(tmp)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions(node_position[0], node_position[1], node_direction)
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
return False
return True
def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
......@@ -346,8 +499,141 @@ class GridTransitionMap(TransitionMap):
if any(t4Trans2):
continue
else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False
return True
def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
symmetrical = cells[6]
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south]
# loop over available outbound directions (indices) for rcPos
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
connected = 0
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line set deadend
if number_of_incoming == 1:
if self.get_full_transitions(*rcPos) == 0:
self.set_transitions(rcPos, 0)
else:
self.set_transitions(rcPos, 0)
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
self.set_transitions(rcPos, 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
hole = np.argwhere(incoming_connections < 1)[0][0]
if direction >= 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 0:
transition = simple_switch_west_south
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
# Make a double slip switch
if number_of_incoming == 4:
rotation = self.random_generator.randint(2)
transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
:param prev_pos: The previous position we were checking
:param current_pos: The current position we are checking
:param new_pos: Possible child position we move into
:param end_pos: End cell of path we are drawing
:return: True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return self.transitions.is_valid(new_trans)
def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
......@@ -12,7 +12,7 @@ class Transitions:
Generic class that implements checks to control whether a
certain transition is allowed (agent facing a direction
`orientation' and moving into direction `orientation')
`orientation' and moving into direction `orientation`)
"""
def get_type(self):
......@@ -21,7 +21,7 @@ class Transitions:
def get_transitions(self, cell_transition, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_transition' for an agent facing direction `orientation'
`cell_transition' for an agent facing direction `orientation`
(e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -45,9 +45,9 @@ class Transitions:
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Return a `cell_transition' specification where the transitions
available for an agent facing direction `orientation' are replaced
with the tuple `new_transitions'. `new_orientations' must have
Return a `cell_transition` specification where the transitions
available for an agent facing direction `orientation` are replaced
with the tuple `new_transitions'. `new_orientations` must have
one element for each possible transition.
Parameters
......@@ -65,8 +65,8 @@ class Transitions:
-------
[cell-content]
An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions',
for the appropriate `orientation'.
transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation`.
"""
raise NotImplementedError()
......@@ -74,8 +74,8 @@ class Transitions:
def get_transition(self, cell_transition, orientation, direction):
"""
Return the status of whether an agent oriented in directions
`orientation' and inside a cell with transitions `cell_transition'
can move to the cell in direction `direction' relative
`orientation' and inside a cell with transitions `cell_transition`
can move to the cell in direction `direction` relative
to the current cell.
Parameters
......@@ -101,11 +101,11 @@ class Transitions:
def set_transition(self, cell_transition, orientation, direction,
new_transition):
"""
Return a `cell_transition' specification where the status of
whether an agent oriented in direction `orientation' and inside
a cell with transitions `cell_transition' can move to the cell
in direction `direction' relative to the current cell is set
to `new_transition'.
Return a `cell_transition` specification where the status of
whether an agent oriented in direction `orientation` and inside
a cell with transitions `cell_transition` can move to the cell
in direction `direction` relative to the current cell is set
to `new_transition`.
Parameters
----------
......@@ -125,8 +125,8 @@ class Transitions:
-------
[cell-content]
An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions',
for the appropriate `orientation' to `direction'.
transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation' to `direction`.
"""
raise NotImplementedError()
......
import networkx as nx
import numpy as np
from typing import List, Tuple
import graphviz as gv
class MotionCheck(object):
""" Class to find chains of agents which are "colliding" with a stopped agent.
This is to allow close-packed chains of agents, ie a train of agents travelling
at the same speed with no gaps between them,
"""
def __init__(self):
self.G = nx.DiGraph()
self.nDeadlocks = 0
self.svDeadlocked = set()
def addAgent(self, iAg, rc1, rc2, xlabel=None):
""" add an agent and its motion as row,col tuples of current and next position.
The agent's current position is given an "agent" attribute recording the agent index.
If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created.
xlabel is used for test cases to give a label (see graphviz)
"""
# Agents which have not yet entered the env have position None.
# Substitute this for the row = -1, column = agent index
if rc1 is None:
rc1 = (-1, iAg)
if rc2 is None:
rc2 = (-1, iAg)
self.G.add_node(rc1, agent=iAg)
if xlabel:
self.G.nodes[rc1]["xlabel"] = xlabel
self.G.add_edge(rc1, rc2)
def find_stops(self):
""" find all the stopped agents as a set of rc position nodes
A stopped agent is a self-loop on a cell node.
"""
# get the (sparse) adjacency matrix
spAdj = nx.linalg.adjacency_matrix(self.G)
# the stopped agents appear as 1s on the diagonal
# the where turns this into a list of indices of the 1s
giStops = np.where(spAdj.diagonal())[0]
# convert the cell/node indices into the node rc values
lvAll = list(self.G.nodes())
# pick out the stops by their indices
lvStops = [ lvAll[i] for i in giStops ]
# make it into a set ready for a set intersection
svStops = set(lvStops)
return svStops
def find_stops2(self):
""" alternative method to find stopped agents, using a networkx call to find selfloop edges
"""
svStops = { u for u,v in nx.classes.function.selfloop_edges(self.G) }
return svStops
def find_stop_preds(self, svStops=None):
""" Find the predecessors to a list of stopped agents (ie the nodes / vertices)
Returns the set of predecessors.
Includes "chained" predecessors.
"""
if svStops is None:
svStops = self.find_stops2()
# Get all the chains of agents - weakly connected components.
# Weakly connected because it's a directed graph and you can traverse a chain of agents
# in only one direction
lWCC = list(nx.algorithms.components.weakly_connected_components(self.G))
svBlocked = set()
for oWCC in lWCC:
#print("Component:", oWCC)
# Get the node details for this WCC in a subgraph
Gwcc = self.G.subgraph(oWCC)
# Find all the stops in this chain or tree
svCompStops = svStops.intersection(Gwcc)
#print(svCompStops)
if len(svCompStops) > 0:
# We need to traverse it in reverse - back up the movement edges
Gwcc_rev = Gwcc.reverse()
for vStop in svCompStops:
# Find all the agents stopped by vStop by following the (reversed) edges
# This traverses a tree - dfs = depth first seearch
iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc_rev, vStop)
lStops = list(iter_stops)
svBlocked.update(lStops)
# the set of all the nodes/agents blocked by this set of stopped nodes
return svBlocked
def find_swaps(self):
""" find all the swap conflicts where two agents are trying to exchange places.
These appear as simple cycles of length 2.
These agents are necessarily deadlocked (since they can't change direction in flatland) -
meaning they will now be stuck for the rest of the episode.
"""
#svStops = self.find_stops2()
llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
return svSwaps
def find_same_dest(self):
""" find groups of agents which are trying to land on the same cell.
ie there is a gap of one cell between them and they are both landing on it.
"""
pass
def block_preds(self, svStops, color="red"):
""" Take a list of stopped agents, and apply a stop color to any chains/trees
of agents trying to head toward those cells.
Count the number of agents blocked, ignoring those which are already marked.
(Otherwise it can double count swaps)
"""
iCount = 0
svBlocked = set()
# The reversed graph allows us to follow directed edges to find affected agents.
Grev = self.G.reverse()
for v in svStops:
# Use depth-first-search to find a tree of agents heading toward the blocked cell.
lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
svBlocked |= set(lvPred)
svBlocked.add(v)
#print("node:", v, "set", svBlocked)
# only count those not already marked
for v2 in [v]+lvPred:
if self.G.nodes[v2].get("color") != color:
self.G.nodes[v2]["color"] = color
iCount += 1
return svBlocked
def find_conflicts(self):
svStops = self.find_stops2() # voluntarily stopped agents - have self-loops
svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions
# Block all swaps and their tree of predessors
self.svDeadlocked = self.block_preds(svSwaps, color="purple")
# Take the union of the above, and find all the predecessors
#svBlocked = self.find_stop_preds(svStops.union(svSwaps))
# Just look for the the tree of preds for each voluntarily stopped agent
svBlocked = self.find_stop_preds(svStops)
# iterate the nodes v with their predecessors dPred (dict of nodes->{})
for (v, dPred) in self.G.pred.items():
# mark any swaps with purple - these are directly deadlocked
#if v in svSwaps:
# self.G.nodes[v]["color"] = "purple"
# If they are not directly deadlocked, but are in the union of stopped + deadlocked
#elif v in svBlocked:
# if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
if v in svBlocked:
self.G.nodes[v]["color"] = "red"
# not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
elif len(dPred)>1:
# if this agent is already red/blocked, ignore. CHECK: why?
# certainly we want to ignore purple so we don't overwrite with red.
if self.G.nodes[v].get("color") in ("red", "purple"):
continue
# if this node has no agent, and >=2 want to enter it.
if self.G.nodes[v].get("agent") is None:
self.G.nodes[v]["color"] = "blue"
# this node has an agent and >=2 want to enter
else:
self.G.nodes[v]["color"] = "magenta"
# predecessors of a contended cell: {agent index -> node}
diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred}
# remove the agent with the lowest index, who wins
iAgWinner = min(diAgCell)
diAgCell.pop(iAgWinner)
# Block all the remaining predessors, and their tree of preds
#for iAg, v in diAgCell.items():
# self.G.nodes[v]["color"] = "red"
# for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
# self.G.nodes[vPred]["color"] = "red"
self.block_preds(diAgCell.values(), "red")
def check_motion(self, iAgent, rcPos):
""" Returns tuple of boolean can the agent move, and the cell it will move into.
If agent position is None, we use a dummy position of (-1, iAgent)
"""
if rcPos is None:
rcPos = (-1, iAgent)
dAttr = self.G.nodes.get(rcPos)
#print("pos:", rcPos, "dAttr:", dAttr)
if dAttr is None:
dAttr = {}
# If it's been marked red or purple then it can't move
if "color" in dAttr:
sColor = dAttr["color"]
if sColor in [ "red", "purple" ]:
return False
dSucc = self.G.succ[rcPos]
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return False
# This agent has a successor
rcNext = self.G.successors(rcPos).__next__()
if rcNext == rcPos: # the agent didn't want to move
return False
# The agent wanted to move, and it can
return True
def render(omc:MotionCheck, horizontal=True):
try:
oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
oAG.layout("dot")
sDot = oAG.to_string()
if horizontal:
sDot = sDot.replace('{', '{ rankdir="LR" ')
#return oAG.draw(format="png")
# This returns a graphviz object which implements __repr_svg
return gv.Source(sDot)
except ImportError as oError:
print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs")
return None
class ChainTestEnv(object):
""" Just for testing agent chains
"""
def __init__(self, omc:MotionCheck):
self.iAgNext = 0
self.iRowNext = 1
self.omc = omc
def addAgent(self, rc1, rc2, xlabel=None):
self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel)
self.iAgNext+=1
def addAgentToRow(self, c1, c2, xlabel=None):
self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel)
def create_test_chain(self,
nAgents:int,
rcVel:Tuple[int] = (0,1),
liStopped:List[int]=[],
xlabel=None):
""" create a chain of agents
"""
lrcAgPos = [ (self.iRowNext, i * rcVel[1]) for i in range(nAgents) ]
for iAg, rcPos in zip(range(nAgents), lrcAgPos):
if iAg in liStopped:
rcVel1 = (0,0)
else:
rcVel1 = rcVel
self.omc.addAgent(iAg+self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1]) )
if xlabel:
self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel
self.iAgNext += nAgents
self.iRowNext += 1
def nextRow(self):
self.iRowNext+=1
def create_test_agents(omc:MotionCheck):
# blocked chain
omc.addAgent(1, (1,2), (1,3))
omc.addAgent(2, (1,3), (1,4))
omc.addAgent(3, (1,4), (1,5))
omc.addAgent(31, (1,5), (1,5))
# unblocked chain
omc.addAgent(4, (2,1), (2,2))
omc.addAgent(5, (2,2), (2,3))
# blocked short chain
omc.addAgent(6, (3,1), (3,2))
omc.addAgent(7, (3,2), (3,2))
# solitary agent
omc.addAgent(8, (4,1), (4,2))
# solitary stopped agent
omc.addAgent(9, (5,1), (5,1))
# blocked short chain (opposite direction)
omc.addAgent(10, (6,4), (6,3))
omc.addAgent(11, (6,3), (6,3))
# swap conflict
omc.addAgent(12, (7,1), (7,2))
omc.addAgent(13, (7,2), (7,1))
def create_test_agents2(omc:MotionCheck):
# blocked chain
cte = ChainTestEnv(omc)
cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain")
cte.create_test_chain(4, xlabel="running\nchain")
cte.create_test_chain(2, liStopped = [1], xlabel="stopped \nshort\n chain")
cte.addAgentToRow(1, 2, "swap")
cte.addAgentToRow(2, 1)
cte.nextRow()
cte.addAgentToRow(1, 2, "chain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nstop")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 4)
cte.addAgentToRow(5, 6)
cte.addAgentToRow(6, 7)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 3)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.nextRow()
cte.addAgentToRow(1, 2, "Land on\nSame")
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "chains\nonto\nsame")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.addAgentToRow(7, 6)
cte.nextRow()
cte.addAgentToRow(1, 2, "3-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.nextRow()
if False:
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "4-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.addAgent((cte.iRowNext-1, 2), (cte.iRowNext, 2))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tee")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgent((cte.iRowNext+1, 3), (cte.iRowNext, 3))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tree")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
r1 = cte.iRowNext
r2 = cte.iRowNext+1
r3 = cte.iRowNext+2
cte.addAgent((r2, 3), (r1, 3))
cte.addAgent((r2, 2), (r2, 3))
cte.addAgent((r3, 2), (r2, 3))
cte.nextRow()
def test_agent_following():
omc = MotionCheck()
create_test_agents2(omc)
svStops = omc.find_stops()
svBlocked = omc.find_stop_preds()
llvSwaps = omc.find_swaps()
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
print(list(svBlocked))
lvCells = omc.G.nodes()
lColours = [ "magenta" if v in svStops
else "red" if v in svBlocked
else "purple" if v in svSwaps
else "lightblue"
for v in lvCells ]
dPos = dict(zip(lvCells, lvCells))
nx.draw(omc.G,
with_labels=True, arrowsize=20,
pos=dPos,
node_color = lColours)
def main():
test_agent_following()
if __name__=="__main__":
main()
from itertools import starmap
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from attr import attrs, attrib, Factory
import warnings
@attrs
class EnvAgentStatic(object):
""" EnvAgentStatic - Stores initial position, direction and target.
This is like static data for the environment - it's where an agent starts,
rather than where it is at the moment.
The target should also be stored here.
"""
position = attrib()
direction = attrib()
target = attrib()
moving = attrib(default=False)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
from typing import Tuple, Optional, NamedTuple, List
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
from attr import attr, attrs, attrib, Factory
def to_list(self):
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
# I can't find an expression which works on both tuples, lists and ndarrays
# which converts them all to a list of native python ints.
lPos = self.position
if type(lPos) is np.ndarray:
lPos = lPos.tolist()
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
lTarget = self.target
if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist()
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
('direction', Grid4TransitionsEnum),
('target', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('handle', int),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data]
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs
class EnvAgent(EnvAgentStatic):
""" EnvAgent - replace separate agent_* lists with a single list
of agent objects. The EnvAgent represent's the environment's view
of the dynamic agent state.
We are duplicating target in the EnvAgent, which seems simpler than
forcing the env to refer to it in the EnvAgentStatic
"""
class EnvAgent:
# INIT FROM HERE IN _from_line()
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# NEW : EnvAgent - Schedule properties
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
# used in rendering
old_direction = attrib(default=None)
old_position = attrib(default=None)
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data]
@classmethod
def from_static(cls, oStatic):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
def reset(self):
"""
return EnvAgent(*oStatic.__dict__, handle=0)
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
"""
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
self.old_position = None
self.old_direction = None
self.moving = False
self.arrival_time = None
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
@classmethod
def list_from_static(cls, lEnvAgentStatic, handles=None):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
if handles is None:
handles = range(len(lEnvAgentStatic))
num_agents = len(line.agent_positions)
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} \n \
initial_direction: {self.initial_direction} \n \
position: {self.position} \n \
direction: {self.direction} \n \
target: {self.target} \n \
old_position: {self.old_position} \n \
old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} \n \
latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")
return [EnvAgent(**oEAS.__dict__, handle=handle)
for handle, oEAS in zip(handles, lEnvAgentStatic)]
from collections import deque
from typing import List, Optional
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
class DistanceMap:
def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
self.env_height = env_height
self.env_width = env_width
self.distance_map = None
self.agents_previous_computation = None
self.reset_was_called = False
self.agents: List[EnvAgent] = agents
self.rail: Optional[GridTransitionMap] = None
def set(self, distance_map: np.ndarray):
"""
Set the distance map
"""
self.distance_map = distance_map
def get(self) -> np.ndarray:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_computation is None and self.distance_map is not None:
compute_distance_map = False
if compute_distance_map:
self._compute(self.agents, self.rail)
elif self.distance_map is None:
self._compute(self.agents, self.rail)
return self.distance_map
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents: List[EnvAgent] = agents
self.rail = rail
self.env_height = rail.height
self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
This function computes the distance maps for each unique target. Thus if several targets are the same
we only compute the distance for them once and copy to all targets with same position.
:param agents: All the agents in the environment, independent of their current status
:param rail: The rail transition map
"""
self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
computed_targets = []
for i, agent in enumerate(agents):
if agent.target not in computed_targets:
self._distance_map_walker(rail, agent.target, i)
else:
# just copy the distance map form other agent with same target (performance)
self.distance_map[i, :, :, :] = np.copy(
self.distance_map[computed_targets.index(agent.target), :, :, :])
computed_targets.append(agent.target)
def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
from typing import Tuple
# Adrian Egli / Michel Marti performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def fast_delete(lis: list, index) -> list:
new_list = lis.copy()
new_list.pop(index)
return new_list
def fast_where(binary_iterable):
return [index for index, element in enumerate(binary_iterable) if element != 0]
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
"""
Parameters
----------
file_name : str
The pickle file.
load_from_package : str
The python module to import from. Example: 'env_data.tests'
This requires that there are `__init__.py` files in the folder structure we load the file from.
obs_builder_object: ObservationBuilder
The obs builder for the `RailEnv` that is created.
Returns
-------
RailEnv
The environment loaded from the pickle file.
"""
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.grid4_generators_utils import connect_rail
from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
def empty_rail_generator():
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width, height, num_agents=0, num_resets=0):
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, [], [], [], []
return generator
def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a start and goal position
# - validate min/max distance allowed
# - validate that start/goals are not placed too close to other start/goals
# - draw a rail from [start,goal]
# - if rail crosses existing rail then validate new connection
# - possibility that this fails to create a path to goal
# - on failure generate new start/goal
#
# step 2:
# - add more rails to map randomly between cells that have rails
# - validate all new rails, on failure don't add new rails
#
# step 3:
# - return transition map + list of [start_pos, start_dir, goal_pos] points
#
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
if not all_ok:
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
agents_position = [sg[0] for sg in start_goal[:num_agents]]
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
-------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets=0):
rail_env_transitions = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions)
for r in range(height):
for c in range(width):
rail_spec_of_cell = rail_spec[r][c]
index_basic_type_of_cell_ = rail_spec_of_cell[0]
rotation_cell_ = rail_spec_of_cell[1]
if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions):
print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_)
return []
basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_]
effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
rail.set_transitions((r, c), effective_transition_cell)
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail,
num_agents)
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def rail_from_file(filename):
"""
Utility to load pickle file
Parameters
-------
input_file : Pickle file generated by env.save() or editor
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets):
rail_env_transitions = RailEnvTransitions()
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
grid = np.array(data[b"grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid
# agents are always reset as not moving
agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
# setup with loaded data
agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
if b"distance_maps" in data.keys():
distance_maps = data[b"distance_maps"]
if len(distance_maps) > 0:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(
agents_position), distance_maps
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def rail_from_grid_transition_map(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions
are legit; if no cell can be filled in without violating some
transitions, pick one among those that can satisfy most transitions
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
incompatible.
- keep trying for a total number of insertions
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)): # don't include dead-ends
if t_utils.transitions[i] == int('0010000000000000', 2):
continue
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
return_rail,
num_agents)
return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
......@@ -7,18 +7,43 @@ a GridTransitionMap object.
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position
def connect_rail(rail_trans, rail_array, start, end):
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point, get_new_position
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
rail_trans: RailEnvTransitions,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
flip_start_node_trans: bool = False, flip_end_node_trans: bool = False,
respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None,
avoid_rail=False) -> IntVector2DArray:
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions.
:param avoid_rail:
:param rail_trans: basic rail transition object
:param grid_map: grid map
:param start: start position of rail
:param end: end position of rail
:param flip_start_node_trans: make valid start position by adding dead-end, empty start if False
:param flip_end_node_trans: make valid end position by adding dead-end, empty end if False
:param respect_transition_validity: Only draw rail maps if legal rail elements can be use, False, draw line without
respecting rail transitions.
:param a_star_distance_function: Define what distance function a-star should use
:param forbidden_cells: cells to avoid when drawing rail. Rail cannot go through this list of cells
:return: List of cells in the path
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, avoid_rail,
respect_transition_validity,
forbidden_cells)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
......@@ -26,12 +51,15 @@ def connect_rail(rail_trans, rail_array, start, end):
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
new_trans = grid_map.grid[current_pos]
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
if flip_start_node_trans:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -40,96 +68,108 @@ def connect_rail(rail_trans, rail_array, start, end):
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans
grid_map.grid[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = rail_array[end_pos]
new_trans_e = grid_map.grid[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
if flip_end_node_trans:
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
rail_array[end_pos] = new_trans_e
grid_map.grid[end_pos] = new_trans_e
current_dir = new_dir
return path
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D, rail_trans: RailEnvTransitions) -> IntVector2DArray:
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
Generates a straight rail line from start cell to end cell.
Diagonal lines are not allowed
:param rail_trans:
:param grid_map:
:param start: Cell coordinates for start of line
:param end: Cell coordinates for end of line
:return: A list of all cells in the path
"""
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
if not (start[0] == end[0] or start[1] == end[1]):
print("No straight line possible!")
return []
direction = direction_to_point(start, end)
if direction is Grid4TransitionsEnum.NORTH or direction is Grid4TransitionsEnum.SOUTH:
start_row = min(start[0], end[0])
end_row = max(start[0], end[0]) + 1
rows = np.arange(start_row, end_row)
length = np.abs(end[0] - start[0]) + 1
cols = np.repeat(start[1], length)
else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST
start_col = min(start[1], end[1])
end_col = max(start[1], end[1]) + 1
cols = np.arange(start_col, end_col)
length = np.abs(end[1] - start[1]) + 1
rows = np.repeat(start[0], length)
return agents_position, agents_direction, agents_target
path = list(zip(rows, cols))
for cell in path:
transition = grid_map.grid[cell]
transition = rail_trans.set_transition(transition, direction, direction, 1)
transition = rail_trans.set_transition(transition, mirror(direction), mirror(direction), 1)
grid_map.grid[cell] = transition
return path
def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, rail_trans: RailEnvTransitions):
"""
Fix inner city nodes by connecting it to its neighbouring parallel track
:param grid_map:
:param inner_node_pos: inner city node to fix
:param rail_trans:
:return:
"""
corner_directions = []
for direction in range(4):
tmp_pos = get_new_position(inner_node_pos, direction)
if grid_map.grid[tmp_pos] > 0:
corner_directions.append(direction)
if len(corner_directions) == 2:
transition = 0
transition = rail_trans.set_transition(transition, mirror(corner_directions[0]), corner_directions[1], 1)
transition = rail_trans.set_transition(transition, mirror(corner_directions[1]), corner_directions[0], 1)
grid_map.grid[inner_node_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[0])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[0], mirror(corner_directions[0]), 1)
grid_map.grid[tmp_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[1])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[1], mirror(corner_directions[1]),
1)
grid_map.grid[tmp_pos] = transition
return
def align_cell_to_city(city_center, city_orientation, cell):
"""
Alig all cells to face the city center along the city orientation
@param city_center: Center needed for orientation
@param city_orientation: Orientation of the city
@param cell: Cell we would like to orient
:@return: Orientation of cell towards city center along axis of city orientation
"""
if city_orientation % 2 == 0:
return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
else:
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
"""Line generators (railway undertaking, "EVU")."""
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.timetable_utils import Line
from flatland.envs import persistence
AgentPosition = Tuple[int, int]
LineGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Line]
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None,
seed: int = None, np_random: RandomState = None) -> List[float]:
"""
Parameters
----------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
if speed_ratio_map is None:
return [1.0] * nb_agents
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np_random.choice(nb_classes, nb_agents, p=speed_ratios)))
class BaseLineGen(object):
def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1):
self.speed_ratio_map = speed_ratio_map
self.seed = seed
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
pass
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
def sparse_line_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> LineGenerator:
return SparseLineGen(speed_ratio_map, seed)
class SparseLineGen(BaseLineGen):
"""
This is the line generator which is used for Round 2 of the Flatland challenge. It produces lines
to railway networks provided by sparse_rail_generator.
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
"""
def decide_orientation(self, rail, start, target, possible_orientations, np_random: RandomState) -> int:
feasible_orientations = []
for orientation in possible_orientations:
if rail.check_path_exists(start[0], orientation, target[0]):
feasible_orientations.append(orientation)
if len(feasible_orientations) > 0:
return np_random.choice(feasible_orientations)
else:
return 0
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
np_random: RandomState) -> Line:
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the line
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = self.seed + num_resets
train_stations = hints['train_stations']
city_positions = hints['city_positions']
city_orientation = hints['city_orientations']
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
city1, city2 = None, None
city1_num_stations, city2_num_stations = None, None
city1_possible_orientations, city2_possible_orientations = None, None
for agent_idx in range(num_agents):
if (agent_idx % 2 == 0):
# Setlect 2 cities, find their num_stations and possible orientations
city_idx = np_random.choice(len(city_positions), 2, replace=False)
city1 = city_idx[0]
city2 = city_idx[1]
city1_num_stations = len(train_stations[city1])
city2_num_stations = len(train_stations[city2])
city1_possible_orientations = [city_orientation[city1],
(city_orientation[city1] + 2) % 4]
city2_possible_orientations = [city_orientation[city2],
(city_orientation[city2] + 2) % 4]
# Agent 1 : city1 > city2, Agent 2: city2 > city1
agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations
agent_start = train_stations[city1][agent_start_idx]
agent_target = train_stations[city2][agent_target_idx]
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city1_possible_orientations, np_random)
else:
agent_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations
agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations
agent_start = train_stations[city2][agent_start_idx]
agent_target = train_stations[city1][agent_target_idx]
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city2_possible_orientations, np_random)
# agent1 details
agents_position.append((agent_start[0][0], agent_start[0][1]))
agents_target.append((agent_target[0][0], agent_target[0][1]))
agents_direction.append(agent_orientation)
if self.speed_ratio_map:
speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
else:
speeds = [1.0] * len(agents_position)
# We add multiply factors to the max number of time steps to simplify task in Flatland challenge.
# These factors might change in the future.
timedelay_factor = 4
alpha = 2
max_episode_steps = int(
timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions)))
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds)
def line_from_file(filename, load_from_package=None) -> LineGenerator:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
max_episode_steps = env_dict.get("max_episode_steps", 0)
if (max_episode_steps==0):
print("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100
agents = env_dict["agents"]
# setup with loaded data
agents_position = [a.initial_position for a in agents]
# this logic is wrong - we should really load the initial_direction as the direction.
#agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed)
return generator
"""Malfunction generators for rail systems"""
from typing import Callable, NamedTuple, Optional, Tuple
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs import persistence
# why do we have both MalfunctionParameters and MalfunctionProcessData - they are both the same!
MalfunctionParameters = NamedTuple('MalfunctionParameters',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float:
"""
Probability of a single agent to break. According to Poisson process with given rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(-rate)
class ParamMalfunctionGen(object):
""" Preserving old behaviour of using MalfunctionParameters for constructor,
but returning MalfunctionProcessData in get_process_data.
Data structure and content is the same.
"""
def __init__(self, parameters: MalfunctionParameters):
#self.mean_malfunction_rate = parameters.malfunction_rate
#self.min_number_of_steps_broken = parameters.min_duration
#self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters
def generate(self, np_random: RandomState) -> Malfunction:
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
else:
num_broken_steps = 0
return Malfunction(num_broken_steps)
def get_process_data(self):
return MalfunctionProcessData(*self.MFP)
class NoMalfunctionGen(ParamMalfunctionGen):
def __init__(self):
super().__init__(MalfunctionParameters(0,0,0))
class FileMalfunctionGen(ParamMalfunctionGen):
def __init__(self, env_dict=None, filename=None, load_from_package=None):
""" uses env_dict if populated, otherwise tries to load from file / package.
"""
if env_dict is None:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
if env_dict.get('malfunction') is not None:
oMFP = MalfunctionParameters(*env_dict["malfunction"])
else:
oMFP = MalfunctionParameters(0,0,0) # no malfunctions
super().__init__(oMFP)
################################################################################################
# OLD / DEPRECATED generator functions below. To be removed.
def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which generates no malfunctions
Parameters
----------
Nothing
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use NoMalfunctionGen instead of no_malfunction_generator")
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
Parameters
----------
earlierst_malfunction: Earliest possible malfunction onset
malfunction_duration: The duration of the single malfunction
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
# Keep track of the total number of malfunctions in the env
global_nr_malfunctions = 0
# Malfunction calls per agent
malfunction_calls = dict()
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
# We use the global variable to assure only a single malfunction in the env
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
# Reset malfunciton generator
if reset:
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
global_nr_malfunctions = 0
malfunction_calls = dict()
return Malfunction(0)
# No more malfunctions if we already had one, ignore all updates
if global_nr_malfunctions > 0:
return Malfunction(0)
# Update number of calls per agent
if agent.handle in malfunction_calls:
malfunction_calls[agent.handle] += 1
else:
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \
and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use FileMalfunctionGen instead of malfunction_from_file")
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
# TODO: make this better by using namedtuple in the pickle file. See issue 282
if env_dict.get('malfunction') is not None:
env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction'])
else:
oMPD = None
if oMPD is not None:
# Mean malfunction in number of time steps
mean_malfunction_rate = oMPD.malfunction_rate
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = oMPD.min_duration
max_number_of_steps_broken = oMPD.max_duration
else:
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if agent.malfunction_handler.malfunction_down_counter < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
Parameters
----------
parameters : contains all the parameters of the malfunction
malfunction_rate : float rate per timestep at which each agent malfunctions
min_duration : int minimal duration of a failure
max_number_of_steps_broken : int maximal duration of a failure
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use ParamMalfunctionGen instead of malfunction_from_params")
mean_malfunction_rate = parameters.malfunction_rate
min_number_of_steps_broken = parameters.min_duration
max_number_of_steps_broken = parameters.max_duration
def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1)
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
"""
Collection of environment-specific ObservationBuilder.
"""
import pprint
from collections import deque
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero, fast_position_equal, fast_delete, fast_where
from flatland.envs.step_utils.states import TrainState
from flatland.utils.ordered_set import OrderedSet
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
class TreeObsForRailEnv(ObservationBuilder):
......@@ -22,157 +42,24 @@ class TreeObsForRailEnv(ObservationBuilder):
For details about the features in the tree observation see the get() function.
"""
def __init__(self, max_depth, predictor=None):
tree_explored_actions_char = ['L', 'F', 'R', 'B']
def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
super().__init__()
self.max_depth = max_depth
self.observation_dim = 9
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_dim = 9
self.observation_space = [size * self.observation_dim]
self.observation_dim = 11
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.agents_previous_reset = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
self.distance_map = None
self.distance_map_computed = False
self.location_has_target = None
def reset(self):
agents = self.env.agents
nb_agents = len(agents)
compute_distance_map = True
if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nb_agents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.distance_map is not None:
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
compute_distance_map = False
if compute_distance_map:
self._compute_distance_map()
self.agents_previous_reset = agents
def _compute_distance_map(self):
agents = self.env.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nb_agents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_many(self, handles=None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
if handles is None:
......@@ -181,34 +68,64 @@ class TreeObsForRailEnv(ObservationBuilder):
self.max_prediction_depth = 0
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
self.predictions = self.predictor.get()
if self.predictions:
for t in range(len(self.predictions[0])):
for t in range(self.predictor.max_depth + 1):
pos_list = []
dir_list = []
for a in handles:
if self.predictions[a] is None:
continue
pos_list.append(self.predictions[a][t][1:3])
dir_list.append(self.predictions[a][t][3])
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos)
observations = {}
for h in handles:
observations[h] = self.get(h)
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
# self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
# agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if not _agent.state.is_off_map_state() and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = \
_agent.malfunction_handler.malfunction_down_counter
# [NIMISH] WHAT IS THIS
if _agent.state.is_off_map_state() and \
_agent.initial_position:
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
# self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
# self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles)
return observations
def get(self, handle):
def get(self, handle: int = 0) -> Node:
"""
Computes the current observation for agent `handle' in env
Computes the current observation for agent `handle` in env
The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
the transitions. The order is:
the transitions. The order is::
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as:
Each branch data is organized as::
[root node information] +
[recursive branch data from 'left'] +
[... from 'forward'] +
......@@ -217,88 +134,123 @@ class TreeObsForRailEnv(ObservationBuilder):
Each node information is composed of 9 features:
#1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
#1:
if own target lies on the explored branch the current distance from the agent in number of cells is stored.
#2: if another agents target is detected the distance in number of cells from the agents current locaiton
#2:
if another agents target is detected the distance in number of cells from the agents current location\
is stored
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#3:
if another agent is detected the distance in number of cells from current agent position is stored.
#4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
#4:
possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
distance in number of cells from current agent position
0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#5:
if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#6:
This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#7:
minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#8: agent in the same direction
n = number of agents present same direction
#8:
agent in the same direction
n = number of agents present same direction \
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#9: agent in the opposite direction
n = number of agents present other direction than myself (so conflict)
#9:
agent in the opposite direction
n = number of agents present other direction than myself (so conflict) \
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#10:
malfunctioning/blokcing agents
n = number of time steps the oberved agent remains blocked
#11:
slowest observed speed of an agent in same direction
1 if no agent is observed
min_fractional speed otherwise
#12:
number of agents ready to depart but no yet active
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
# was referring to TreeObsForRailEnv.Node
root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[
(handle, *agent_virtual_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_handler.malfunction_down_counter,
speed_min_fractional=agent.speed_counter.speed,
num_agents_ready_to_depart=0,
childs={})
# print("root node type:", type(root_node_observation))
visited = OrderedSet()
visited = set()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch.
orientation = agent.direction
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
orientation = fast_argmax(possible_transitions)
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
new_cell = get_new_position(agent_virtual_position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
self.env.dev_obs_dict[handle] = visited
return observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations = 0
pow4 = 1
for i in range(remaining_depth):
num_observations += pow4
pow4 *= 4
return num_observations * self.observation_dim
return root_node_observation
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
......@@ -306,6 +258,7 @@ class TreeObsForRailEnv(ObservationBuilder):
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
# [Recursive branch opened]
if depth >= self.max_depth + 1:
return [], []
......@@ -319,8 +272,12 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_is_target = False
visited = set()
visited = OrderedSet()
agent = self.env.agents[handle]
distance_map_handle = self.env.distance_map.get()[handle]
time_per_cell = 1.0 / agent.speed_counter.speed
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
......@@ -328,50 +285,85 @@ class TreeObsForRailEnv(ObservationBuilder):
unusable_switch = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
malfunctioning_agent = 0
min_fractional_speed = 1.
num_steps = 1
other_agent_ready_to_depart_encountered = 0
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
if self.location_has_agent.get(position, 0) == 1:
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
# Check if any of the observed agents is malfunctioning, store agent with longest duration left
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += 1
if self.location_has_agent_direction[position] != direction:
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1
# Check fractional speed of agents
current_fractional_speed = self.location_has_agent_speed[position]
if current_fractional_speed < min_fractional_speed:
min_fractional_speed = current_fractional_speed
else:
# If no agent in the same direction was found all agents in that position are other direction
# Attention this counts to many agents as a few might be going off on a switch.
other_agent_opposite_direction += self.location_has_agent[position]
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
crossing_found = False
if int(transition_bit, 2) == int('1000010000100001', 2):
crossing_found = True
# Register possible future conflict
if self.predictor and num_steps < self.max_prediction_depth:
predicted_time = int(tot_dist * time_per_cell)
if self.predictor and predicted_time < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
if tot_dist < self.max_prediction_depth:
pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]:
pre_step = max(0, predicted_time - 1)
post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
# Look for conflicting paths at distance tot_dist
if int_position in fast_delete(self.predicted_pos[predicted_time], handle):
conflicting_agent = fast_where(self.predicted_pos[predicted_time] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
elif int_position in fast_delete(self.predicted_pos[pre_step], handle):
conflicting_agent = fast_where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca] \
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
# Look for conflicting paths at distance num_step+1
elif int_position in fast_delete(self.predicted_pos[post_step], handle):
conflicting_agent = fast_where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......@@ -389,13 +381,15 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if np.array_equal(position, self.env.agents[handle].target):
if fast_position_equal(position, self.env.agents[handle].target):
last_is_target = True
break
cell_transitions = self.env.rail.get_transitions(*position, direction)
total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1")
num_transitions = np.count_nonzero(cell_transitions)
# Check if crossing is found --> Not an unusable switch
if crossing_found:
# Treat the crossing as a straight rail cell
total_transitions = 2
num_transitions = fast_count_nonzero(cell_transitions)
exploring = False
......@@ -411,11 +405,11 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_dead_end = True
if not last_is_dead_end:
# Keep walking through the tree along `direction'
# Keep walking through the tree along `direction`
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
direction = fast_argmax(cell_transitions)
position = get_new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
......@@ -430,124 +424,112 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_terminal = True
break
# `position' is either a terminal node or a switch
# `position` is either a terminal node or a switch
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
if last_is_target:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction
]
dist_to_next_branch = tot_dist
dist_min_to_target = 0
elif last_is_terminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction
]
dist_to_next_branch = np.inf
dist_min_to_target = distance_map_handle[position[0], position[1], direction]
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
]
dist_to_next_branch = tot_dist
dist_min_to_target = distance_map_handle[position[0], position[1], direction]
# TreeObsForRailEnv.Node
node = Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
new_cell = get_new_position(position, (branch_direction + 2) % 4)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explored_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
new_cell = get_new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explored_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
node.childs[self.tree_explored_actions_char[i]] = -np.inf
return observation, visited
if depth == self.max_depth:
node.childs.clear()
return node, visited
def util_print_obs_subtree(self, tree):
def util_print_obs_subtree(self, tree: Node):
"""
Utility function to pretty-print tree observations returned by this object.
Utility function to print tree observations returned by this object.
"""
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(self.unfold_observation_tree(tree))
self.print_node_features(tree, "root", "")
for direction in self.tree_explored_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
@staticmethod
def print_node_features(node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
node.num_agents_ready_to_depart)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
print(indent, "Direction ", label, ": -np.inf")
return
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < self.observation_dim:
self.print_node_features(node, label, indent)
if not node.childs:
return
depth = 0
tmp = len(tree) / self.observation_dim - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
unfolded = {}
unfolded[''] = tree[0:self.observation_dim]
child_size = (len(tree) - self.observation_dim) // 4
for child in range(4):
child_tree = tree[(self.observation_dim + child * child_size):
(self.observation_dim + (child + 1) * child_size)]
observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
if observation_tree is not None:
if actions_for_display:
label = self.tree_explorted_actions_char[child]
else:
label = self.tree_explored_actions[child]
unfolded[label] = observation_tree
return unfolded
for direction in self.tree_explored_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def _set_env(self, env):
self.env = env
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
self.predictor._set_env(self.env)
self.predictor.set_env(self.env)
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
class GlobalObsForRailEnv(ObservationBuilder):
......@@ -555,25 +537,25 @@ class GlobalObsForRailEnv(ObservationBuilder):
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
- transition map array with dimensions (env.height, env.width, 16),\
assuming 16 bits encoding of transitions.
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets.
- obs_agents_state: A 3D array (map_height, map_width, 5) with
- first channel containing the agents position and direction
- second channel containing the other agents positions and direction
- third channel containing agent/other agent malfunctions
- fourth channel containing agent/other agent fractional speeds
- fifth channel containing number of other agents ready to depart
- A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
of the direction of the given agent and the 4 second channels containing the positions
of the other agents at their position coordinates.
- obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
target and the positions of the other agents targets (flag only, no counter!).
"""
def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnv, self).__init__()
def _set_env(self, env):
super()._set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def set_env(self, env: Environment):
super().set_env(env)
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
......@@ -583,169 +565,189 @@ class GlobalObsForRailEnv(ObservationBuilder):
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
agents = self.env.agents
agent = agents[handle]
direction = np.zeros(4)
direction[agent.direction] = 1
agent_pos = agents[handle].position
obs_agents_state[agent_pos][:4] = direction
obs_targets[agent.target][0] += 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][4 + agent2.direction] = 1
obs_targets[agent2.target][1] += 1
direction = self._get_one_hot_for_agent_direction(agent)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
return self.rail_obs, obs_agents_state, obs_targets, direction
class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions, flipped in the direction of the agent
(the agent is always heading north on the flipped view).
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets, also flipped depending on the agent's direction.
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
agents at their position coordinates, and the last channel containing the position of the given agent.
- A 4 elements array with one hot encoding of the direction.
"""
def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnvDirectionDependent, self).__init__()
def _set_env(self, env):
super()._set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
agent = self.env.agents[handle]
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
def get(self, handle):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
agents = self.env.agents
agent = agents[handle]
direction = agent.direction
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
idx = np.tile(np.arange(16), 2)
# TODO can we do this more elegantly?
# for r in range(self.env.height):
# for c in range(self.env.width):
# obs_agents_state[(r, c)][4] = 0
obs_agents_state[:, :, 4] = 0
rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
if direction == 1:
rail_obs = np.flip(rail_obs, axis=1)
elif direction == 2:
rail_obs = np.flip(rail_obs)
elif direction == 3:
rail_obs = np.flip(rail_obs, axis=0)
obs_agents_state[agent_virtual_position][0] = agent.direction
obs_targets[agent.target][0] = 1
agent_pos = agents[handle].position
obs_agents_state[agent_pos][0] = 1
obs_targets[agent.target][0] += 1
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
idx = np.tile(np.arange(4), 2)
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
obs_targets[agent2.target][1] += 1
# ignore other agents not in the grid any more
if other_agent.state == TrainState.DONE:
continue
direction = self._get_one_hot_for_agent_direction(agent)
obs_targets[other_agent.target][1] = 1
return rail_obs, obs_agents_state, obs_targets, direction
# second to fourth channel only if in the grid
if other_agent.position is not None:
# second channel only for other agents
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_handler.malfunction_down_counter
obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
class LocalObsForRailEnv(ObservationBuilder):
"""
!!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
Gives a local observation of the rail environment around the agent.
The observation is composed of the following elements:
- transition map array of the local environment around the given agent,
with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
- transition map array of the local environment around the given agent, \
with dimensions (view_height,2*view_width+1, 16), \
assuming 16 bits encoding of transitions.
- Two 2D arrays (2*view_radius + 1, 2*view_radius + 1, 2) containing respectively,
- Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
if they are in the agent's vision range, its target position, the positions of the other targets.
- A 3D array (2*view_radius + 1, 2*view_radius + 1, 4) containing the one hot encoding of directions
- A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
Use the parameters view_width and view_height to define the rectangular view of the agent.
The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
observation in front of it.
.. deprecated:: 2.0.0
"""
def __init__(self, view_radius):
"""
:param view_radius:
"""
def __init__(self, view_width, view_height, center):
super(LocalObsForRailEnv, self).__init__()
self.view_radius = view_radius
self.view_width = view_width
self.view_height = view_height
self.center = center
self.max_padding = max(self.view_width, self.view_height - self.center)
def reset(self):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
self.env.width + 2 * self.view_radius, 16))
self.max_padding = max(self.view_width, self.view_height)
self.rail_obs = np.zeros((self.env.height,
self.env.width, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
agents = self.env.agents
agent = agents[handle]
local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
def relative_pos(pos):
return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
def is_in(rel_pos):
return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius)
target_rel_pos = relative_pos(agent.target)
if is_in(target_rel_pos):
obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
agent_2_rel_pos = relative_pos(agent2.position)
if is_in(agent_2_rel_pos):
obs_other_agents_state[self.view_radius + agent_2_rel_pos[0],
self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1
target_rel_pos_2 = relative_pos(agent2.position)
if is_in(target_rel_pos_2):
obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
# Correct agents position for padding
# agent_rel_pos[0] = agent.position[0] + self.max_padding
# agent_rel_pos[1] = agent.position[1] + self.max_padding
# Collect visible cells as set to be plotted
visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
local_rail_obs = None
# Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = set(visited)
# Locate observed agents and their coresponding targets
local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
_idx = 0
for pos in visited:
curr_rel_coord = rel_coords[_idx]
local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
if pos == agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
else:
for tmp_agent in agents:
if pos == tmp_agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
if pos != agent.position:
for tmp_agent in agents:
if pos == tmp_agent.position:
obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
tmp_agent.direction]
_idx += 1
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
direction = self._get_one_hot_for_agent_direction(agent)
def get_many(self, handles: Optional[List[int]] = None) -> Dict[
int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
return super().get_many(handles)
def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
if state is not None:
temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
data_collection = True
if direction == 0:
origin = (position[0] + self.center, position[1] - self.view_width)
elif direction == 1:
origin = (position[0] - self.view_width, position[1] - self.center)
elif direction == 2:
origin = (position[0] - self.center, position[1] + self.view_width)
else:
origin = (position[0] + self.view_width, position[1] + self.center)
visible = list()
rel_coords = list()
for h in range(self.view_height):
for w in range(2 * self.view_width + 1):
if direction == 0:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.append((origin[0] - h, origin[1] + w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.append((origin[0] + w, origin[1] + h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2:
if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
visible.append((origin[0] + h, origin[1] - w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else:
if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
visible.append((origin[0] - w, origin[1] - h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection:
return temp_visible_data
else:
return visible, rel_coords
import pickle
import msgpack
import numpy as np
import msgpack_numpy
msgpack_numpy.patch()
from flatland.envs import rail_env
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent, load_env_agent
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
class RailEnvPersister(object):
@classmethod
def save(cls, env, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
env_dict = cls.get_full_state(env)
# We have an unresolved problem with msgpack loading the list of agents
# see also 20 lines below.
# print(f"env save - agents: {env_dict['agents'][0]}")
# a0 = env_dict["agents"][0]
# print("agent type:", type(a0))
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
if len(oDistMap) > 0:
env_dict["distance_map"] = oDistMap
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
data = msgpack.packb(env_dict)
elif filename.endswith("pkl"):
data = pickle.dumps(env_dict)
#pickle.dump(env_dict, file_out)
file_out.write(data)
# We have an unresovled problem with msgpack loading the list of Agents
# with open(filename, "rb") as file_in:
# if filename.endswith("mpk"):
# bytes_in = file_in.read()
# dIn = msgpack.unpackb(data, encoding="utf-8")
# print(f"msgpack check - {dIn.keys()}")
# print(f"msgpack check - {dIn['agents'][0]}")
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
# Add additional info to dict_env before saving
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
dict_env["shape"] = (env.width, env.height)
dict_env["max_episode_steps"] = env._max_episode_steps
with open(filename, "wb") as file_out:
if filename.endswith(".mpk"):
file_out.write(msgpack.packb(dict_env))
elif filename.endswith(".pkl"):
pickle.dump(dict_env, file_out)
@classmethod
def load(cls, env, filename, load_from_package=None):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
cls.set_full_state(env, env_dict)
@classmethod
def load_new(cls, filename, load_from_package=None):
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
llGrid = env_dict["grid"]
height = len(llGrid)
width = len(llGrid[0])
# TODO: inefficient - each one of these generators loads the complete env file.
env = rail_env.RailEnv(#width=1, height=1,
width=width, height=height,
rail_generator=rail_gen.rail_from_file(filename,
load_from_package=load_from_package),
line_generator=line_gen.line_from_file(filename,
load_from_package=load_from_package),
#malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
# load_from_package=load_from_package),
malfunction_generator=mal_gen.FileMalfunctionGen(env_dict),
obs_builder_object=DummyObservationBuilder(),
record_steps=True)
env.rail = GridTransitionMap(1,1) # dummy
cls.set_full_state(env, env_dict)
return env, env_dict
@classmethod
def load_env_dict(cls, filename, load_from_package=None):
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
elif filename.endswith("pkl"):
try:
env_dict = pickle.loads(load_data)
except ValueError:
print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
# Replace the agents tuple with EnvAgent objects
if "agents_static" in env_dict:
env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
# remove the legacy key
del env_dict["agents_static"]
elif "agents" in env_dict:
# env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
env_dict["agents"] = [load_env_agent(d) for d in env_dict["agents"]]
return env_dict
@classmethod
def load_resource(cls, package, resource):
"""
Load environment (with distance map?) from a binary
"""
#from importlib_resources import read_binary
#load_data = read_binary(package, resource)
#if resource.endswith("pkl"):
# env_dict = pickle.loads(load_data)
#elif resource.endswith("mpk"):
# env_dict = msgpack.unpackb(load_data, encoding="utf-8")
#cls.set_full_state(env, env_dict)
return cls.load_new(resource, load_from_package=package)
@classmethod
def set_full_state(cls, env, env_dict):
"""
Sets environment state from env_dict
Parameters
-------
env_dict: dict
"""
env.rail.grid = np.array(env_dict["grid"])
# Initialise the env with the frozen agents in the file
env.agents = env_dict.get("agents", [])
# For consistency, set number_of_agents, which is the number which will be generated on reset
env.number_of_agents = env.get_num_agents()
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
env.rail.width = env.width
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
@classmethod
def get_full_state(cls, env):
"""
Returns state of environment in dict object, ready for serialization
"""
grid_data = env.rail.grid.tolist()
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
#print("get_full_state - agent_data:", agent_data)
malfunction_data: mal_gen.MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
"grid": grid_data,
"agents": agent_data,
"malfunction": malfunction_data,
"max_episode_steps": env._max_episode_steps,
}
return msg_data_dict
################################################################################################
# deprecated methods moved from RailEnv. Most likely broken.
def deprecated_get_full_state_msg(self) -> msgpack.Packer:
"""
Returns state of environment in msgpack object
"""
msg_data_dict = self.get_full_state_dict()
return msgpack.packb(msg_data_dict, use_bin_type=True)
def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_agent() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
# I think these calls do nothing - they create packed data and it is discarded
#msgpack.packb(grid_data, use_bin_type=True)
#msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data: mal_gen.MalfunctionProcessData = self.malfunction_process_data
#msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data = {
"grid": grid_data,
"agents": agent_data,
"distance_map": distance_map_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def deprecated_set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
......@@ -5,8 +5,12 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils import transition_utils
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -17,15 +21,13 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......@@ -47,11 +49,14 @@ class DummyPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if not agent.state.is_on_map_state():
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_virtual_position = agent.position
agent_virtual_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
......@@ -60,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
new_cell_isValid, new_direction, new_position, transition_isValid = \
transition_utils.check_action_on_agent(action, self.env.rail, agent.position, agent.direction)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
......@@ -73,8 +78,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
agent.position = agent_virtual_position
agent.direction = agent_virtual_direction
return prediction_dict
......@@ -86,16 +91,20 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
def __init__(self, max_depth: int = 20):
super().__init__(max_depth)
def get(self, handle: int = None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Does not take into account future positions of other agents!
If there is no shortest path, the agent just stands still and stops moving.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......@@ -106,68 +115,66 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
- position axis 0
- position axis 1
- direction
- action taken to come here
- action taken to come here (not implemented yet)
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args is not None
distance_map = custom_args.get('distance_map')
assert distance_map is not None
distance_map: DistanceMap = self.env.distance_map
shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
prediction_dict = {}
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
prediction = np.zeros(shape=(self.max_depth + 1, 5))
for i in range(self.max_depth):
prediction[i] = [i, None, None, None, None]
prediction_dict[agent.handle] = prediction
continue
agent_virtual_direction = agent.direction
agent_speed = agent.speed_counter.speed
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
visited = set()
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
shortest_path = shortest_paths[agent.handle]
# if there is a shortest path, remove the initial position
if shortest_path:
shortest_path = shortest_path[1:]
new_direction = agent_virtual_direction
new_position = agent_virtual_position
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
if not agent.moving:
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
# if we're at the target, stop moving until max_depth is reached
if new_position == agent.target or not shortest_path:
prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
visited.add((*new_position, agent.direction))
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist or no_dist_found:
min_dist = target_dist
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
else:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
if index % times_per_cell == 0:
new_position = shortest_path[0].position
new_direction = shortest_path[0].direction
shortest_path = shortest_path[1:]
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction))
visited.add((*new_position, new_direction))
# TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
"""
Definition of the RailEnv environment.
"""
# TODO: _ this is a global method --> utils or remove later
import random
from enum import IntEnum
from typing import List, Optional, Dict, Tuple
import msgpack
import msgpack_numpy as m
import numpy as np
from gym.utils import seeding
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
m.patch()
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs import persistence
from flatland.envs import agent_chains as ac
from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.envs.step_utils import action_preprocessing
from flatland.envs.step_utils import env_utils
class RailEnv(Environment):
"""
......@@ -65,43 +62,71 @@ class RailEnv(Environment):
It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.
alpha = 1
beta = 1
alpha = 0
beta = 0
Reward function parameters:
- invalid_action_penalty = 0
- step_penalty = -alpha
- global_reward = beta
- epsilon = avoid rounding errors
- stop_penalty = 0 # penalty for stopping a moving agent
- start_penalty = 0 # penalty for starting a stopped agent
Stochastic malfunctioning of trains:
Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
action or cell is selected.
Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
complexity managable.
TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
"""
# Epsilon to avoid rounding errors
epsilon = 0.01
# NEW : REW: Sparse Reward
alpha = 0
beta = 0
step_penalty = -1 * alpha
global_reward = 1 * beta
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
cancellation_factor = 1
cancellation_time_buffer = 0
def __init__(self,
width,
height,
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None
rail_generator=None,
line_generator=None, # : line_gen.LineGenerator = line_gen.random_line_generator(),
number_of_agents=2,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(),
malfunction_generator=None,
remove_agents_at_target=True,
random_seed=None,
record_steps=False,
):
"""
Environment init.
Parameters
-------
----------
rail_generator : function
The rail_generator function is a function that takes the width,
height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
Implemented functions are:
random_rail_generator : generate a random rail of given size
rail_from_grid_transition_map(rail_map) : generate a rail from
a GridTransitionMap object
rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
The rail_generator can pass a distance map in the hints or information for specific line_generators.
Implementations can be found in flatland/envs/rail_generators.py
line_generator : function
The line_generator function is a function that takes the grid, the number of agents and optional hints
and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
Implementations can be found in flatland/envs/line_generators.py
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
......@@ -114,396 +139,635 @@ class RailEnv(Environment):
obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation
vectors for each agent.
max_episode_steps : int or None
file_name: you can load a pickle file.
remove_agents_at_target : bool
If remove_agents_at_target is set to true then the agents will be removed by placing to
RailEnv.DEPOT_POSITION when the agent has reach it's target position.
random_seed : int or None
if None, then its ignored, else the random generators are seeded with this number to ensure
that stochastic operations are replicable across multiple operations
"""
super().__init__()
if malfunction_generator_and_process_data is not None:
print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
elif malfunction_generator is not None:
self.malfunction_generator = malfunction_generator
# malfunction_process_data is not used
# self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.malfunction_process_data = self.malfunction_generator.get_process_data()
# replace default values here because we can't use default args values because of cyclic imports
else:
self.malfunction_generator = mal_gen.NoMalfunctionGen()
self.malfunction_process_data = self.malfunction_generator.get_process_data()
self.number_of_agents = number_of_agents
if rail_generator is None:
rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator
self.rail = None
if line_generator is None:
line_generator = line_gen.sparse_line_generator()
self.line_generator = line_generator
self.rail: Optional[GridTransitionMap] = None
self.width = width
self.height = height
self.rewards = [0] * number_of_agents
self.done = False
self.remove_agents_at_target = remove_agents_at_target
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.obs_builder.set_env(self)
self._max_episode_steps = max_episode_steps
self._max_episode_steps: Optional[int] = None
self._elapsed_steps = 0
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
self.obs_dict = {}
self.rewards_dict = {}
self.dev_obs_dict = {}
self.dev_pred_dict = {}
self.agents = [None] * number_of_agents # live agents
self.agents_static = [None] * number_of_agents # static agent information
self.agents: List[EnvAgent] = []
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.action_space = [5]
self._seed()
if random_seed:
self._seed(seed=random_seed)
self.agent_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.record_steps = record_steps # whether to save timesteps
# save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
self.cur_episode = []
self.list_actions = [] # save actions in here
self.motionCheck = ac.MotionCheck()
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
self.random_seed = seed
self.reset()
self.num_resets = 0 # yes, set it to zero again!
# Keep track of all the seeds in order
if not hasattr(self, 'seed_history'):
self.seed_history = [seed]
if self.seed_history[-1] != seed:
self.seed_history.append(seed)
self.valid_positions = None
return [seed]
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
def get_num_agents(self, static=True):
if static:
return len(self.agents_static)
else:
return len(self.agents)
def get_num_agents(self) -> int:
return len(self.agents)
def add_agent_static(self, agent_static):
def add_agent(self, agent):
""" Add static info for a single agent.
Returns the index of the new agent.
"""
self.agents_static.append(agent_static)
return len(self.agents_static) - 1
self.agents.append(agent)
return len(self.agents) - 1
def reset_agents(self):
""" Reset the agents to their starting positions
"""
for agent in self.agents:
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]
def action_required(self, agent):
"""
Check if an agent needs to provide an action
def restart_agents(self):
""" Reset the agents to their starting positions defined in agents_static
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
self.agents = EnvAgent.list_from_static(self.agents_static)
return agent.state == TrainState.READY_TO_DEPART or \
( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regen_rail=True, replace_agents=True):
""" if regen_rail then regenerate the rails.
if replace_agents then regenerate the agents static.
Relies on the rail_generator returning agent_static lists (pos, dir, target)
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: int = None) -> Tuple[Dict, Dict]:
"""
tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
# Check if generator provided a distance map TODO: Make this check safer!
if len(tRailAgents) > 5:
self.obs_builder.distance_map = tRailAgents[-1]
The method resets the rail environment
if regen_rail or self.rail is None:
self.rail = tRailAgents[0]
Parameters
----------
regenerate_rail : bool, optional
regenerate the rails
regenerate_schedule : bool, optional
regenerate the schedule and the static agents
random_seed : int, optional
random seed for environment
Returns
-------
observation_dict: Dict
Dictionary with an observation for each agent
info_dict: Dict with agent specific information
"""
if random_seed:
self._seed(random_seed)
optionals = {}
if regenerate_rail or self.rail is None:
if "__call__" in dir(self.rail_generator):
rail, optionals = self.rail_generator(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
elif "generate" in dir(self.rail_generator):
rail, optionals = self.rail_generator.generate(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
else:
raise ValueError("Could not invoke __call__ or generate on rail_generator")
self.rail = rail
self.height, self.width = self.rail.grid.shape
if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
# Do a new set_env call on the obs_builder to ensure
# that obs_builder specific instantiations are made according to the
# specifications of the current environment : like width, height, etc
self.obs_builder.set_env(self)
self.restart_agents()
if optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
agent.speed_data['position_fraction'] = 0.0
if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
self.num_resets, self.np_random)
self.agents = EnvAgent.from_line(line)
# Reset distance map - basically initializing
self.distance_map.reset(self.agents, self.rail)
# NEW : Time Schedule Generation
timetable = timetable_generator(self.agents, self.distance_map,
agents_hints, self.np_random)
self._max_episode_steps = timetable.max_episode_steps
for agent_i, agent in enumerate(self.agents):
agent.earliest_departure = timetable.earliest_departures[agent_i]
agent.latest_arrival = timetable.latest_arrivals[agent_i]
else:
self.distance_map.reset(self.agents, self.rail)
# Reset agents to initial states
self.reset_agents()
self.num_resets += 1
self._elapsed_steps = 0
# TODO perhaps dones should be part of each agent.
# Agent positions map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
self._update_agent_positions_map(ignore_old_positions=False)
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
self.observation_space = self.obs_builder.observation_space # <-- change on reset?
# Empty the episode store of agent positions
self.cur_episode = []
info_dict = self.get_info_dict()
# Return the new observation vectors for each agent
return self._get_observations()
observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def step(self, action_dict_):
self._elapsed_steps += 1
action_dict = action_dict_.copy()
def _update_agent_positions_map(self, ignore_old_positions=True):
""" Update the agent_positions array for agents that changed positions """
for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position:
if agent.position is not None:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
alpha = 1.0
beta = 1.0
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals()
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
step_penalty = -1 * alpha
global_reward = 1 * beta
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
# Malfunction starts when in_malfunction is set to true
st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
# Reset the step rewards
self.rewards_dict = dict()
for i_agent in range(self.get_num_agents()):
self.rewards_dict[i_agent] = 0
# Malfunction counter complete - Malfunction ends next timestep
st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
# Earliest departure reached - Train is allowed to move now
st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
# Stop Action Given
st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
# Valid Movement action Given
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
# Target Reached
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
return st_signals
def _handle_end_reward(self, agent: EnvAgent) -> int:
'''
Handles end-of-episode reward for a particular agent.
Parameters
----------
agent : EnvAgent
'''
reward = None
# agent done? (arrival_time is not None)
if agent.state == TrainState.DONE:
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.state.is_off_map_state()):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
# Departed but never reached
if (agent.state.is_on_map_state()):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
return reward
def preprocess_action(self, action, agent):
"""
Preprocess the provided action
* Change to DO_NOTHING if illegal action
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
if current_position is None: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
# Check transitions, bounts for executing the action in the given position and directon
if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction):
action = RailEnvActions.STOP_MOVING
return action
def clear_rewards_dict(self):
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
def get_info_dict(self):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
"""
info_dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
def update_step_rewards(self, i_agent):
"""
Update the rewards dict for agent id i_agent for every timestep
"""
pass
def end_of_episode_update(self, have_all_agents_ended):
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if have_all_agents_ended or \
( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
def handle_done_state(self, agent):
""" Any updates to agent to be made in Done state """
if agent.state == TrainState.DONE and agent.arrival_time is None:
agent.arrival_time = self._elapsed_steps
self.dones[agent.handle] = True
if self.remove_agents_at_target:
agent.position = None
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
self._elapsed_steps += 1
# Not allowed to step further once done
if self.dones["__all__"]:
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
return self._get_observations(), self.rewards_dict, self.dones, {}
raise Exception("Episode is done, cannot call step()")
# for i in range(len(self.agents_handles)):
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
agent.old_direction = agent.direction
self.clear_rewards_dict()
have_all_agents_ended = True # Boolean flag to check if all agents are done
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_transition_data = {}
for agent in self.agents:
i_agent = agent.handle
agent.old_position = agent.position
if self.dones[i_agent]: # this agent has already completed...
continue
if i_agent not in action_dict: # no action has been supplied for this agent
action_dict[i_agent] = RailEnvActions.DO_NOTHING
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent],
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action_dict[i_agent] = RailEnvActions.DO_NOTHING
action = action_dict[i_agent]
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[i_agent] += stop_penalty
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
agent.moving = True
self.rewards_dict[i_agent] += start_penalty
# Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
# store the desired action in `transition_action_on_cellexit' (only if the desired transition is
# allowed! otherwise DO_NOTHING!)
# Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
# position_fraction by the speed of the agent (regardless of action taken, as long as no
# STOP_MOVING, but that makes agent.moving=False)
# If the new position fraction is >= 1, reset to 0, and perform the stored
# transition_action_on_cellexit
# If the agent can make an action
action_selected = False
if agent.speed_data['position_fraction'] == 0.:
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = action
action_selected = True
else:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
action_selected = True
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty
agent.moving = False
continue
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty
agent.moving = False
continue
if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] >= 1.0:
# Perform stored action to transition to the next cell
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
# the cell
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
if all([new_cell_valid, transition_valid, cell_free]):
agent.position = new_position
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
if np.equal(agent.position, agent.target).all():
self.dones[i_agent] = True
agent.old_direction = agent.direction
# Generate malfunction
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
# Get action for the agent
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
preprocessed_action = self.preprocess_action(action, agent)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if train is at cell's exit and train is not in malfunction
position_update_allowed = agent.speed_counter.is_cell_exit and \
not agent.malfunction_handler.malfunction_down_counter > 0 and \
not preprocessed_action == RailEnvActions.STOP_MOVING
# Calculate new position
# Keep agent in same place if already done
if agent.state == TrainState.DONE:
new_position, new_direction = agent.position, agent.direction
# Add agent to the map if not on it yet
elif agent.position is None and agent.action_saver.is_action_saved:
new_position = agent.initial_position
new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = env_utils.apply_action_independent(saved_action,
self.rail,
agent.position,
agent.direction)
preprocessed_action = saved_action
else:
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
new_position, new_direction = agent.position, agent.direction
# Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
self.dones["__all__"] = True
self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()}
temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
for k in self.dones.keys():
self.dones[k] = True
return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent):
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_valid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
new_cell_valid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_free = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def check_action(self, agent, action):
transition_valid = None
possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
new_direction = agent.direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = agent.direction - 1
if num_transitions <= 1:
transition_valid = False
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = agent.direction + 1
if num_transitions <= 1:
transition_valid = False
new_direction %= 4
if action == RailEnvActions.MOVE_FORWARD:
if num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = np.argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
# This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
# Find conflicts between trains trying to occupy same cell
self.motionCheck.find_conflicts()
for agent in self.agents:
i_agent = agent.handle
## Update positions
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit
movement_allowed = movement_allowed or movement_inside_cell
# Fetch the saved transition data
agent_transition_data = temp_transition_data[i_agent]
preprocessed_action = agent_transition_data.preprocessed_action
## Update states
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
# Needed when not removing agents at target
movement_allowed = movement_allowed and agent.state != TrainState.DONE
# Agent is being added to map
if agent.state.is_on_map_state():
if agent.state_machine.previous_state.is_off_map_state():
agent.position = agent.initial_position
agent.direction = agent.initial_direction
# Speed counter completes
elif movement_allowed and (agent.speed_counter.is_cell_exit):
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
agent.state_machine.update_if_reached(agent.position, agent.target)
# Off map or on map state and position should match
env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards
self.update_step_rewards(i_agent)
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state, agent.old_position)
# agent.state_machine.previous_state)
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action()
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions):
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if agent.position is None:
pos = (0, 0)
else:
pos = (int(agent.position[0]), int(agent.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_handler.malfunction_down_counter,
int(agent.status),
int(agent.position in self.motionCheck.svDeadlocked)
])
self.cur_episode.append(list_agents_state)
self.list_actions.append(dActions)
def _get_observations(self):
"""
Utility which returns the dictionary of observations for an agent with respect to environment
"""
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data)
msgpack.packb(agent_data)
msgpack.packb(agent_static_data)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self):
agent_data = [agent.to_list() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def set_full_state_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def set_full_state_dist_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys():
self.obs_builder.distance_map = data[b"distance_maps"]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def _exp_distirbution_synced(self, rate: float) -> float:
"""
Generates sample from exponential distribution
We need this to guarantee synchronity between different instances with same seed.
:param rate:
:return:
"""
u = self.np_random.rand()
x = - np.log(1 - u) * rate
return x
def get_full_state_dist_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data)
msgpack.packb(agent_data)
msgpack.packb(agent_static_data)
if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map
msgpack.packb(distance_map_data)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_maps": distance_map_data}
else:
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
def _is_agent_ok(self, agent: EnvAgent) -> bool:
"""
Check if an agent is ok, meaning it can move and is not malfuncitoinig
Parameters
----------
agent
Returns
-------
True if agent is ok, False otherwise
"""
return agent.malfunction_handler.in_malfunction
return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename):
if hasattr(self.obs_builder, 'distance_map'):
if len(self.obs_builder.distance_map) > 0:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg())
else:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg())
else:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg())
def load(self, filename):
if hasattr(self.obs_builder, 'distance_map'):
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_msg(load_data)
print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
def load_pkl(self, pkl_data):
self.set_full_state_msg(pkl_data)
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
def load_resource(self, package, resource):
from importlib_resources import read_binary
load_data = read_binary(package, resource)
self.set_full_state_msg(load_data)
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
def close(self):
"""
This methods closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
from enum import IntEnum
from typing import NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
@classmethod
def is_action_valid(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
('next_direction', Grid4TransitionsEnum)])