Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2713 additions and 1568 deletions
from collections import defaultdict
from typing import Dict, Tuple
from flatland.contrib.utils.deadlock_checker import Deadlock_Checker
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.step_utils.states import TrainState
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.state == TrainState.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
else:
print("no action possible!")
print("agent state: ", agent.state)
# NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...)
return [(RailEnvActions.DO_NOTHING, 0)] * 2
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
print(f"possible transitions: {possible_transitions}")
distance_map = env.distance_map.get()[handle]
possible_steps = []
for movement in list(range(4)):
if possible_transitions[movement]:
if movement == agent.direction:
action = RailEnvActions.MOVE_FORWARD
elif movement == (agent.direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?")
action = RailEnvActions.MOVE_FORWARD
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
# if there is only one path to target, this is both the shortest one and the second shortest path.
if len(possible_steps) == 1:
return possible_steps * 2
else:
return possible_steps
class RailEnvWrapper:
def __init__(self, env:RailEnv):
self.env = env
assert self.env is not None
assert self.env.rail is not None, "Reset original environment first!"
assert self.env.agents is not None, "Reset original environment first!"
assert len(self.env.agents) > 0, "Reset original environment first!"
# @property
# def number_of_agents(self):
# return self.env.number_of_agents
# @property
# def agents(self):
# return self.env.agents
# @property
# def _seed(self):
# return self.env._seed
# @property
# def obs_builder(self):
# return self.env.obs_builder
def __getattr__(self, name):
try:
return super().__getattr__(self,name)
except:
"""Expose any other attributes of the underlying environment."""
return getattr(self.env, name)
@property
def rail(self):
return self.env.rail
@property
def width(self):
return self.env.width
@property
def height(self):
return self.env.height
@property
def agent_positions(self):
return self.env.agent_positions
def get_num_agents(self):
return self.env.get_num_agents()
def get_agent_handles(self):
return self.env.get_agent_handles()
def step(self, action_dict: Dict[int, RailEnvActions]):
return self.env.step(action_dict)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return obs, info
class ShortestPathActionWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv):
super().__init__(env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# input: action dict with actions in [0, 1, 2].
transformed_action_dict = {}
for agent_id, action in action_dict.items():
if action == 0:
transformed_action_dict[agent_id] = action
else:
#assert action in [1, 2]
#assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
#assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
obs, rewards, dones, info = self.env.step(transformed_action_dict)
return obs, rewards, dones, info
def find_all_cells_where_agent_can_choose(env: RailEnv):
"""
input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
WHICH HAS BEEN RESET ALREADY!
(o.w., we call env.rail, which is None before reset(), and crash.)
"""
switches = []
switches_neighbors = []
directions = list(range(4))
for h in range(env.height):
for w in range(env.width):
pos = (h, w)
is_switch = False
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
break
if is_switch:
# Add all neighbouring rails, if pos is a switch
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
for movement in directions:
if possible_transitions[movement]:
switches_neighbors.append(get_new_position(pos, movement))
decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.switches = None
self.switches_neighbors = None
self.decision_cells = None
self.skipped_rewards = defaultdict(list)
# sets initial values for switches, decision_cells, etc.
self.reset_cells()
def on_decision_cell(self, agent: EnvAgent) -> bool:
return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
def on_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches
def next_to_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches_neighbors
def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# need to initialize i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["state"] = dict()
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self.reset_cells()
return obs, info
class DeadlockWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv, deadlock_reward=-100) -> None:
super().__init__(env)
self.deadlock_reward = deadlock_reward
self.deadlock_checker = Deadlock_Checker(env=self.env)
@property
def deadlocked_agents(self):
return self.deadlock_checker.deadlocked_agents
@property
def immediate_deadlocks(self):
return [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
# make sure to assign the deadlock reward only once to each deadlocked agent...
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# agents which are already deadlocked from previous steps
already_deadlocked_ids = [agent.handle for agent in self.deadlocked_agents]
# step environment
obs, rewards, dones, info = self.env.step(action_dict)
# compute new list of deadlocked agents (ids) after stepping the environment
deadlocked_agents = self.deadlock_checker.check_deadlocks(action_dict) # also stored in self.deadlocked_checker.deadlocked_agents
deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents]
# immediate deadlocked ids only used for prints
immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
print(f"immediate deadlocked: {immediate_deadlocked_ids}")
print(f"total deadlocked: {deadlocked_agents_ids}")
newly_deadlocked_agents_ids = [agent_id for agent_id in deadlocked_agents_ids if agent_id not in already_deadlocked_ids]
# assign deadlock rewards
for agent_id in newly_deadlocked_agents_ids:
print(f"assigning deadlock reward of {self.deadlock_reward} to agent {agent_id}")
rewards[agent_id] = self.deadlock_reward
return obs, rewards, dones, info
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list
obs, info = super().reset(**kwargs)
return obs, info
......@@ -11,11 +11,11 @@ class Environment:
Derived environments should implement the following attributes:
action_space: tuple with the dimensions of the actions to be passed to the step method
observation_space: tuple with the dimensions of the observations returned by reset and step
Agents are identified by agent ids (handles).
Examples:
>>> obs = env.reset()
>>> obs, info = env.reset()
>>> print(obs)
{
"train_0": [2.4, 1.6],
......@@ -40,18 +40,19 @@ class Environment:
"train_0": {}, # info for train_0
"train_1": {}, # info for train_1
}
"""
def __init__(self):
self.action_space = ()
self.observation_space = ()
pass
def reset(self):
"""
Resets the env and returns observations from agents in the environment.
Returns:
Returns
-------
obs : dict
New observations for each agent.
"""
......@@ -66,7 +67,7 @@ class Environment:
The returns are dicts mapping from agent_id strings to values.
Parameters
-------
----------
action_dict : dict
Dictionary of actions to execute, indexed by agent id.
......@@ -84,27 +85,6 @@ class Environment:
"""
raise NotImplementedError()
def predict(self):
"""
Predictions step.
Returns predictions for the agents.
The returns are dicts mapping from agent_id strings to values.
Returns
-------
predictions : dict
New predictions for each ready agent.
"""
raise NotImplementedError()
def render(self):
"""
Perform rendering of the environment.
"""
raise NotImplementedError()
def get_agent_handles(self):
"""
Returns a list of agents' handles to be used as keys in the step()
......
......@@ -2,27 +2,29 @@
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ `reset()` is called after each environment reset, to allow for pre-computing relevant data.
+ `get()` is called whenever an observation has to be computed, potentially for each agent independently in case of \
multi-agent environments.
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
from typing import Optional, List
import numpy as np
from flatland.core.env import Environment
class ObservationBuilder:
"""
ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
observations.
"""
def __init__(self):
self.observation_space = ()
self.env = None
def _set_env(self, env):
self.env = env
def set_env(self, env: Environment):
self.env: Environment = env
def reset(self):
"""
......@@ -30,14 +32,37 @@ class ObservationBuilder:
"""
raise NotImplementedError()
def get(self, handle=0):
def get_many(self, handles: Optional[List[int]] = None):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
Parameters
----------
handles : list of handles, optional
List with the handles of the agents for which to compute the observation vector.
Returns
-------
handle : int (optional)
function
A dictionary of observation structures, specific to the corresponding environment, with handles from
`handles` as keys.
"""
observations = {}
if handles is None:
handles = []
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
"""
Called whenever an observation has to be computed for the `env` environment, possibly
for each agent independently (agent id `handle`).
Parameters
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......@@ -52,3 +77,22 @@ class ObservationBuilder:
direction = np.zeros(4)
direction[agent.direction] = 1
return direction
class DummyObservationBuilder(ObservationBuilder):
"""
DummyObservationBuilder class which returns dummy observations
This is used in the evaluation service
"""
def __init__(self):
super().__init__()
def reset(self):
pass
def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True
def get(self, handle: int = 0) -> bool:
return True
......@@ -3,22 +3,25 @@ PredictionBuilder objects are objects that can be passed to environments designe
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ `reset()` is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an step has to be computed, potentially for each agent independently in
+ `get()` is called whenever an step has to be computed, potentially for each agent independently in \
case of multi-agent environments.
"""
from flatland.core.env import Environment
class PredictionBuilder:
"""
PredictionBuilder base class.
"""
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
self.env = None
def _set_env(self, env):
def set_env(self, env: Environment):
self.env = env
def reset(self):
......@@ -27,13 +30,13 @@ class PredictionBuilder:
"""
pass
def get(self, handle=0):
def get(self, handle: int = 0):
"""
Called whenever step_prediction is called on the environment.
Called whenever get_many in the observation build is called.
Parameters
-------
handle : int (optional)
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......
from enum import IntEnum
from functools import lru_cache
from typing import Type, List
import numpy as np
from flatland.core.transitions import Transitions
# maxsize=None can be used because the number of possible transition is limited (16 bit encoded) and the
# direction/orientation is also limited (2bit). Where the 16bit are only sparse used = number of rail types
# Those methods can be cached -> the are independant of the railways (env)
@lru_cache(maxsize=128)
def fast_grid4_get_transitions(cell_transition, orientation):
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
@lru_cache(maxsize=128)
def fast_grid4_get_transition(cell_transition, orientation, direction):
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
@lru_cache(maxsize=128)
def fast_grid4_set_transitions(cell_transition, orientation, new_transitions):
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_remove_deadends(cell_transition):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
maskDeadEnds = Grid4Transitions.maskDeadEnds()
cell_transition &= cell_transition & (~maskDeadEnds) & 0xffff
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_rotate_transition(cell_transition, rotation=0):
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = fast_grid4_get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = fast_grid4_set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (
value >> (rotation * 4))
cell_transition = value
return cell_transition
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
@staticmethod
def to_char(int: int):
return {0: 'N',
1: 'E',
2: 'S',
3: 'W'}[int]
class Grid4Transitions(Transitions):
"""
Grid4Transitions class derived from Transitions.
Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 16 bits.
Whether a transition is allowed or not depends on which direction an agent
inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
direction the agent wants to move to
(North, East, South, West, relative to the cell).
Each transition (orientation, direction)
can be allowed (1) or forbidden (0).
For example, in case of no diagonal transitions on the grid, the 16 bits
of the transition bitmaps are organized in 4 blocks of 4 bits each, the
direction that the agent is facing.
E.g., the most-significant 4-bits represent the possible movements (NESW)
if the agent is facing North, etc...
agent's direction: North East South West
agent's allowed movements: [nesw] [nesw] [nesw] [nesw]
example: 1000 0000 0010 0000
In the example, the agent can move from North to South and viceversa.
"""
def __init__(self, transitions):
self.transitions = transitions
self.sDirs = "NESW"
self.lsDirs = list(self.sDirs)
# row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
# These bits represent all the possible dead ends
@staticmethod
@lru_cache()
def maskDeadEnds():
return 0b0010000110000100
def get_type(self):
return np.uint16
def get_transitions(self, cell_transition, orientation):
"""
Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent oriented
in direction `orientation` and inside a cell with
transitions `cell_transition`.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
return fast_grid4_get_transitions(cell_transition, orientation)
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent
oriented in direction `orientation` and inside a cell with transitions
`cell_transition'. A new `cell_transition` is returned with
the specified bits replaced by `new_transitions`.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
return fast_grid4_set_transitions(cell_transition, orientation, new_transitions)
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
Returns
-------
int
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return fast_grid4_get_transition(cell_transition, orientation, direction)
def set_transition(self, cell_transition, orientation, direction, new_transition,
remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
new_transition : int
Validity of the requested transition: 0/1 allowed/not allowed.
remove_deadends -- boolean, default False
remove all deadend transitions.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
if new_transition:
cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
else:
cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
if remove_deadends:
cell_transition = fast_grid4_remove_deadends(cell_transition)
return cell_transition
def rotate_transition(self, cell_transition, rotation=0):
"""
Clockwise-rotate a 16-bit transition bitmap by
rotation={0, 90, 180, 270} degrees.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition` by. I.e., rotation={0, 90, 180, 270} degrees.
Returns
-------
int
An updated bitmap that replaces the original transitions bits
with the equivalent bitmap after rotation.
"""
# Rotate the individual bits in each block
return fast_grid4_rotate_transition(cell_transition, rotation)
def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
return Grid4TransitionsEnum
@staticmethod
@lru_cache()
def has_deadend(cell_transition):
"""
Checks if one entry can only by exited by a turn-around.
"""
if cell_transition & Grid4Transitions.maskDeadEnds() > 0:
return True
else:
return False
def remove_deadends(self, cell_transition):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
return fast_grid4_remove_deadends(cell_transition)
@staticmethod
@lru_cache()
def get_entry_directions(cell_transition) -> List[int]:
return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
import numpy as np
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
from flatland.core.grid.grid_utils import IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.ordered_set import OrderedSet
class AStarNode:
"""A node class for A* Pathfinding"""
def __init__(self, pos: IntVector2D, parent=None):
self.parent = parent
self.pos: IntVector2D = pos
self.g = 0.0
self.h = 0.0
self.f = 0.0
def __eq__(self, other):
"""
Parameters
----------
other : AStarNode
"""
return self.pos == other.pos
def __hash__(self):
return hash(self.pos)
def update_if_better(self, other):
if other.g < self.g:
self.parent = other.parent
self.g = other.g
self.h = other.h
self.f = other.f
def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, avoid_rails=False,
respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
"""
:param avoid_rails:
:param grid_map: Grid Map where the path is found in
:param start: Start positions as (row,column)
:param end: End position as (row,column)
:param a_star_distance_function: Define the distance function to use as heuristc:
-get_euclidean_distance
-get_manhattan_distance
-get_chebyshev_distance
:param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map.
- True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found
- False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards
:param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map
:return: IF a path is found a ordered list of al cells in path is returned
"""
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape = grid_map.grid.shape
start_node = AStarNode(start, None)
end_node = AStarNode(end, None)
open_nodes = OrderedSet()
closed_nodes = OrderedSet()
open_nodes.add(start_node)
while len(open_nodes) > 0:
# get node with current shortest est. path (lowest f)
current_node = None
for item in open_nodes:
if current_node is None:
current_node = item
continue
if item.f < current_node.f:
current_node = item
# pop current off open list, add to closed list
open_nodes.remove(current_node)
closed_nodes.add(current_node)
# found the goal
if current_node == end_node:
path = []
current = current_node
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
return path[::-1]
# generate children
children = []
if current_node.parent is not None:
prev_pos = current_node.parent.pos
else:
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
# update the "current" pos
node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos)
# is node_pos inside the grid?
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue
# validate positions
#
if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos,
end_node.pos) and respect_transition_validity:
continue
# create new node
new_node = AStarNode(node_pos, current_node)
# Skip paths through forbidden regions if they are provided
if forbidden_cells is not None:
if node_pos in forbidden_cells and new_node != start_node and new_node != end_node:
continue
children.append(new_node)
# loop through children
for child in children:
# already in closed list?
if child in closed_nodes:
continue
# create the f, g, and h values
child.g = current_node.g + 1.0
# this heuristic avoids diagonal paths
if avoid_rails:
child.h = a_star_distance_function(child.pos, end_node.pos) + np.clip(grid_map.grid[child.pos], 0, 1)
else:
child.h = a_star_distance_function(child.pos, end_node.pos)
child.f = child.g + child.h
# already in the open list?
if child in open_nodes:
continue
# add the child to the open list
open_nodes.add(child)
# no full path found
if len(open_nodes) == 0:
return []
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2D
def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
"""
diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1]
if diff_0 < 0:
return Grid4TransitionsEnum.NORTH
if diff_0 > 0:
return Grid4TransitionsEnum.SOUTH
if diff_1 > 0:
return Grid4TransitionsEnum.EAST
if diff_1 < 0:
return Grid4TransitionsEnum.WEST
raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
def mirror(dir):
return (dir + 2) % 4
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
:param pos2: position we want to know it is facing
:return: direction NESW as int N:0 E:1 S:2 W:3
"""
diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1]))
axis = np.argmax(np.power(diff_vec, 2))
direction = np.sign(diff_vec[axis])
if axis == 0:
if direction > 0:
return Grid4TransitionsEnum.NORTH
else:
return Grid4TransitionsEnum.SOUTH
else:
if direction > 0:
return Grid4TransitionsEnum.WEST
else:
return Grid4TransitionsEnum.EAST
from enum import IntEnum
import numpy as np
from flatland.core.transitions import Transitions
class Grid8TransitionsEnum(IntEnum):
NORTH = 0
NORTH_EAST = 1
EAST = 2
SOUTH_EAST = 3
SOUTH = 4
SOUTH_WEST = 5
WEST = 6
NORTH_WEST = 7
class Grid8Transitions(Transitions):
"""
Grid8Transitions class derived from Transitions.
Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 64 bits.
0=North, 1=North-East, etc.
"""
def __init__(self, transitions):
self.transitions = transitions
def get_type(self):
return np.uint64
def get_transitions(self, cell_transition, orientation):
"""
Get the 8 possible transitions.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
bits = (np.uint64(cell_transition) >> np.uint64((7 - orientation) * 8))
cell_transition = (
(bits >> np.uint64(7)) & np.uint64(1),
(bits >> np.uint64(6)) & np.uint64(1),
(bits >> np.uint64(5)) & np.uint64(1),
(bits >> np.uint64(4)) & np.uint64(1),
(bits >> np.uint64(3)) & np.uint64(1),
(bits >> np.uint64(2)) & np.uint64(1),
(bits >> np.uint64(1)) & np.uint64(1),
bits & np.uint64(1))
return cell_transition
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
negmask = ~mask
new_transitions = \
(int(new_transitions[0]) & 1) << 7 | \
(int(new_transitions[1]) & 1) << 6 | \
(int(new_transitions[2]) & 1) << 5 | \
(int(new_transitions[3]) & 1) << 4 | \
(int(new_transitions[4]) & 1) << 3 | \
(int(new_transitions[5]) & 1) << 2 | \
(int(new_transitions[6]) & 1) << 1 | \
(int(new_transitions[7]) & 1)
cell_transition = (int(cell_transition) & negmask) | (new_transitions << ((7 - orientation) * 8))
return cell_transition
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
Returns
-------
int
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
new_transition : int
Validity of the requested transition: 0/1 allowed/not allowed.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
if new_transition:
cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
else:
cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
return cell_transition
def rotate_transition(self, cell_transition, rotation=0):
"""
Clockwise-rotate a 64-bit transition bitmap by
rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition` by. I.e., rotation={0, 45, 90, 135, 180,
225, 270, 315} degrees.
Returns
-------
int
An updated bitmap that replaces the original transitions bits
with the equivalent bitmap after rotation.
"""
# TODO: WARNING: this part of the function has never been tested!
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 45
for i in range(8):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 8bits blocks
value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid8TransitionsEnum
from math import isnan
from typing import Tuple, Callable, List, Type
import numpy as np
Vector2D: Type = Tuple[float, float]
IntVector2D: Type = Tuple[int, int]
IntVector2DArray: Type = List[IntVector2D]
IntVector2DArrayArray: Type = List[List[IntVector2D]]
Vector2DArray: Type = List[Vector2D]
Vector2DArrayArray: Type = List[List[Vector2D]]
IntVector2DDistance: Type = Callable[[IntVector2D, IntVector2D], float]
class Vec2dOperations:
@staticmethod
def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
check if node_a and nobe_b are equal
"""
return node_a[0] == node_b[0] and node_a[1] == node_b[1]
@staticmethod
def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a - node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] - node_b[0], node_a[1] - node_b[1]
@staticmethod
def add(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] + node_b[0], node_a[1] + node_b[1]
@staticmethod
def make_orthogonal(node: Vector2D) -> Vector2D:
"""
vector operation : rotates the 2D vector +90°
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[1], -node[0]
@staticmethod
def get_norm(node: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return np.sqrt(node[0] * node[0] + node[1] * node[1])
@staticmethod
def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Euclidean distance
"""
return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a))
@staticmethod
def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the manhattan distance of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Mahnhattan distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return np.abs(delta[0]) + np.abs(delta[1])
@staticmethod
def get_chebyshev_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the chebyshev norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
the chebyshev distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return max(np.abs(delta[0]), np.abs(delta[1]))
@staticmethod
def normalize(node: Vector2D) -> Tuple[float, float]:
"""
normalize the 2d vector = `v/|v|`
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
n = Vec2dOperations.get_norm(node)
if n > 0.0:
n = 1 / n
return Vec2dOperations.scale(node, n)
@staticmethod
def scale(node: Vector2D, scale: float) -> Vector2D:
"""
scales the 2d vector = node * scale
:param node: tuple with coordinate (x,y) or 2d vector
:param scale: scalar to scale
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[0] * scale, node[1] * scale
@staticmethod
def round(node: Vector2D) -> IntVector2D:
"""
rounds the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return int(np.round(node[0])), int(np.round(node[1]))
@staticmethod
def ceil(node: Vector2D) -> IntVector2D:
"""
ceiling the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.ceil(node[0])), int(np.ceil(node[1]))
@staticmethod
def floor(node: Vector2D) -> IntVector2D:
"""
floor the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.floor(node[0])), int(np.floor(node[1]))
@staticmethod
def bound(node: Vector2D, min_value: float, max_value: float) -> Vector2D:
"""
force the values x and y to be between min_value and max_value
:param node: tuple with coordinate (x,y) or 2d vector
:param min_value: scalar value
:param max_value: scalar value
:return:
tuple with coordinate (x,y) or 2d vector
"""
return max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))
@staticmethod
def rotate(node: Vector2D, rot_in_degree: float) -> Vector2D:
"""
rotate the 2d vector with given angle in degree
:param node: tuple with coordinate (x,y) or 2d vector
:param rot_in_degree: angle in degree
:return:
tuple with coordinate (x,y) or 2d vector
"""
alpha = rot_in_degree / 180.0 * np.pi
x0 = node[0]
y0 = node[1]
x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha)
y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha)
return x1, y1
def position_to_coordinate(depth: int, positions: List[int]):
"""Converts coordinates to positions::
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
-->
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
Parameters
----------
depth : int
positions : List[Tuple[int,int]]
"""
coords = ()
for p in positions:
coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim
return coords
def coordinate_to_position(depth, coords):
"""
Converts positions to coordinates::
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
-->
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
:param depth:
:param coords:
:return:
"""
position = list(range(len(coords)))
for index, t in enumerate(coords):
if isnan(t[0]):
position[index] = -1
else:
position[index] = int(t[1] * depth + t[0])
return position
def distance_on_rail(pos1, pos2, metric="Euclidean"):
if metric == "Euclidean":
return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
if metric == "Manhattan":
return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.utils.ordered_set import OrderedSet
class RailEnvTransitions(Grid4Transitions):
"""
Special case of `GridTransitions` over a 2D-grid, with a pre-defined set
of transitions mimicking the types of real Swiss rail connections.
As no diagonal transitions are allowed in the RailEnv environment, the
possible transitions for RailEnv from a cell to its neighboring ones
are represented over 16 bits.
The 16 bits are organized in 4 blocks of 4 bits each, the direction that
the agent is facing.
E.g., the most-significant 4-bits represent the possible movements (NESW)
if the agent is facing North, etc...
agent's direction: North East South West
agent's allowed movements: [nesw] [nesw] [nesw] [nesw]
example: 1000 0000 0010 0000
In the example, the agent can move from North to South and viceversa.
"""
# Contains the basic transitions;
# the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions.
transition_list = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip
int('1100110000110011', 2), # Case 5 - double slip
int('0101001000000010', 2), # Case 6 - symmetrical
int('0010000000000000', 2), # Case 7 - dead end
int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2), # Case 1c (9) - simple turn left
int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored
def __init__(self):
super(RailEnvTransitions, self).__init__(
transitions=self.transition_list
)
# create this to make validation faster
self.transitions_all = OrderedSet()
for index, trans in enumerate(self.transitions):
self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10):
for _ in range(3):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.add(trans)
elif index in (1, 5):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.add(trans)
def print(self, cell_transition):
print(" NESW")
print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
def is_valid(self, cell_transition):
"""
Checks if a cell transition is a valid cell setup.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
Returns
-------
Boolean
True or False
"""
return cell_transition in self.transitions_all
......@@ -3,11 +3,19 @@ TransitionMap and derived classes.
"""
import numpy as np
from importlib_resources import path
from numpy import array
from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
class TransitionMap:
"""
Base TransitionMap class.
......@@ -19,7 +27,7 @@ class TransitionMap:
def get_transitions(self, cell_id):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -39,8 +47,8 @@ class TransitionMap:
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
......@@ -56,8 +64,8 @@ class TransitionMap:
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
......@@ -72,7 +80,7 @@ class TransitionMap:
Returns
-------
int or float (depending on derived class)
int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
......@@ -81,8 +89,8 @@ class TransitionMap:
def set_transition(self, cell_id, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
......@@ -94,7 +102,7 @@ class TransitionMap:
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on derived class)
new_transition : int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
......@@ -109,7 +117,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions=Grid4Transitions([])):
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None):
"""
Builder for GridTransitionMap object.
......@@ -128,16 +136,35 @@ class GridTransitionMap(TransitionMap):
self.width = width
self.height = height
self.transitions = transitions
self.random_generator = np.random.RandomState()
if random_seed is None:
self.random_generator.seed(12)
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions):
self.grid = np.ndarray((height, width), dtype=np.uint16)
elif isinstance(self.transitions, Grid8Transitions):
self.grid = np.ndarray((height, width), dtype=np.uint64)
def get_full_transitions(self, row, column):
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
def get_transitions(self, cell_id):
Parameters
----------
row: int
column: int
(row,column) specifies the cell in this transition map.
Returns
-------
self.transitions.get_type()
The cell content int the format of this map's Transitions.
"""
return self.grid[row][column]
def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -152,22 +179,15 @@ class GridTransitionMap(TransitionMap):
Returns
-------
tuple
List of the validity of transitions in the cell.
List of the validity of transitions in the cell as given by the maps transitions.
"""
if len(cell_id) == 3:
return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
elif len(cell_id) == 2:
return self.grid[cell_id[0]][cell_id[1]]
else:
print('GridTransitionMap.get_transitions() ERROR: \
wrong cell_id tuple.')
return ()
return self.transitions.get_transitions(self.grid[row][column], orientation)
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
......@@ -181,20 +201,19 @@ class GridTransitionMap(TransitionMap):
Tuple of new transitions validitiy for the cell.
"""
assert len(cell_id) in (2, 3), \
'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if len(cell_id) == 3:
self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]],
cell_id[2],
new_transitions)
elif len(cell_id) == 2:
self.grid[cell_id[0]][cell_id[1]] = new_transitions
else:
print('GridTransitionMap.get_transitions() ERROR: \
wrong cell_id tuple.')
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
......@@ -209,21 +228,20 @@ class GridTransitionMap(TransitionMap):
Returns
-------
int or float (depending on derived class)
int or float (depending on Transitions used in the )
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
if len(cell_id) != 3:
print('GridTransitionMap.get_transition() ERROR: \
wrong cell_id tuple.')
return ()
assert len(cell_id) == 3, \
'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.'
return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index)
def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
......@@ -235,15 +253,13 @@ class GridTransitionMap(TransitionMap):
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on derived class)
new_transition : int or float (depending on Transitions used in the map.)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
if len(cell_id) != 3:
print('GridTransitionMap.set_transition() ERROR: \
wrong cell_id tuple.')
return
assert len(cell_id) == 3, \
'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'
self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition(
self.grid[cell_id[0]][cell_id[1]],
cell_id[2],
......@@ -253,7 +269,7 @@ class GridTransitionMap(TransitionMap):
def save_transition_map(self, filename):
"""
Save the transitions grid as `filename', in npy format.
Save the transitions grid as `filename`, in npy format.
Parameters
----------
......@@ -263,57 +279,184 @@ class GridTransitionMap(TransitionMap):
"""
np.save(filename, self.grid)
def load_transition_map(self, filename, override_gridsize=True):
def load_transition_map(self, package, resource):
"""
Load the transitions grid from `filename' (npy format).
Load the transitions grid from `filename` (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway.
initialized with the correct `transitions` object anyway.
Parameters
----------
filename : string
Name of the file from which to load the transitions grid.
package : string
Name of the package from which to load the transitions grid.
resource : string
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) )
"""
new_grid = np.load(filename)
with path(package, resource) as file_in:
new_grid = np.load(file_in)
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
if override_gridsize:
self.width = new_width
self.height = new_height
self.grid = new_grid
self.width = new_width
self.height = new_height
self.grid = new_grid
else:
if new_grid.dtype == np.uint16:
self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
elif new_grid.dtype == np.uint64:
self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a dead-end.
"""
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)
self.grid[0:min(self.height, new_height),
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
def is_cell_valid(self, rcPos):
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a left/right simple turn.
"""
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
def is_simple_turn(trans):
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
for _ in range(3):
trans = self.transitions.rotate_transition(trans, rotation=90)
all_simple_turns.add(trans)
return trans in all_simple_turns
return is_simple_turn(tmp)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions(node_position[0], node_position[1], node_direction)
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if not self.transitions.is_valid(cell_transition):
return False
else:
return True
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
return False
return True
def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
......@@ -332,7 +475,7 @@ class GridTransitionMap(TransitionMap):
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_transitions(rcPos) # 16bit integer - all trans in/out
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
......@@ -352,19 +495,145 @@ class GridTransitionMap(TransitionMap):
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions((*gPos2, iDirOut))
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False
return True
def cell_repr(self, rcPos):
return self.transitions.repr(self.get_transitions(rcPos))
def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
symmetrical = cells[6]
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south]
# loop over available outbound directions (indices) for rcPos
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
connected = 0
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line set deadend
if number_of_incoming == 1:
if self.get_full_transitions(*rcPos) == 0:
self.set_transitions(rcPos, 0)
else:
self.set_transitions(rcPos, 0)
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
self.set_transitions(rcPos, 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
hole = np.argwhere(incoming_connections < 1)[0][0]
if direction >= 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 0:
transition = simple_switch_west_south
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
# Make a double slip switch
if number_of_incoming == 4:
rotation = self.random_generator.randint(2)
transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
:param prev_pos: The previous position we were checking
:param current_pos: The current position we are checking
:param new_pos: Possible child position we move into
:param end_pos: End cell of path we are drawing
:return: True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return self.transitions.is_valid(new_trans)
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps.
# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?)
def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
......@@ -3,8 +3,7 @@ The transitions module defines the base Transitions class and a
derived GridTransitions class, which allows for the specification of
possible transitions over a 2D grid.
"""
import numpy as np
from enum import IntEnum
class Transitions:
......@@ -13,13 +12,16 @@ class Transitions:
Generic class that implements checks to control whether a
certain transition is allowed (agent facing a direction
`orientation' and moving into direction `direction')
`orientation' and moving into direction `orientation`)
"""
def get_type(self):
raise NotImplementedError()
def get_transitions(self, cell_transition, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_transition' for an agent facing direction `orientation'
`cell_transition' for an agent facing direction `orientation`
(e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -43,9 +45,9 @@ class Transitions:
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Return a `cell_transition' specification where the transitions
available for an agent facing direction `orientation' are replaced
with the tuple `new_transitions'. `new_orientations' must have
Return a `cell_transition` specification where the transitions
available for an agent facing direction `orientation` are replaced
with the tuple `new_transitions'. `new_orientations` must have
one element for each possible transition.
Parameters
......@@ -63,8 +65,8 @@ class Transitions:
-------
[cell-content]
An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions',
for the appropriate `orientation'.
transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation`.
"""
raise NotImplementedError()
......@@ -72,8 +74,8 @@ class Transitions:
def get_transition(self, cell_transition, orientation, direction):
"""
Return the status of whether an agent oriented in directions
`orientation' and inside a cell with transitions `cell_transition'
can move to the cell in direction `direction' relative
`orientation' and inside a cell with transitions `cell_transition`
can move to the cell in direction `direction` relative
to the current cell.
Parameters
......@@ -99,11 +101,11 @@ class Transitions:
def set_transition(self, cell_transition, orientation, direction,
new_transition):
"""
Return a `cell_transition' specification where the status of
whether an agent oriented in direction `orientation' and inside
a cell with transitions `cell_transition' can move to the cell
in direction `direction' relative to the current cell is set
to `new_transition'.
Return a `cell_transition` specification where the status of
whether an agent oriented in direction `orientation` and inside
a cell with transitions `cell_transition` can move to the cell
in direction `direction` relative to the current cell is set
to `new_transition`.
Parameters
----------
......@@ -123,510 +125,11 @@ class Transitions:
-------
[cell-content]
An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions',
for the appropriate `orientation' to `direction'.
transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation' to `direction`.
"""
raise NotImplementedError()
class Grid4Transitions(Transitions):
"""
Grid4Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
list, each represented as a bitmap of 16 bits.
Whether a transition is allowed or not depends on which direction an agent
inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
direction the agent wants to move to
(North, East, South, West, relative to the cell).
Each transition (orientation, direction)
can be allowed (1) or forbidden (0).
For example, in case of no diagonal transitions on the grid, the 16 bits
of the transition bitmaps are organized in 4 blocks of 4 bits each, the
direction that the agent is facing.
E.g., the most-significant 4-bits represent the possible movements (NESW)
if the agent is facing North, etc...
agent's direction: North East South West
agent's allowed movements: [nesw] [nesw] [nesw] [nesw]
example: 1000 0000 0010 0000
In the example, the agent can move from North to South and viceversa.
"""
def __init__(self, transitions):
self.transitions = transitions
self.sDirs = "NESW"
self.lsDirs = list(self.sDirs)
# row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
def get_transitions(self, cell_transition, orientation):
"""
Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent oriented
in direction `orientation' and inside a cell with
transitions `cell_transition'.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition'. A new `cell_transition' is returned with
the specified bits replaced by `new_transitions'.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
Returns
-------
int
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
new_transition : int
Validity of the requested transition: 0/1 allowed/not allowed.
remove_deadends -- boolean, default False
remove all deadend transitions.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
if new_transition:
cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
else:
cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
if remove_deadends:
cell_transition = self.remove_deadends(cell_transition)
return cell_transition
def rotate_transition(self, cell_transition, rotation=0):
"""
Clockwise-rotate a 16-bit transition bitmap by
rotation={0, 90, 180, 270} degrees.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees.
Returns
-------
int
An updated bitmap that replaces the original transitions bits
with the equivalent bitmap after rotation.
"""
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
cell_transition = value
return cell_transition
class Grid8Transitions(Transitions):
"""
Grid8Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
list, each represented as a bitmap of 64 bits.
0=North, 1=North-East, etc.
"""
def __init__(self, transitions):
self.transitions = transitions
def get_transitions(self, cell_transition, orientation):
"""
Get the 8 possible transitions.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
bits = (cell_transition >> ((7 - orientation) * 8))
cell_transition = (
(bits >> 7) & 1,
(bits >> 6) & 1,
(bits >> 5) & 1,
(bits >> 4) & 1,
(bits >> 3) & 1,
(bits >> 2) & 1,
(bits >> 1) & 1,
(bits) & 1)
return cell_transition
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 7 | \
(new_transitions[1] & 1) << 6 | \
(new_transitions[2] & 1) << 5 | \
(new_transitions[3] & 1) << 4 | \
(new_transitions[4] & 1) << 3 | \
(new_transitions[5] & 1) << 2 | \
(new_transitions[6] & 1) << 1 | \
(new_transitions[7] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((7 - orientation) * 8))
return cell_transition
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
Returns
-------
int
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1
def set_transition(self, cell_transition, orientation, direction,
new_transition):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
new_transition : int
Validity of the requested transition: 0/1 allowed/not allowed.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
if new_transition:
cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
else:
cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
return cell_transition
def rotate_transition(self, cell_transition, rotation=0):
"""
Clockwise-rotate a 64-bit transition bitmap by
rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
225, 270, 315} degrees.
Returns
-------
int
An updated bitmap that replaces the original transitions bits
with the equivalent bitmap after rotation.
"""
# TODO: WARNING: this part of the function has never been tested!
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 45
for i in range(8):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 8bits blocks
value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
cell_transition = value
return cell_transition
class RailEnvTransitions(Grid4Transitions):
"""
Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
of transitions mimicking the types of real Swiss rail connections.
--------------------------------------------------------------------------
As no diagonal transitions are allowed in the RailEnv environment, the
possible transitions for RailEnv from a cell to its neighboring ones
are represented over 16 bits.
The 16 bits are organized in 4 blocks of 4 bits each, the direction that
the agent is facing.
E.g., the most-significant 4-bits represent the possible movements (NESW)
if the agent is facing North, etc...
agent's direction: North East South West
agent's allowed movements: [nesw] [nesw] [nesw] [nesw]
example: 1000 0000 0010 0000
In the example, the agent can move from North to South and viceversa.
"""
"""
transitions[] is indexed by case type/id, and returns the 4x4-bit [NESW]
transitions available as a function of the agent's orientation
(north, east, south, west)
"""
transition_list = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip
int('1100110000110011', 2), # Case 5 - double slip
int('0101001000000010', 2), # Case 6 - symmetrical
int('0010000000000000', 2), # Case 7 - dead end
int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2), # Case 1c (9) - simple turn left
int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored
def __init__(self):
super(RailEnvTransitions, self).__init__(
transitions=self.transition_list
)
# These bits represent all the possible dead ends
self.maskDeadEnds = 0b0010000110000100
# create this to make validation faster
self.transitions_all = set()
for index, trans in enumerate(self.transitions):
self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10):
for _ in range(3):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.add(trans)
elif index in (1, 5):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.add(trans)
def print(self, cell_transition):
print(" NESW")
print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
def repr(self, cell_transition, version=0):
"""
Provide a string representation of the cell transitions.
This class doesn't represent an individual cell,
but a way of interpreting the contents of a cell.
So using the ad hoc name repr rather than __repr__.
"""
# binary format string without leading 0b
sbinTrans = format(cell_transition, "#018b")[2:]
if version == 0:
sRepr = " ".join([
"{}:{}".format(sDir, sbinTrans[i:(i + 4)])
for i, sDir in
zip(
range(0, len(sbinTrans), 4),
self.lsDirs)]) # NESW
return sRepr
if version == 1:
lsRepr = []
for iDirIn in range(0, 4):
sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)]
if sDirTrans == "0000":
continue
sDirsOut = [
self.lsDirs[iDirOut]
for iDirOut in range(0, 4)
if sDirTrans[iDirOut] == "1"]
lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
return ", ".join(lsRepr)
def is_valid(self, cell_transition):
"""
Checks if a cell transition is a valid cell setup.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
Returns
-------
Boolean
True or False
"""
return cell_transition in self.transitions_all
def has_deadend(self, cell_transition):
if cell_transition & self.maskDeadEnds > 0:
return True
else:
return False
def remove_deadends(self, cell_transition):
cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
return cell_transition
def get_direction_enum(self) -> IntEnum:
raise NotImplementedError()
import networkx as nx
import numpy as np
from typing import List, Tuple
import graphviz as gv
class MotionCheck(object):
""" Class to find chains of agents which are "colliding" with a stopped agent.
This is to allow close-packed chains of agents, ie a train of agents travelling
at the same speed with no gaps between them,
"""
def __init__(self):
self.G = nx.DiGraph()
self.nDeadlocks = 0
self.svDeadlocked = set()
def addAgent(self, iAg, rc1, rc2, xlabel=None):
""" add an agent and its motion as row,col tuples of current and next position.
The agent's current position is given an "agent" attribute recording the agent index.
If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created.
xlabel is used for test cases to give a label (see graphviz)
"""
# Agents which have not yet entered the env have position None.
# Substitute this for the row = -1, column = agent index
if rc1 is None:
rc1 = (-1, iAg)
if rc2 is None:
rc2 = (-1, iAg)
self.G.add_node(rc1, agent=iAg)
if xlabel:
self.G.nodes[rc1]["xlabel"] = xlabel
self.G.add_edge(rc1, rc2)
def find_stops(self):
""" find all the stopped agents as a set of rc position nodes
A stopped agent is a self-loop on a cell node.
"""
# get the (sparse) adjacency matrix
spAdj = nx.linalg.adjacency_matrix(self.G)
# the stopped agents appear as 1s on the diagonal
# the where turns this into a list of indices of the 1s
giStops = np.where(spAdj.diagonal())[0]
# convert the cell/node indices into the node rc values
lvAll = list(self.G.nodes())
# pick out the stops by their indices
lvStops = [ lvAll[i] for i in giStops ]
# make it into a set ready for a set intersection
svStops = set(lvStops)
return svStops
def find_stops2(self):
""" alternative method to find stopped agents, using a networkx call to find selfloop edges
"""
svStops = { u for u,v in nx.classes.function.selfloop_edges(self.G) }
return svStops
def find_stop_preds(self, svStops=None):
""" Find the predecessors to a list of stopped agents (ie the nodes / vertices)
Returns the set of predecessors.
Includes "chained" predecessors.
"""
if svStops is None:
svStops = self.find_stops2()
# Get all the chains of agents - weakly connected components.
# Weakly connected because it's a directed graph and you can traverse a chain of agents
# in only one direction
lWCC = list(nx.algorithms.components.weakly_connected_components(self.G))
svBlocked = set()
for oWCC in lWCC:
#print("Component:", oWCC)
# Get the node details for this WCC in a subgraph
Gwcc = self.G.subgraph(oWCC)
# Find all the stops in this chain or tree
svCompStops = svStops.intersection(Gwcc)
#print(svCompStops)
if len(svCompStops) > 0:
# We need to traverse it in reverse - back up the movement edges
Gwcc_rev = Gwcc.reverse()
for vStop in svCompStops:
# Find all the agents stopped by vStop by following the (reversed) edges
# This traverses a tree - dfs = depth first seearch
iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc_rev, vStop)
lStops = list(iter_stops)
svBlocked.update(lStops)
# the set of all the nodes/agents blocked by this set of stopped nodes
return svBlocked
def find_swaps(self):
""" find all the swap conflicts where two agents are trying to exchange places.
These appear as simple cycles of length 2.
These agents are necessarily deadlocked (since they can't change direction in flatland) -
meaning they will now be stuck for the rest of the episode.
"""
#svStops = self.find_stops2()
llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
return svSwaps
def find_same_dest(self):
""" find groups of agents which are trying to land on the same cell.
ie there is a gap of one cell between them and they are both landing on it.
"""
pass
def block_preds(self, svStops, color="red"):
""" Take a list of stopped agents, and apply a stop color to any chains/trees
of agents trying to head toward those cells.
Count the number of agents blocked, ignoring those which are already marked.
(Otherwise it can double count swaps)
"""
iCount = 0
svBlocked = set()
# The reversed graph allows us to follow directed edges to find affected agents.
Grev = self.G.reverse()
for v in svStops:
# Use depth-first-search to find a tree of agents heading toward the blocked cell.
lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
svBlocked |= set(lvPred)
svBlocked.add(v)
#print("node:", v, "set", svBlocked)
# only count those not already marked
for v2 in [v]+lvPred:
if self.G.nodes[v2].get("color") != color:
self.G.nodes[v2]["color"] = color
iCount += 1
return svBlocked
def find_conflicts(self):
svStops = self.find_stops2() # voluntarily stopped agents - have self-loops
svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions
# Block all swaps and their tree of predessors
self.svDeadlocked = self.block_preds(svSwaps, color="purple")
# Take the union of the above, and find all the predecessors
#svBlocked = self.find_stop_preds(svStops.union(svSwaps))
# Just look for the the tree of preds for each voluntarily stopped agent
svBlocked = self.find_stop_preds(svStops)
# iterate the nodes v with their predecessors dPred (dict of nodes->{})
for (v, dPred) in self.G.pred.items():
# mark any swaps with purple - these are directly deadlocked
#if v in svSwaps:
# self.G.nodes[v]["color"] = "purple"
# If they are not directly deadlocked, but are in the union of stopped + deadlocked
#elif v in svBlocked:
# if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
if v in svBlocked:
self.G.nodes[v]["color"] = "red"
# not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
elif len(dPred)>1:
# if this agent is already red/blocked, ignore. CHECK: why?
# certainly we want to ignore purple so we don't overwrite with red.
if self.G.nodes[v].get("color") in ("red", "purple"):
continue
# if this node has no agent, and >=2 want to enter it.
if self.G.nodes[v].get("agent") is None:
self.G.nodes[v]["color"] = "blue"
# this node has an agent and >=2 want to enter
else:
self.G.nodes[v]["color"] = "magenta"
# predecessors of a contended cell: {agent index -> node}
diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred}
# remove the agent with the lowest index, who wins
iAgWinner = min(diAgCell)
diAgCell.pop(iAgWinner)
# Block all the remaining predessors, and their tree of preds
#for iAg, v in diAgCell.items():
# self.G.nodes[v]["color"] = "red"
# for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
# self.G.nodes[vPred]["color"] = "red"
self.block_preds(diAgCell.values(), "red")
def check_motion(self, iAgent, rcPos):
""" Returns tuple of boolean can the agent move, and the cell it will move into.
If agent position is None, we use a dummy position of (-1, iAgent)
"""
if rcPos is None:
rcPos = (-1, iAgent)
dAttr = self.G.nodes.get(rcPos)
#print("pos:", rcPos, "dAttr:", dAttr)
if dAttr is None:
dAttr = {}
# If it's been marked red or purple then it can't move
if "color" in dAttr:
sColor = dAttr["color"]
if sColor in [ "red", "purple" ]:
return False
dSucc = self.G.succ[rcPos]
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return False
# This agent has a successor
rcNext = self.G.successors(rcPos).__next__()
if rcNext == rcPos: # the agent didn't want to move
return False
# The agent wanted to move, and it can
return True
def render(omc:MotionCheck, horizontal=True):
try:
oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
oAG.layout("dot")
sDot = oAG.to_string()
if horizontal:
sDot = sDot.replace('{', '{ rankdir="LR" ')
#return oAG.draw(format="png")
# This returns a graphviz object which implements __repr_svg
return gv.Source(sDot)
except ImportError as oError:
print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs")
return None
class ChainTestEnv(object):
""" Just for testing agent chains
"""
def __init__(self, omc:MotionCheck):
self.iAgNext = 0
self.iRowNext = 1
self.omc = omc
def addAgent(self, rc1, rc2, xlabel=None):
self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel)
self.iAgNext+=1
def addAgentToRow(self, c1, c2, xlabel=None):
self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel)
def create_test_chain(self,
nAgents:int,
rcVel:Tuple[int] = (0,1),
liStopped:List[int]=[],
xlabel=None):
""" create a chain of agents
"""
lrcAgPos = [ (self.iRowNext, i * rcVel[1]) for i in range(nAgents) ]
for iAg, rcPos in zip(range(nAgents), lrcAgPos):
if iAg in liStopped:
rcVel1 = (0,0)
else:
rcVel1 = rcVel
self.omc.addAgent(iAg+self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1]) )
if xlabel:
self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel
self.iAgNext += nAgents
self.iRowNext += 1
def nextRow(self):
self.iRowNext+=1
def create_test_agents(omc:MotionCheck):
# blocked chain
omc.addAgent(1, (1,2), (1,3))
omc.addAgent(2, (1,3), (1,4))
omc.addAgent(3, (1,4), (1,5))
omc.addAgent(31, (1,5), (1,5))
# unblocked chain
omc.addAgent(4, (2,1), (2,2))
omc.addAgent(5, (2,2), (2,3))
# blocked short chain
omc.addAgent(6, (3,1), (3,2))
omc.addAgent(7, (3,2), (3,2))
# solitary agent
omc.addAgent(8, (4,1), (4,2))
# solitary stopped agent
omc.addAgent(9, (5,1), (5,1))
# blocked short chain (opposite direction)
omc.addAgent(10, (6,4), (6,3))
omc.addAgent(11, (6,3), (6,3))
# swap conflict
omc.addAgent(12, (7,1), (7,2))
omc.addAgent(13, (7,2), (7,1))
def create_test_agents2(omc:MotionCheck):
# blocked chain
cte = ChainTestEnv(omc)
cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain")
cte.create_test_chain(4, xlabel="running\nchain")
cte.create_test_chain(2, liStopped = [1], xlabel="stopped \nshort\n chain")
cte.addAgentToRow(1, 2, "swap")
cte.addAgentToRow(2, 1)
cte.nextRow()
cte.addAgentToRow(1, 2, "chain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nstop")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 4)
cte.addAgentToRow(5, 6)
cte.addAgentToRow(6, 7)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 3)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.nextRow()
cte.addAgentToRow(1, 2, "Land on\nSame")
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "chains\nonto\nsame")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.addAgentToRow(7, 6)
cte.nextRow()
cte.addAgentToRow(1, 2, "3-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.nextRow()
if False:
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "4-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.addAgent((cte.iRowNext-1, 2), (cte.iRowNext, 2))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tee")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgent((cte.iRowNext+1, 3), (cte.iRowNext, 3))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tree")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
r1 = cte.iRowNext
r2 = cte.iRowNext+1
r3 = cte.iRowNext+2
cte.addAgent((r2, 3), (r1, 3))
cte.addAgent((r2, 2), (r2, 3))
cte.addAgent((r3, 2), (r2, 3))
cte.nextRow()
def test_agent_following():
omc = MotionCheck()
create_test_agents2(omc)
svStops = omc.find_stops()
svBlocked = omc.find_stop_preds()
llvSwaps = omc.find_swaps()
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
print(list(svBlocked))
lvCells = omc.G.nodes()
lColours = [ "magenta" if v in svStops
else "red" if v in svBlocked
else "purple" if v in svSwaps
else "lightblue"
for v in lvCells ]
dPos = dict(zip(lvCells, lvCells))
nx.draw(omc.G,
with_labels=True, arrowsize=20,
pos=dPos,
node_color = lColours)
def main():
test_agent_following()
if __name__=="__main__":
main()
from itertools import starmap
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from attr import attrs, attrib
@attrs
class EnvDescription(object):
""" EnvDescription - This is a description of a random env,
based around the rail_generator and stats like size and n_agents.
It mirrors the parameters given to the RailEnv constructor.
Not currently used.
"""
n_agents = attrib()
height = attrib()
width = attrib()
rail_generator = attrib()
obs_builder = attrib() # not sure if this should closer to the agent than the env
import warnings
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
('direction', Grid4TransitionsEnum),
('target', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('handle', int),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs
class EnvAgentStatic(object):
""" EnvAgentStatic - Stores initial position, direction and target.
This is like static data for the environment - it's where an agent starts,
rather than where it is at the moment.
The target should also be stored here.
"""
position = attrib()
direction = attrib()
target = attrib()
moving = attrib()
def __init__(self, position, direction, target, moving=False):
self.position = position
self.direction = direction
self.target = target
self.moving = moving
class EnvAgent:
# INIT FROM HERE IN _from_line()
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# NEW : EnvAgent - Schedule properties
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
@classmethod
def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions))))
def to_list(self):
# I can't find an expression which works on both tuples, lists and ndarrays
# which converts them all to a list of native python ints.
lPos = self.position
if type(lPos) is np.ndarray:
lPos = lPos.tolist()
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
lTarget = self.target
if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist()
# Env step facelift
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
return [lPos, int(self.direction), lTarget, int(self.moving)]
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
@attrs
class EnvAgent(EnvAgentStatic):
""" EnvAgent - replace separate agent_* lists with a single list
of agent objects. The EnvAgent represent's the environment's view
of the dynamic agent state.
We are duplicating target in the EnvAgent, which seems simpler than
forcing the env to refer to it in the EnvAgentStatic
"""
handle = attrib(default=None)
# used in rendering
old_direction = attrib(default=None)
old_position = attrib(default=None)
def __init__(self, position, direction, target, handle, old_direction, old_position):
super(EnvAgent, self).__init__(position, direction, target)
self.handle = handle
self.old_direction = old_direction
self.old_position = old_position
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving]
def reset(self):
"""
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
"""
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
self.old_position = None
self.old_direction = None
self.moving = False
self.arrival_time = None
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
@classmethod
def from_static(cls, oStatic):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
return EnvAgent(*oStatic.__dict__, handle=0)
num_agents = len(line.agent_positions)
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def list_from_static(cls, lEnvAgentStatic, handles=None):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
"""
if handles is None:
handles = range(len(lEnvAgentStatic))
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} \n \
initial_direction: {self.initial_direction} \n \
position: {self.position} \n \
direction: {self.direction} \n \
target: {self.target} \n \
old_position: {self.old_position} \n \
old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} \n \
latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")
return [EnvAgent(**oEAS.__dict__, handle=handle)
for handle, oEAS in zip(handles, lEnvAgentStatic)]
from collections import deque
from typing import List, Optional
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
class DistanceMap:
def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
self.env_height = env_height
self.env_width = env_width
self.distance_map = None
self.agents_previous_computation = None
self.reset_was_called = False
self.agents: List[EnvAgent] = agents
self.rail: Optional[GridTransitionMap] = None
def set(self, distance_map: np.ndarray):
"""
Set the distance map
"""
self.distance_map = distance_map
def get(self) -> np.ndarray:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_computation is None and self.distance_map is not None:
compute_distance_map = False
if compute_distance_map:
self._compute(self.agents, self.rail)
elif self.distance_map is None:
self._compute(self.agents, self.rail)
return self.distance_map
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents: List[EnvAgent] = agents
self.rail = rail
self.env_height = rail.height
self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
This function computes the distance maps for each unique target. Thus if several targets are the same
we only compute the distance for them once and copy to all targets with same position.
:param agents: All the agents in the environment, independent of their current status
:param rail: The rail transition map
"""
self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
computed_targets = []
for i, agent in enumerate(agents):
if agent.target not in computed_targets:
self._distance_map_walker(rail, agent.target, i)
else:
# just copy the distance map form other agent with same target (performance)
self.distance_map[i, :, :, :] = np.copy(
self.distance_map[computed_targets.index(agent.target), :, :, :])
computed_targets.append(agent.target)
def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
"""
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np
def get_direction(pos1, pos2):
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
"""
diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1]
if diff_0 < 0:
return 0
if diff_0 > 0:
return 2
if diff_1 > 0:
return 1
if diff_1 < 0:
return 3
return 0
def mirror(dir):
return (dir + 2) % 4
def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = rail_array[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
class AStarNode():
"""A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None):
self.parent = parent
self.pos = pos
self.g = 0
self.h = 0
self.f = 0
def __eq__(self, other):
return self.pos == other.pos
def __hash__(self):
return hash(self.pos)
def update_if_better(self, other):
if other.g < self.g:
self.parent = other.parent
self.g = other.g
self.h = other.h
self.f = other.f
def a_star(rail_trans, rail_array, start, end):
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape = rail_array.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_nodes = set()
closed_nodes = set()
open_nodes.add(start_node)
while len(open_nodes) > 0:
# get node with current shortest est. path (lowest f)
current_node = None
for item in open_nodes:
if current_node is None:
current_node = item
continue
if item.f < current_node.f:
current_node = item
# pop current off open list, add to closed list
open_nodes.remove(current_node)
closed_nodes.add(current_node)
# found the goal
if current_node == end_node:
path = []
current = current_node
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
return path[::-1]
# generate children
children = []
if current_node.parent is not None:
prev_pos = current_node.parent.pos
else:
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue
# validate positions
if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
continue
# create new node
new_node = AStarNode(current_node, node_pos)
children.append(new_node)
# loop through children
for child in children:
# already in closed list?
if child in closed_nodes:
continue
# create the f, g, and h values
child.g = current_node.g + 1
# this heuristic favors diagonal paths:
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \# noqa: E800
# this heuristic avoids diagonal paths
child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
child.f = child.g + child.h
# already in the open list?
if child in open_nodes:
continue
# add the child to the open list
open_nodes.add(child)
# no full path found
if len(open_nodes) == 0:
return []
def connect_rail(rail_trans, rail_array, start, end):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
rail_array[end_pos] = new_trans_e
current_dir = new_dir
return path
def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def get_new_position(position, movement):
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
"""
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions((position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
return agents_position, agents_direction, agents_target
from typing import Tuple
# Adrian Egli / Michel Marti performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def fast_delete(lis: list, index) -> list:
new_list = lis.copy()
new_list.pop(index)
return new_list
def fast_where(binary_iterable):
return [index for index, element in enumerate(binary_iterable) if element != 0]
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
"""
Parameters
----------
file_name : str
The pickle file.
load_from_package : str
The python module to import from. Example: 'env_data.tests'
This requires that there are `__init__.py` files in the folder structure we load the file from.
obs_builder_object: ObservationBuilder
The obs builder for the `RailEnv` that is created.
Returns
-------
RailEnv
The environment loaded from the pickle file.
"""
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
import numpy as np
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror
from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
def empty_rail_generator():
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width, height, num_agents=0, num_resets=0):
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, [], [], []
return generator
def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a start and goal position
# - validate min/max distance allowed
# - validate that start/goals are not placed too close to other start/goals
# - draw a rail from [start,goal]
# - if rail crosses existing rail then validate new connection
# - possibility that this fails to create a path to goal
# - on failure generate new start/goal
#
# step 2:
# - add more rails to map randomly between cells that have rails
# - validate all new rails, on failure don't add new rails
#
# step 3:
# - return transition map + list of [start_pos, start_dir, goal_pos] points
#
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
if not all_ok:
# we can might as well give up at this point
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
agents_position = [sg[0] for sg in start_goal[:num_agents]]
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target
return generator
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
-------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
for r in range(height):
for c in range(width):
cell = rail_spec[r][c]
if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
print("ERROR - invalid cell type=", cell[0])
return []
rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail,
num_agents)
return rail, agents_position, agents_direction, agents_target
return generator
def rail_from_GridTransitionMap_generator(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions
are legit; if no cell can be filled in without violating some
transitions, pick one among those that can satisfy most transitions
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
incompatible.
- keep trying for a total number of insertions
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)): # don't include dead-ends
if t_utils.transitions[i] == int('0010000000000000', 2):
continue
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
return_rail,
num_agents)
return return_rail, agents_position, agents_direction, agents_target
return generator