Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2398 additions and 931 deletions
import logging
import random
import numpy as np
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.fast_methods import fast_count_nonzero, fast_argmax
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]
if agent.status in [TrainState.WAITING, TrainState.READY_TO_DEPART,
TrainState.MALFUNCTION_OFF_MAP]:
agent_virtual_position = agent.initial_position
elif agent.status in [TrainState.MALFUNCTION, TrainState.MOVING, TrainState.STOPPED]:
agent_virtual_position = agent.position
elif agent.status == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
if agent.position:
possible_transitions = env.rail.get_transitions(
*agent.position, agent.direction)
else:
possible_transitions = env.rail.get_transitions(
*agent.initial_position, agent.direction)
num_transitions = fast_count_nonzero(possible_transitions)
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(
agent_virtual_position, direction)
min_distances.append(
distance_map[handle, new_position[0],
new_position[1], direction])
else:
min_distances.append(np.inf)
if num_transitions == 1:
observation = [0, 1, 0]
elif num_transitions == 2:
idx = np.argpartition(np.array(min_distances), 2)
observation = [0, 0, 0]
observation[idx[0]] = 1
return fast_argmax(observation) + 1
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
random.seed(random_seed)
width = 30
height = 30
nr_trains = 5
max_num_cities = 4
grid_mode = False
max_rails_between_cities = 2
max_rails_in_city = 3
malfunction_rate = 0
malfunction_min_duration = 0
malfunction_max_duration = 0
rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
speed_ratio_map = None
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
random.seed(random_seed)
size = random.randint(0, 5)
width = 20 + size * 5
height = 20 + size * 5
nr_cities = 2 + size // 2 + random.randint(0, 2)
nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5)) # , 10 + random.randint(0, 10))
max_rails_between_cities = 2
max_rails_in_cities = 3 + random.randint(0, size)
malfunction_rate = 30 + random.randint(0, 100)
malfunction_min_duration = 3 + random.randint(0, 7)
malfunction_max_duration = 20 + random.randint(0, 80)
rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_cities)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def sparse_env_small(random_seed, observation_builder):
width = 30 # With of map
height = 30 # Height of map
nr_trains = 2 # Number of trains that have an assigned task in the env
cities_in_map = 3 # Number of cities where agents can start or end
seed = 10 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
line_generator = sparse_rail_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
rail_env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True)
return rail_env
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
def perc_completion(env):
tasks_finished = 0
if hasattr(env, "agents_data"):
agent_data = env.agents_data
else:
agent_data = env.agents
for current_agent in agent_data:
if current_agent.status == TrainState.DONE:
tasks_finished += 1
return 100 * np.mean(tasks_finished / max(
1, len(agent_data)))
from collections import defaultdict
from typing import Dict, Tuple
from flatland.contrib.utils.deadlock_checker import Deadlock_Checker
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.step_utils.states import TrainState
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.state == TrainState.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
else:
print("no action possible!")
print("agent state: ", agent.state)
# NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...)
return [(RailEnvActions.DO_NOTHING, 0)] * 2
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
print(f"possible transitions: {possible_transitions}")
distance_map = env.distance_map.get()[handle]
possible_steps = []
for movement in list(range(4)):
if possible_transitions[movement]:
if movement == agent.direction:
action = RailEnvActions.MOVE_FORWARD
elif movement == (agent.direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?")
action = RailEnvActions.MOVE_FORWARD
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
# if there is only one path to target, this is both the shortest one and the second shortest path.
if len(possible_steps) == 1:
return possible_steps * 2
else:
return possible_steps
class RailEnvWrapper:
def __init__(self, env:RailEnv):
self.env = env
assert self.env is not None
assert self.env.rail is not None, "Reset original environment first!"
assert self.env.agents is not None, "Reset original environment first!"
assert len(self.env.agents) > 0, "Reset original environment first!"
# @property
# def number_of_agents(self):
# return self.env.number_of_agents
# @property
# def agents(self):
# return self.env.agents
# @property
# def _seed(self):
# return self.env._seed
# @property
# def obs_builder(self):
# return self.env.obs_builder
def __getattr__(self, name):
try:
return super().__getattr__(self,name)
except:
"""Expose any other attributes of the underlying environment."""
return getattr(self.env, name)
@property
def rail(self):
return self.env.rail
@property
def width(self):
return self.env.width
@property
def height(self):
return self.env.height
@property
def agent_positions(self):
return self.env.agent_positions
def get_num_agents(self):
return self.env.get_num_agents()
def get_agent_handles(self):
return self.env.get_agent_handles()
def step(self, action_dict: Dict[int, RailEnvActions]):
return self.env.step(action_dict)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return obs, info
class ShortestPathActionWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv):
super().__init__(env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# input: action dict with actions in [0, 1, 2].
transformed_action_dict = {}
for agent_id, action in action_dict.items():
if action == 0:
transformed_action_dict[agent_id] = action
else:
#assert action in [1, 2]
#assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
#assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
obs, rewards, dones, info = self.env.step(transformed_action_dict)
return obs, rewards, dones, info
def find_all_cells_where_agent_can_choose(env: RailEnv):
"""
input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
WHICH HAS BEEN RESET ALREADY!
(o.w., we call env.rail, which is None before reset(), and crash.)
"""
switches = []
switches_neighbors = []
directions = list(range(4))
for h in range(env.height):
for w in range(env.width):
pos = (h, w)
is_switch = False
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
break
if is_switch:
# Add all neighbouring rails, if pos is a switch
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
for movement in directions:
if possible_transitions[movement]:
switches_neighbors.append(get_new_position(pos, movement))
decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.switches = None
self.switches_neighbors = None
self.decision_cells = None
self.skipped_rewards = defaultdict(list)
# sets initial values for switches, decision_cells, etc.
self.reset_cells()
def on_decision_cell(self, agent: EnvAgent) -> bool:
return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
def on_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches
def next_to_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches_neighbors
def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# need to initialize i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["state"] = dict()
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self.reset_cells()
return obs, info
class DeadlockWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv, deadlock_reward=-100) -> None:
super().__init__(env)
self.deadlock_reward = deadlock_reward
self.deadlock_checker = Deadlock_Checker(env=self.env)
@property
def deadlocked_agents(self):
return self.deadlock_checker.deadlocked_agents
@property
def immediate_deadlocks(self):
return [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
# make sure to assign the deadlock reward only once to each deadlocked agent...
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# agents which are already deadlocked from previous steps
already_deadlocked_ids = [agent.handle for agent in self.deadlocked_agents]
# step environment
obs, rewards, dones, info = self.env.step(action_dict)
# compute new list of deadlocked agents (ids) after stepping the environment
deadlocked_agents = self.deadlock_checker.check_deadlocks(action_dict) # also stored in self.deadlocked_checker.deadlocked_agents
deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents]
# immediate deadlocked ids only used for prints
immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
print(f"immediate deadlocked: {immediate_deadlocked_ids}")
print(f"total deadlocked: {deadlocked_agents_ids}")
newly_deadlocked_agents_ids = [agent_id for agent_id in deadlocked_agents_ids if agent_id not in already_deadlocked_ids]
# assign deadlock rewards
for agent_id in newly_deadlocked_agents_ids:
print(f"assigning deadlock reward of {self.deadlock_reward} to agent {agent_id}")
rewards[agent_id] = self.deadlock_reward
return obs, rewards, dones, info
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list
obs, info = super().reset(**kwargs)
return obs, info
...@@ -11,11 +11,11 @@ class Environment: ...@@ -11,11 +11,11 @@ class Environment:
Derived environments should implement the following attributes: Derived environments should implement the following attributes:
action_space: tuple with the dimensions of the actions to be passed to the step method action_space: tuple with the dimensions of the actions to be passed to the step method
observation_space: tuple with the dimensions of the observations returned by reset and step
Agents are identified by agent ids (handles). Agents are identified by agent ids (handles).
Examples: Examples:
>>> obs = env.reset()
>>> obs, info = env.reset()
>>> print(obs) >>> print(obs)
{ {
"train_0": [2.4, 1.6], "train_0": [2.4, 1.6],
...@@ -40,18 +40,19 @@ class Environment: ...@@ -40,18 +40,19 @@ class Environment:
"train_0": {}, # info for train_0 "train_0": {}, # info for train_0
"train_1": {}, # info for train_1 "train_1": {}, # info for train_1
} }
""" """
def __init__(self): def __init__(self):
self.action_space = () self.action_space = ()
self.observation_space = ()
pass pass
def reset(self): def reset(self):
""" """
Resets the env and returns observations from agents in the environment. Resets the env and returns observations from agents in the environment.
Returns: Returns
-------
obs : dict obs : dict
New observations for each agent. New observations for each agent.
""" """
...@@ -66,7 +67,7 @@ class Environment: ...@@ -66,7 +67,7 @@ class Environment:
The returns are dicts mapping from agent_id strings to values. The returns are dicts mapping from agent_id strings to values.
Parameters Parameters
------- ----------
action_dict : dict action_dict : dict
Dictionary of actions to execute, indexed by agent id. Dictionary of actions to execute, indexed by agent id.
......
...@@ -2,27 +2,29 @@ ...@@ -2,27 +2,29 @@
ObservationBuilder objects are objects that can be passed to environments designed for customizability. ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle). The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).
+ Reset() is called after each environment reset, to allow for pre-computing relevant data. + `reset()` is called after each environment reset, to allow for pre-computing relevant data.
+ `get()` is called whenever an observation has to be computed, potentially for each agent independently in case of \
multi-agent environments.
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
""" """
from typing import Optional, List
import numpy as np import numpy as np
from flatland.core.env import Environment
class ObservationBuilder: class ObservationBuilder:
""" """
ObservationBuilder base class. ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
observations.
""" """
def __init__(self): def __init__(self):
self.observation_space = () self.env = None
def _set_env(self, env): def set_env(self, env: Environment):
self.env = env self.env: Environment = env
def reset(self): def reset(self):
""" """
...@@ -30,35 +32,37 @@ class ObservationBuilder: ...@@ -30,35 +32,37 @@ class ObservationBuilder:
""" """
raise NotImplementedError() raise NotImplementedError()
def get_many(self, handles=[]): def get_many(self, handles: Optional[List[int]] = None):
""" """
Called whenever an observation has to be computed for the `env' environment, for each agent with handle Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles' list. in the `handles` list.
Parameters Parameters
------- ----------
handles : list of handles (optional) handles : list of handles, optional
List with the handles of the agents for which to compute the observation vector. List with the handles of the agents for which to compute the observation vector.
Returns Returns
------- -------
function function
A dictionary of observation structures, specific to the corresponding environment, with handles from A dictionary of observation structures, specific to the corresponding environment, with handles from
`handles' as keys. `handles` as keys.
""" """
observations = {} observations = {}
if handles is None:
handles = []
for h in handles: for h in handles:
observations[h] = self.get(h) observations[h] = self.get(h)
return observations return observations
def get(self, handle=0): def get(self, handle: int = 0):
""" """
Called whenever an observation has to be computed for the `env' environment, possibly Called whenever an observation has to be computed for the `env` environment, possibly
for each agent independently (agent id `handle'). for each agent independently (agent id `handle`).
Parameters Parameters
------- ----------
handle : int (optional) handle : int, optional
Handle of the agent for which to compute the observation vector. Handle of the agent for which to compute the observation vector.
Returns Returns
...@@ -82,16 +86,13 @@ class DummyObservationBuilder(ObservationBuilder): ...@@ -82,16 +86,13 @@ class DummyObservationBuilder(ObservationBuilder):
""" """
def __init__(self): def __init__(self):
self.observation_space = () super().__init__()
def _set_env(self, env):
self.env = env
def reset(self): def reset(self):
pass pass
def get_many(self, handles=[]): def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True return True
def get(self, handle=0): def get(self, handle: int = 0) -> bool:
return True return True
...@@ -3,11 +3,12 @@ PredictionBuilder objects are objects that can be passed to environments designe ...@@ -3,11 +3,12 @@ PredictionBuilder objects are objects that can be passed to environments designe
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]). The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then If predictions are not required in every step or not for all agents, then
+ Reset() is called after each environment reset, to allow for pre-computing relevant data. + `reset()` is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an step has to be computed, potentially for each agent independently in + `get()` is called whenever an step has to be computed, potentially for each agent independently in \
case of multi-agent environments. case of multi-agent environments.
""" """
from flatland.core.env import Environment
class PredictionBuilder: class PredictionBuilder:
...@@ -18,8 +19,9 @@ class PredictionBuilder: ...@@ -18,8 +19,9 @@ class PredictionBuilder:
def __init__(self, max_depth: int = 20): def __init__(self, max_depth: int = 20):
self.max_depth = max_depth self.max_depth = max_depth
self.env = None
def _set_env(self, env): def set_env(self, env: Environment):
self.env = env self.env = env
def reset(self): def reset(self):
...@@ -28,16 +30,13 @@ class PredictionBuilder: ...@@ -28,16 +30,13 @@ class PredictionBuilder:
""" """
pass pass
def get(self, custom_args=None, handle=0): def get(self, handle: int = 0):
""" """
Called whenever get_many in the observation build is called. Called whenever get_many in the observation build is called.
Parameters Parameters
------- ----------
custom_args: dict handle : int, optional
Implementation-dependent custom arguments, see the sub-classes.
handle : int (optional)
Handle of the agent for which to compute the observation vector. Handle of the agent for which to compute the observation vector.
Returns Returns
......
from enum import IntEnum from enum import IntEnum
from typing import Type from functools import lru_cache
from typing import Type, List
import numpy as np import numpy as np
from flatland.core.transitions import Transitions from flatland.core.transitions import Transitions
# maxsize=None can be used because the number of possible transition is limited (16 bit encoded) and the
# direction/orientation is also limited (2bit). Where the 16bit are only sparse used = number of rail types
# Those methods can be cached -> the are independant of the railways (env)
@lru_cache(maxsize=128)
def fast_grid4_get_transitions(cell_transition, orientation):
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
@lru_cache(maxsize=128)
def fast_grid4_get_transition(cell_transition, orientation, direction):
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
@lru_cache(maxsize=128)
def fast_grid4_set_transitions(cell_transition, orientation, new_transitions):
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_remove_deadends(cell_transition):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
maskDeadEnds = Grid4Transitions.maskDeadEnds()
cell_transition &= cell_transition & (~maskDeadEnds) & 0xffff
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_rotate_transition(cell_transition, rotation=0):
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = fast_grid4_get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = fast_grid4_set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (
value >> (rotation * 4))
cell_transition = value
return cell_transition
class Grid4TransitionsEnum(IntEnum): class Grid4TransitionsEnum(IntEnum):
NORTH = 0 NORTH = 0
EAST = 1 EAST = 1
...@@ -24,9 +82,9 @@ class Grid4Transitions(Transitions): ...@@ -24,9 +82,9 @@ class Grid4Transitions(Transitions):
""" """
Grid4Transitions class derived from Transitions. Grid4Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand). Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed. Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions' GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 16 bits. list, each represented as a bitmap of 16 bits.
Whether a transition is allowed or not depends on which direction an agent Whether a transition is allowed or not depends on which direction an agent
...@@ -57,8 +115,11 @@ class Grid4Transitions(Transitions): ...@@ -57,8 +115,11 @@ class Grid4Transitions(Transitions):
# row,col delta for each direction # row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
# These bits represent all the possible dead ends # These bits represent all the possible dead ends
self.maskDeadEnds = 0b0010000110000100 @staticmethod
@lru_cache()
def maskDeadEnds():
return 0b0010000110000100
def get_type(self): def get_type(self):
return np.uint16 return np.uint16
...@@ -67,8 +128,8 @@ class Grid4Transitions(Transitions): ...@@ -67,8 +128,8 @@ class Grid4Transitions(Transitions):
""" """
Get the 4 possible transitions ((N,E,S,W), 4 elements tuple Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent oriented if no diagonal transitions allowed) available for an agent oriented
in direction `orientation' and inside a cell with in direction `orientation` and inside a cell with
transitions `cell_transition'. transitions `cell_transition`.
Parameters Parameters
---------- ----------
...@@ -83,16 +144,15 @@ class Grid4Transitions(Transitions): ...@@ -83,16 +144,15 @@ class Grid4Transitions(Transitions):
List of the validity of transitions in the cell. List of the validity of transitions in the cell.
""" """
bits = (cell_transition >> ((3 - orientation) * 4)) return fast_grid4_get_transitions(cell_transition, orientation)
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
def set_transitions(self, cell_transition, orientation, new_transitions): def set_transitions(self, cell_transition, orientation, new_transitions):
""" """
Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent if no diagonal transitions allowed) available for an agent
oriented in direction `orientation' and inside a cell with transitions oriented in direction `orientation` and inside a cell with transitions
`cell_transition'. A new `cell_transition' is returned with `cell_transition'. A new `cell_transition` is returned with
the specified bits replaced by `new_transitions'. the specified bits replaced by `new_transitions`.
Parameters Parameters
---------- ----------
...@@ -107,28 +167,17 @@ class Grid4Transitions(Transitions): ...@@ -107,28 +167,17 @@ class Grid4Transitions(Transitions):
------- -------
int int
An updated bitmap that replaces the original transitions validity An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate of `cell_transition' with `new_transitions`, for the appropriate
`orientation'. `orientation`.
""" """
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4)) return fast_grid4_set_transitions(cell_transition, orientation, new_transitions)
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
def get_transition(self, cell_transition, orientation, direction): def get_transition(self, cell_transition, orientation, direction):
""" """
Get the transition bit (1 value) that determines whether an agent Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction' `cell_transition' can move to the cell in direction `direction`
relative to the current cell. relative to the current cell.
Parameters Parameters
...@@ -146,13 +195,14 @@ class Grid4Transitions(Transitions): ...@@ -146,13 +195,14 @@ class Grid4Transitions(Transitions):
Validity of the requested transition: 0/1 allowed/not allowed. Validity of the requested transition: 0/1 allowed/not allowed.
""" """
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 return fast_grid4_get_transition(cell_transition, orientation, direction)
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): def set_transition(self, cell_transition, orientation, direction, new_transition,
remove_deadends=False):
""" """
Set the transition bit (1 value) that determines whether an agent Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction' `cell_transition' can move to the cell in direction `direction`
relative to the current cell. relative to the current cell.
Parameters Parameters
...@@ -171,8 +221,8 @@ class Grid4Transitions(Transitions): ...@@ -171,8 +221,8 @@ class Grid4Transitions(Transitions):
------- -------
int int
An updated bitmap that replaces the original transitions validity An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate of `cell_transition' with `new_transitions`, for the appropriate
`orientation'. `orientation`.
""" """
if new_transition: if new_transition:
...@@ -181,7 +231,7 @@ class Grid4Transitions(Transitions): ...@@ -181,7 +231,7 @@ class Grid4Transitions(Transitions):
cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
if remove_deadends: if remove_deadends:
cell_transition = self.remove_deadends(cell_transition) cell_transition = fast_grid4_remove_deadends(cell_transition)
return cell_transition return cell_transition
...@@ -196,7 +246,7 @@ class Grid4Transitions(Transitions): ...@@ -196,7 +246,7 @@ class Grid4Transitions(Transitions):
16 bits used to encode the valid transitions for a cell. 16 bits used to encode the valid transitions for a cell.
rotation : int rotation : int
Angle by which to clock-wise rotate the transition bits in Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees. `cell_transition` by. I.e., rotation={0, 90, 180, 270} degrees.
Returns Returns
------- -------
...@@ -206,27 +256,18 @@ class Grid4Transitions(Transitions): ...@@ -206,27 +256,18 @@ class Grid4Transitions(Transitions):
""" """
# Rotate the individual bits in each block # Rotate the individual bits in each block
value = cell_transition return fast_grid4_rotate_transition(cell_transition, rotation)
rotation = rotation // 90
for i in range(4):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
cell_transition = value
return cell_transition
def get_direction_enum(self) -> Type[Grid4TransitionsEnum]: def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
return Grid4TransitionsEnum return Grid4TransitionsEnum
def has_deadend(self, cell_transition): @staticmethod
@lru_cache()
def has_deadend(cell_transition):
""" """
Checks if one entry can only by exited by a turn-around. Checks if one entry can only by exited by a turn-around.
""" """
if cell_transition & self.maskDeadEnds > 0: if cell_transition & Grid4Transitions.maskDeadEnds() > 0:
return True return True
else: else:
return False return False
...@@ -235,5 +276,9 @@ class Grid4Transitions(Transitions): ...@@ -235,5 +276,9 @@ class Grid4Transitions(Transitions):
""" """
Remove all turn-arounds (e.g. N-S, S-N, E-W,...). Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
""" """
cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff return fast_grid4_remove_deadends(cell_transition)
return cell_transition
@staticmethod
@lru_cache()
def get_entry_directions(cell_transition) -> List[int]:
return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
from flatland.core.grid.grid4_utils import validate_new_transition import numpy as np
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
from flatland.core.grid.grid_utils import IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.ordered_set import OrderedSet
class AStarNode():
class AStarNode:
"""A node class for A* Pathfinding""" """A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None): def __init__(self, pos: IntVector2D, parent=None):
self.parent = parent self.parent = parent
self.pos = pos self.pos: IntVector2D = pos
self.g = 0 self.g = 0.0
self.h = 0 self.h = 0.0
self.f = 0 self.f = 0.0
def __eq__(self, other): def __eq__(self, other):
"""
Parameters
----------
other : AStarNode
"""
return self.pos == other.pos return self.pos == other.pos
def __hash__(self): def __hash__(self):
...@@ -25,16 +37,35 @@ class AStarNode(): ...@@ -25,16 +37,35 @@ class AStarNode():
self.f = other.f self.f = other.f
def a_star(rail_trans, rail_array, start, end): def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, avoid_rails=False,
respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
"""
:param avoid_rails:
:param grid_map: Grid Map where the path is found in
:param start: Start positions as (row,column)
:param end: End position as (row,column)
:param a_star_distance_function: Define the distance function to use as heuristc:
-get_euclidean_distance
-get_manhattan_distance
-get_chebyshev_distance
:param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map.
- True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found
- False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards
:param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map
:return: IF a path is found a ordered list of al cells in path is returned
"""
""" """
Returns a list of tuples as a path from the given start to end. Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end. If no path is found, returns path to closest point to end.
""" """
rail_shape = rail_array.shape rail_shape = grid_map.grid.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end) start_node = AStarNode(start, None)
open_nodes = set() end_node = AStarNode(end, None)
closed_nodes = set() open_nodes = OrderedSet()
closed_nodes = OrderedSet()
open_nodes.add(start_node) open_nodes.add(start_node)
while len(open_nodes) > 0: while len(open_nodes) > 0:
...@@ -58,6 +89,7 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -58,6 +89,7 @@ def a_star(rail_trans, rail_array, start, end):
while current is not None: while current is not None:
path.append(current.pos) path.append(current.pos)
current = current.parent current = current.parent
# return reversed path # return reversed path
return path[::-1] return path[::-1]
...@@ -67,17 +99,28 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -67,17 +99,28 @@ def a_star(rail_trans, rail_array, start, end):
prev_pos = current_node.parent.pos prev_pos = current_node.parent.pos
else: else:
prev_pos = None prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) # update the "current" pos
node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos)
# is node_pos inside the grid?
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue continue
# validate positions # validate positions
if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): #
if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos,
end_node.pos) and respect_transition_validity:
continue continue
# create new node # create new node
new_node = AStarNode(current_node, node_pos) new_node = AStarNode(node_pos, current_node)
# Skip paths through forbidden regions if they are provided
if forbidden_cells is not None:
if node_pos in forbidden_cells and new_node != start_node and new_node != end_node:
continue
children.append(new_node) children.append(new_node)
# loop through children # loop through children
...@@ -87,11 +130,12 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -87,11 +130,12 @@ def a_star(rail_trans, rail_array, start, end):
continue continue
# create the f, g, and h values # create the f, g, and h values
child.g = current_node.g + 1 child.g = current_node.g + 1.0
# this heuristic favors diagonal paths:
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \# noqa: E800
# this heuristic avoids diagonal paths # this heuristic avoids diagonal paths
child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) if avoid_rails:
child.h = a_star_distance_function(child.pos, end_node.pos) + np.clip(grid_map.grid[child.pos], 0, 1)
else:
child.h = a_star_distance_function(child.pos, end_node.pos)
child.f = child.g + child.h child.f = child.g + child.h
# already in the open list? # already in the open list?
......
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2D
def get_direction(pos1, pos2) -> Grid4TransitionsEnum: def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
""" """
Assumes pos1 and pos2 are adjacent location on grid. Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions. Returns direction (int) that can be used with transitions.
...@@ -9,66 +12,41 @@ def get_direction(pos1, pos2) -> Grid4TransitionsEnum: ...@@ -9,66 +12,41 @@ def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
diff_0 = pos2[0] - pos1[0] diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1] diff_1 = pos2[1] - pos1[1]
if diff_0 < 0: if diff_0 < 0:
return 0 return Grid4TransitionsEnum.NORTH
if diff_0 > 0: if diff_0 > 0:
return 2 return Grid4TransitionsEnum.SOUTH
if diff_1 > 0: if diff_1 > 0:
return 1 return Grid4TransitionsEnum.EAST
if diff_1 < 0: if diff_1 < 0:
return 3 return Grid4TransitionsEnum.WEST
raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
def mirror(dir): def mirror(dir):
return (dir + 2) % 4 return (dir + 2) % 4
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
# start by getting direction used to get to current node def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
# and direction from current node to possible child node """
new_dir = get_direction(current_pos, new_pos) Returns the closest direction orientation of position 2 relative to position 1
if prev_pos is not None: :param pos1: position we are interested in
current_dir = get_direction(prev_pos, current_pos) :param pos2: position we want to know it is facing
else: :return: direction NESW as int N:0 E:1 S:2 W:3
current_dir = new_dir """
# create new transition that would go to child diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1]))
new_trans = rail_array[current_pos] axis = np.argmax(np.power(diff_vec, 2))
if prev_pos is None: direction = np.sign(diff_vec[axis])
if new_trans == 0: if axis == 0:
# need to flip direction because of how end points are defined if direction > 0:
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) return Grid4TransitionsEnum.NORTH
else: else:
# check if matches existing layout return Grid4TransitionsEnum.SOUTH
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else: else:
# set the forward path if direction > 0:
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) return Grid4TransitionsEnum.WEST
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else: else:
# check if matches existing layout return Grid4TransitionsEnum.EAST
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
...@@ -20,9 +20,9 @@ class Grid8Transitions(Transitions): ...@@ -20,9 +20,9 @@ class Grid8Transitions(Transitions):
""" """
Grid8Transitions class derived from Transitions. Grid8Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand). Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed. Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions' GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 64 bits. list, each represented as a bitmap of 64 bits.
0=North, 1=North-East, etc. 0=North, 1=North-East, etc.
...@@ -82,8 +82,8 @@ class Grid8Transitions(Transitions): ...@@ -82,8 +82,8 @@ class Grid8Transitions(Transitions):
------- -------
int int
An updated bitmap that replaces the original transitions validity An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate of `cell_transition' with `new_transitions`, for the appropriate
`orientation'. `orientation`.
""" """
mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8)) mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
...@@ -106,8 +106,8 @@ class Grid8Transitions(Transitions): ...@@ -106,8 +106,8 @@ class Grid8Transitions(Transitions):
def get_transition(self, cell_transition, orientation, direction): def get_transition(self, cell_transition, orientation, direction):
""" """
Get the transition bit (1 value) that determines whether an agent Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction' `cell_transition' can move to the cell in direction `direction`
relative to the current cell. relative to the current cell.
Parameters Parameters
...@@ -131,8 +131,8 @@ class Grid8Transitions(Transitions): ...@@ -131,8 +131,8 @@ class Grid8Transitions(Transitions):
""" """
Set the transition bit (1 value) that determines whether an agent Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction' `cell_transition' can move to the cell in direction `direction`
relative to the current cell. relative to the current cell.
Parameters Parameters
...@@ -150,8 +150,8 @@ class Grid8Transitions(Transitions): ...@@ -150,8 +150,8 @@ class Grid8Transitions(Transitions):
------- -------
int int
An updated bitmap that replaces the original transitions validity An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate of `cell_transition' with `new_transitions`, for the appropriate
`orientation'. `orientation`.
""" """
if new_transition: if new_transition:
...@@ -172,7 +172,7 @@ class Grid8Transitions(Transitions): ...@@ -172,7 +172,7 @@ class Grid8Transitions(Transitions):
64 bits used to encode the valid transitions for a cell. 64 bits used to encode the valid transitions for a cell.
rotation : int rotation : int
Angle by which to clock-wise rotate the transition bits in Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 45, 90, 135, 180, `cell_transition` by. I.e., rotation={0, 45, 90, 135, 180,
225, 270, 315} degrees. 225, 270, 315} degrees.
Returns Returns
......
from math import isnan
from typing import Tuple, Callable, List, Type
import numpy as np import numpy as np
Vector2D: Type = Tuple[float, float]
IntVector2D: Type = Tuple[int, int]
def position_to_coordinate(depth, positions): IntVector2DArray: Type = List[IntVector2D]
"""Converts coordinates to positions: IntVector2DArrayArray: Type = List[List[IntVector2D]]
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1) Vector2DArray: Type = List[Vector2D]
... Vector2DArrayArray: Type = List[List[Vector2D]]
(d-1,0) (d-1,1) (d-1,w-1)
] IntVector2DDistance: Type = Callable[[IntVector2D, IntVector2D], float]
class Vec2dOperations:
@staticmethod
def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
check if node_a and nobe_b are equal
"""
return node_a[0] == node_b[0] and node_a[1] == node_b[1]
@staticmethod
def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a - node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] - node_b[0], node_a[1] - node_b[1]
@staticmethod
def add(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] + node_b[0], node_a[1] + node_b[1]
@staticmethod
def make_orthogonal(node: Vector2D) -> Vector2D:
"""
vector operation : rotates the 2D vector +90°
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[1], -node[0]
@staticmethod
def get_norm(node: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return np.sqrt(node[0] * node[0] + node[1] * node[1])
@staticmethod
def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Euclidean distance
"""
return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a))
@staticmethod
def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the manhattan distance of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Mahnhattan distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return np.abs(delta[0]) + np.abs(delta[1])
@staticmethod
def get_chebyshev_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the chebyshev norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
the chebyshev distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return max(np.abs(delta[0]), np.abs(delta[1]))
@staticmethod
def normalize(node: Vector2D) -> Tuple[float, float]:
"""
normalize the 2d vector = `v/|v|`
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
n = Vec2dOperations.get_norm(node)
if n > 0.0:
n = 1 / n
return Vec2dOperations.scale(node, n)
@staticmethod
def scale(node: Vector2D, scale: float) -> Vector2D:
"""
scales the 2d vector = node * scale
:param node: tuple with coordinate (x,y) or 2d vector
:param scale: scalar to scale
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[0] * scale, node[1] * scale
@staticmethod
def round(node: Vector2D) -> IntVector2D:
"""
rounds the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return int(np.round(node[0])), int(np.round(node[1]))
@staticmethod
def ceil(node: Vector2D) -> IntVector2D:
"""
ceiling the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.ceil(node[0])), int(np.ceil(node[1]))
@staticmethod
def floor(node: Vector2D) -> IntVector2D:
"""
floor the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.floor(node[0])), int(np.floor(node[1]))
@staticmethod
def bound(node: Vector2D, min_value: float, max_value: float) -> Vector2D:
"""
force the values x and y to be between min_value and max_value
:param node: tuple with coordinate (x,y) or 2d vector
:param min_value: scalar value
:param max_value: scalar value
:return:
tuple with coordinate (x,y) or 2d vector
"""
return max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))
@staticmethod
def rotate(node: Vector2D, rot_in_degree: float) -> Vector2D:
"""
rotate the 2d vector with given angle in degree
:param node: tuple with coordinate (x,y) or 2d vector
:param rot_in_degree: angle in degree
:return:
tuple with coordinate (x,y) or 2d vector
"""
alpha = rot_in_degree / 180.0 * np.pi
x0 = node[0]
y0 = node[1]
x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha)
y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha)
return x1, y1
def position_to_coordinate(depth: int, positions: List[int]):
"""Converts coordinates to positions::
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
--> -->
[ 0 d .. (w-1)*d [ 0 d .. (w-1)*d
1 d+1 1 d+1
... ...
d-1 2d-1 w*d-1 d-1 2d-1 w*d-1
] ]
:param depth: Parameters
:param positions: ----------
:return: depth : int
positions : List[Tuple[int,int]]
""" """
coords = () coords = ()
for p in positions: for p in positions:
...@@ -29,7 +264,8 @@ def position_to_coordinate(depth, positions): ...@@ -29,7 +264,8 @@ def position_to_coordinate(depth, positions):
def coordinate_to_position(depth, coords): def coordinate_to_position(depth, coords):
""" """
Converts positions to coordinates: Converts positions to coordinates::
[ 0 d .. (w-1)*d [ 0 d .. (w-1)*d
1 d+1 1 d+1
... ...
...@@ -46,13 +282,17 @@ def coordinate_to_position(depth, coords): ...@@ -46,13 +282,17 @@ def coordinate_to_position(depth, coords):
:param coords: :param coords:
:return: :return:
""" """
position = np.empty(len(coords), dtype=int) position = list(range(len(coords)))
idx = 0 for index, t in enumerate(coords):
for t in coords: if isnan(t[0]):
position[idx] = int(t[1] * depth + t[0]) position[index] = -1
idx += 1 else:
position[index] = int(t[1] * depth + t[0])
return position return position
def distance_on_rail(pos1, pos2): def distance_on_rail(pos1, pos2, metric="Euclidean"):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) if metric == "Euclidean":
return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
if metric == "Manhattan":
return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4 import Grid4Transitions
from flatland.utils.ordered_set import OrderedSet
class RailEnvTransitions(Grid4Transitions): class RailEnvTransitions(Grid4Transitions):
""" """
Special case of `GridTransitions' over a 2D-grid, with a pre-defined set Special case of `GridTransitions` over a 2D-grid, with a pre-defined set
of transitions mimicking the types of real Swiss rail connections. of transitions mimicking the types of real Swiss rail connections.
--------------------------------------------------------------------------
As no diagonal transitions are allowed in the RailEnv environment, the As no diagonal transitions are allowed in the RailEnv environment, the
possible transitions for RailEnv from a cell to its neighboring ones possible transitions for RailEnv from a cell to its neighboring ones
are represented over 16 bits. are represented over 16 bits.
...@@ -44,7 +43,7 @@ class RailEnvTransitions(Grid4Transitions): ...@@ -44,7 +43,7 @@ class RailEnvTransitions(Grid4Transitions):
) )
# create this to make validation faster # create this to make validation faster
self.transitions_all = set() self.transitions_all = OrderedSet()
for index, trans in enumerate(self.transitions): for index, trans in enumerate(self.transitions):
self.transitions_all.add(trans) self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10): if index in (2, 4, 6, 7, 8, 9, 10):
......
...@@ -7,9 +7,15 @@ from importlib_resources import path ...@@ -7,9 +7,15 @@ from importlib_resources import path
from numpy import array from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
class TransitionMap: class TransitionMap:
""" """
Base TransitionMap class. Base TransitionMap class.
...@@ -21,7 +27,7 @@ class TransitionMap: ...@@ -21,7 +27,7 @@ class TransitionMap:
def get_transitions(self, cell_id): def get_transitions(self, cell_id):
""" """
Return a tuple of transitions available in a cell specified by Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions, `cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between, with values 0 or 1, or potentially in between,
for stochastic transitions). for stochastic transitions).
...@@ -41,8 +47,8 @@ class TransitionMap: ...@@ -41,8 +47,8 @@ class TransitionMap:
def set_transitions(self, cell_id, new_transitions): def set_transitions(self, cell_id, new_transitions):
""" """
Replaces the available transitions in cell `cell_id' with the tuple Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions' must have `new_transitions'. `new_transitions` must have
one element for each possible transition. one element for each possible transition.
Parameters Parameters
...@@ -58,8 +64,8 @@ class TransitionMap: ...@@ -58,8 +64,8 @@ class TransitionMap:
def get_transition(self, cell_id, transition_index): def get_transition(self, cell_id, transition_index):
""" """
Return the status of whether an agent in cell `cell_id' can perform a Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index (e.g., the NESW direction movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid). of movement, for agents on a grid).
Parameters Parameters
...@@ -83,8 +89,8 @@ class TransitionMap: ...@@ -83,8 +89,8 @@ class TransitionMap:
def set_transition(self, cell_id, transition_index, new_transition): def set_transition(self, cell_id, transition_index, new_transition):
""" """
Replaces the validity of transition to `transition_index' in cell Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition'. `cell_id' with the new `new_transition`.
Parameters Parameters
...@@ -111,7 +117,7 @@ class GridTransitionMap(TransitionMap): ...@@ -111,7 +117,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions. GridTransitionMap implements utility functions.
""" """
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])): def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None):
""" """
Builder for GridTransitionMap object. Builder for GridTransitionMap object.
...@@ -130,7 +136,11 @@ class GridTransitionMap(TransitionMap): ...@@ -130,7 +136,11 @@ class GridTransitionMap(TransitionMap):
self.width = width self.width = width
self.height = height self.height = height
self.transitions = transitions self.transitions = transitions
self.random_generator = np.random.RandomState()
if random_seed is None:
self.random_generator.seed(12)
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_full_transitions(self, row, column): def get_full_transitions(self, row, column):
...@@ -154,7 +164,7 @@ class GridTransitionMap(TransitionMap): ...@@ -154,7 +164,7 @@ class GridTransitionMap(TransitionMap):
def get_transitions(self, row, column, orientation): def get_transitions(self, row, column, orientation):
""" """
Return a tuple of transitions available in a cell specified by Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions, `cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between, with values 0 or 1, or potentially in between,
for stochastic transitions). for stochastic transitions).
...@@ -176,8 +186,8 @@ class GridTransitionMap(TransitionMap): ...@@ -176,8 +186,8 @@ class GridTransitionMap(TransitionMap):
def set_transitions(self, cell_id, new_transitions): def set_transitions(self, cell_id, new_transitions):
""" """
Replaces the available transitions in cell `cell_id' with the tuple Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions' must have `new_transitions'. `new_transitions` must have
one element for each possible transition. one element for each possible transition.
Parameters Parameters
...@@ -202,8 +212,8 @@ class GridTransitionMap(TransitionMap): ...@@ -202,8 +212,8 @@ class GridTransitionMap(TransitionMap):
def get_transition(self, cell_id, transition_index): def get_transition(self, cell_id, transition_index):
""" """
Return the status of whether an agent in cell `cell_id' can perform a Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index (e.g., the NESW direction movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid). of movement, for agents on a grid).
Parameters Parameters
...@@ -230,8 +240,8 @@ class GridTransitionMap(TransitionMap): ...@@ -230,8 +240,8 @@ class GridTransitionMap(TransitionMap):
def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False): def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False):
""" """
Replaces the validity of transition to `transition_index' in cell Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition'. `cell_id' with the new `new_transition`.
Parameters Parameters
...@@ -259,7 +269,7 @@ class GridTransitionMap(TransitionMap): ...@@ -259,7 +269,7 @@ class GridTransitionMap(TransitionMap):
def save_transition_map(self, filename): def save_transition_map(self, filename):
""" """
Save the transitions grid as `filename', in npy format. Save the transitions grid as `filename`, in npy format.
Parameters Parameters
---------- ----------
...@@ -271,9 +281,9 @@ class GridTransitionMap(TransitionMap): ...@@ -271,9 +281,9 @@ class GridTransitionMap(TransitionMap):
def load_transition_map(self, package, resource): def load_transition_map(self, package, resource):
""" """
Load the transitions grid from `filename' (npy format). Load the transitions grid from `filename` (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway. initialized with the correct `transitions` object anyway.
Parameters Parameters
---------- ----------
...@@ -283,7 +293,7 @@ class GridTransitionMap(TransitionMap): ...@@ -283,7 +293,7 @@ class GridTransitionMap(TransitionMap):
Name of the file from which to load the transitions grid within the package. Name of the file from which to load the transitions grid within the package.
override_gridsize : bool override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) ) (height,width) )
...@@ -298,12 +308,155 @@ class GridTransitionMap(TransitionMap): ...@@ -298,12 +308,155 @@ class GridTransitionMap(TransitionMap):
self.height = new_height self.height = new_height
self.grid = new_grid self.grid = new_grid
def cell_neighbours_valid(self, rcPos, check_this_cell=False): def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a dead-end.
"""
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a left/right simple turn.
"""
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
def is_simple_turn(trans):
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
for _ in range(3):
trans = self.transitions.rotate_transition(trans, rotation=90)
all_simple_turns.add(trans)
return trans in all_simple_turns
return is_simple_turn(tmp)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions(node_position[0], node_position[1], node_direction)
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
return False
return True
def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
""" """
Check validity of cell at rcPos = tuple(row, column) Check validity of cell at rcPos = tuple(row, column)
Checks that: Checks that:
- surrounding cells have inbound transitions for all the - surrounding cells have inbound transitions for all the outbound transitions of this cell.
outbound transitions of this cell.
These are NOT checked - see transition.is_valid: These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S) - all transitions have the mirror transitions (N->E <=> W->S)
...@@ -346,8 +499,141 @@ class GridTransitionMap(TransitionMap): ...@@ -346,8 +499,141 @@ class GridTransitionMap(TransitionMap):
if any(t4Trans2): if any(t4Trans2):
continue continue
else: else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False return False
return True return True
def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
symmetrical = cells[6]
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south]
# loop over available outbound directions (indices) for rcPos
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
connected = 0
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line set deadend
if number_of_incoming == 1:
if self.get_full_transitions(*rcPos) == 0:
self.set_transitions(rcPos, 0)
else:
self.set_transitions(rcPos, 0)
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
self.set_transitions(rcPos, 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
hole = np.argwhere(incoming_connections < 1)[0][0]
if direction >= 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 0:
transition = simple_switch_west_south
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
# Make a double slip switch
if number_of_incoming == 4:
rotation = self.random_generator.randint(2)
transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
:param prev_pos: The previous position we were checking
:param current_pos: The current position we are checking
:param new_pos: Possible child position we move into
:param end_pos: End cell of path we are drawing
:return: True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return self.transitions.is_valid(new_trans)
def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?) # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
...@@ -12,7 +12,7 @@ class Transitions: ...@@ -12,7 +12,7 @@ class Transitions:
Generic class that implements checks to control whether a Generic class that implements checks to control whether a
certain transition is allowed (agent facing a direction certain transition is allowed (agent facing a direction
`orientation' and moving into direction `orientation') `orientation' and moving into direction `orientation`)
""" """
def get_type(self): def get_type(self):
...@@ -21,7 +21,7 @@ class Transitions: ...@@ -21,7 +21,7 @@ class Transitions:
def get_transitions(self, cell_transition, orientation): def get_transitions(self, cell_transition, orientation):
""" """
Return a tuple of transitions available in a cell specified by Return a tuple of transitions available in a cell specified by
`cell_transition' for an agent facing direction `orientation' `cell_transition' for an agent facing direction `orientation`
(e.g., a tuple of size of the maximum number of transitions, (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between, with values 0 or 1, or potentially in between,
for stochastic transitions). for stochastic transitions).
...@@ -45,9 +45,9 @@ class Transitions: ...@@ -45,9 +45,9 @@ class Transitions:
def set_transitions(self, cell_transition, orientation, new_transitions): def set_transitions(self, cell_transition, orientation, new_transitions):
""" """
Return a `cell_transition' specification where the transitions Return a `cell_transition` specification where the transitions
available for an agent facing direction `orientation' are replaced available for an agent facing direction `orientation` are replaced
with the tuple `new_transitions'. `new_orientations' must have with the tuple `new_transitions'. `new_orientations` must have
one element for each possible transition. one element for each possible transition.
Parameters Parameters
...@@ -65,8 +65,8 @@ class Transitions: ...@@ -65,8 +65,8 @@ class Transitions:
------- -------
[cell-content] [cell-content]
An updated class-specific object that replaces the original An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions', transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation'. for the appropriate `orientation`.
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -74,8 +74,8 @@ class Transitions: ...@@ -74,8 +74,8 @@ class Transitions:
def get_transition(self, cell_transition, orientation, direction): def get_transition(self, cell_transition, orientation, direction):
""" """
Return the status of whether an agent oriented in directions Return the status of whether an agent oriented in directions
`orientation' and inside a cell with transitions `cell_transition' `orientation' and inside a cell with transitions `cell_transition`
can move to the cell in direction `direction' relative can move to the cell in direction `direction` relative
to the current cell. to the current cell.
Parameters Parameters
...@@ -101,11 +101,11 @@ class Transitions: ...@@ -101,11 +101,11 @@ class Transitions:
def set_transition(self, cell_transition, orientation, direction, def set_transition(self, cell_transition, orientation, direction,
new_transition): new_transition):
""" """
Return a `cell_transition' specification where the status of Return a `cell_transition` specification where the status of
whether an agent oriented in direction `orientation' and inside whether an agent oriented in direction `orientation` and inside
a cell with transitions `cell_transition' can move to the cell a cell with transitions `cell_transition` can move to the cell
in direction `direction' relative to the current cell is set in direction `direction` relative to the current cell is set
to `new_transition'. to `new_transition`.
Parameters Parameters
---------- ----------
...@@ -125,8 +125,8 @@ class Transitions: ...@@ -125,8 +125,8 @@ class Transitions:
------- -------
[cell-content] [cell-content]
An updated class-specific object that replaces the original An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions', transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation' to `direction'. for the appropriate `orientation' to `direction`.
""" """
raise NotImplementedError() raise NotImplementedError()
......
import networkx as nx
import numpy as np
from typing import List, Tuple
import graphviz as gv
class MotionCheck(object):
""" Class to find chains of agents which are "colliding" with a stopped agent.
This is to allow close-packed chains of agents, ie a train of agents travelling
at the same speed with no gaps between them,
"""
def __init__(self):
self.G = nx.DiGraph()
self.nDeadlocks = 0
self.svDeadlocked = set()
def addAgent(self, iAg, rc1, rc2, xlabel=None):
""" add an agent and its motion as row,col tuples of current and next position.
The agent's current position is given an "agent" attribute recording the agent index.
If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created.
xlabel is used for test cases to give a label (see graphviz)
"""
# Agents which have not yet entered the env have position None.
# Substitute this for the row = -1, column = agent index
if rc1 is None:
rc1 = (-1, iAg)
if rc2 is None:
rc2 = (-1, iAg)
self.G.add_node(rc1, agent=iAg)
if xlabel:
self.G.nodes[rc1]["xlabel"] = xlabel
self.G.add_edge(rc1, rc2)
def find_stops(self):
""" find all the stopped agents as a set of rc position nodes
A stopped agent is a self-loop on a cell node.
"""
# get the (sparse) adjacency matrix
spAdj = nx.linalg.adjacency_matrix(self.G)
# the stopped agents appear as 1s on the diagonal
# the where turns this into a list of indices of the 1s
giStops = np.where(spAdj.diagonal())[0]
# convert the cell/node indices into the node rc values
lvAll = list(self.G.nodes())
# pick out the stops by their indices
lvStops = [ lvAll[i] for i in giStops ]
# make it into a set ready for a set intersection
svStops = set(lvStops)
return svStops
def find_stops2(self):
""" alternative method to find stopped agents, using a networkx call to find selfloop edges
"""
svStops = { u for u,v in nx.classes.function.selfloop_edges(self.G) }
return svStops
def find_stop_preds(self, svStops=None):
""" Find the predecessors to a list of stopped agents (ie the nodes / vertices)
Returns the set of predecessors.
Includes "chained" predecessors.
"""
if svStops is None:
svStops = self.find_stops2()
# Get all the chains of agents - weakly connected components.
# Weakly connected because it's a directed graph and you can traverse a chain of agents
# in only one direction
lWCC = list(nx.algorithms.components.weakly_connected_components(self.G))
svBlocked = set()
for oWCC in lWCC:
#print("Component:", oWCC)
# Get the node details for this WCC in a subgraph
Gwcc = self.G.subgraph(oWCC)
# Find all the stops in this chain or tree
svCompStops = svStops.intersection(Gwcc)
#print(svCompStops)
if len(svCompStops) > 0:
# We need to traverse it in reverse - back up the movement edges
Gwcc_rev = Gwcc.reverse()
for vStop in svCompStops:
# Find all the agents stopped by vStop by following the (reversed) edges
# This traverses a tree - dfs = depth first seearch
iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc_rev, vStop)
lStops = list(iter_stops)
svBlocked.update(lStops)
# the set of all the nodes/agents blocked by this set of stopped nodes
return svBlocked
def find_swaps(self):
""" find all the swap conflicts where two agents are trying to exchange places.
These appear as simple cycles of length 2.
These agents are necessarily deadlocked (since they can't change direction in flatland) -
meaning they will now be stuck for the rest of the episode.
"""
#svStops = self.find_stops2()
llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
return svSwaps
def find_same_dest(self):
""" find groups of agents which are trying to land on the same cell.
ie there is a gap of one cell between them and they are both landing on it.
"""
pass
def block_preds(self, svStops, color="red"):
""" Take a list of stopped agents, and apply a stop color to any chains/trees
of agents trying to head toward those cells.
Count the number of agents blocked, ignoring those which are already marked.
(Otherwise it can double count swaps)
"""
iCount = 0
svBlocked = set()
# The reversed graph allows us to follow directed edges to find affected agents.
Grev = self.G.reverse()
for v in svStops:
# Use depth-first-search to find a tree of agents heading toward the blocked cell.
lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
svBlocked |= set(lvPred)
svBlocked.add(v)
#print("node:", v, "set", svBlocked)
# only count those not already marked
for v2 in [v]+lvPred:
if self.G.nodes[v2].get("color") != color:
self.G.nodes[v2]["color"] = color
iCount += 1
return svBlocked
def find_conflicts(self):
svStops = self.find_stops2() # voluntarily stopped agents - have self-loops
svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions
# Block all swaps and their tree of predessors
self.svDeadlocked = self.block_preds(svSwaps, color="purple")
# Take the union of the above, and find all the predecessors
#svBlocked = self.find_stop_preds(svStops.union(svSwaps))
# Just look for the the tree of preds for each voluntarily stopped agent
svBlocked = self.find_stop_preds(svStops)
# iterate the nodes v with their predecessors dPred (dict of nodes->{})
for (v, dPred) in self.G.pred.items():
# mark any swaps with purple - these are directly deadlocked
#if v in svSwaps:
# self.G.nodes[v]["color"] = "purple"
# If they are not directly deadlocked, but are in the union of stopped + deadlocked
#elif v in svBlocked:
# if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
if v in svBlocked:
self.G.nodes[v]["color"] = "red"
# not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
elif len(dPred)>1:
# if this agent is already red/blocked, ignore. CHECK: why?
# certainly we want to ignore purple so we don't overwrite with red.
if self.G.nodes[v].get("color") in ("red", "purple"):
continue
# if this node has no agent, and >=2 want to enter it.
if self.G.nodes[v].get("agent") is None:
self.G.nodes[v]["color"] = "blue"
# this node has an agent and >=2 want to enter
else:
self.G.nodes[v]["color"] = "magenta"
# predecessors of a contended cell: {agent index -> node}
diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred}
# remove the agent with the lowest index, who wins
iAgWinner = min(diAgCell)
diAgCell.pop(iAgWinner)
# Block all the remaining predessors, and their tree of preds
#for iAg, v in diAgCell.items():
# self.G.nodes[v]["color"] = "red"
# for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
# self.G.nodes[vPred]["color"] = "red"
self.block_preds(diAgCell.values(), "red")
def check_motion(self, iAgent, rcPos):
""" Returns tuple of boolean can the agent move, and the cell it will move into.
If agent position is None, we use a dummy position of (-1, iAgent)
"""
if rcPos is None:
rcPos = (-1, iAgent)
dAttr = self.G.nodes.get(rcPos)
#print("pos:", rcPos, "dAttr:", dAttr)
if dAttr is None:
dAttr = {}
# If it's been marked red or purple then it can't move
if "color" in dAttr:
sColor = dAttr["color"]
if sColor in [ "red", "purple" ]:
return False
dSucc = self.G.succ[rcPos]
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return False
# This agent has a successor
rcNext = self.G.successors(rcPos).__next__()
if rcNext == rcPos: # the agent didn't want to move
return False
# The agent wanted to move, and it can
return True
def render(omc:MotionCheck, horizontal=True):
try:
oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
oAG.layout("dot")
sDot = oAG.to_string()
if horizontal:
sDot = sDot.replace('{', '{ rankdir="LR" ')
#return oAG.draw(format="png")
# This returns a graphviz object which implements __repr_svg
return gv.Source(sDot)
except ImportError as oError:
print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs")
return None
class ChainTestEnv(object):
""" Just for testing agent chains
"""
def __init__(self, omc:MotionCheck):
self.iAgNext = 0
self.iRowNext = 1
self.omc = omc
def addAgent(self, rc1, rc2, xlabel=None):
self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel)
self.iAgNext+=1
def addAgentToRow(self, c1, c2, xlabel=None):
self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel)
def create_test_chain(self,
nAgents:int,
rcVel:Tuple[int] = (0,1),
liStopped:List[int]=[],
xlabel=None):
""" create a chain of agents
"""
lrcAgPos = [ (self.iRowNext, i * rcVel[1]) for i in range(nAgents) ]
for iAg, rcPos in zip(range(nAgents), lrcAgPos):
if iAg in liStopped:
rcVel1 = (0,0)
else:
rcVel1 = rcVel
self.omc.addAgent(iAg+self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1]) )
if xlabel:
self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel
self.iAgNext += nAgents
self.iRowNext += 1
def nextRow(self):
self.iRowNext+=1
def create_test_agents(omc:MotionCheck):
# blocked chain
omc.addAgent(1, (1,2), (1,3))
omc.addAgent(2, (1,3), (1,4))
omc.addAgent(3, (1,4), (1,5))
omc.addAgent(31, (1,5), (1,5))
# unblocked chain
omc.addAgent(4, (2,1), (2,2))
omc.addAgent(5, (2,2), (2,3))
# blocked short chain
omc.addAgent(6, (3,1), (3,2))
omc.addAgent(7, (3,2), (3,2))
# solitary agent
omc.addAgent(8, (4,1), (4,2))
# solitary stopped agent
omc.addAgent(9, (5,1), (5,1))
# blocked short chain (opposite direction)
omc.addAgent(10, (6,4), (6,3))
omc.addAgent(11, (6,3), (6,3))
# swap conflict
omc.addAgent(12, (7,1), (7,2))
omc.addAgent(13, (7,2), (7,1))
def create_test_agents2(omc:MotionCheck):
# blocked chain
cte = ChainTestEnv(omc)
cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain")
cte.create_test_chain(4, xlabel="running\nchain")
cte.create_test_chain(2, liStopped = [1], xlabel="stopped \nshort\n chain")
cte.addAgentToRow(1, 2, "swap")
cte.addAgentToRow(2, 1)
cte.nextRow()
cte.addAgentToRow(1, 2, "chain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nstop")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 4)
cte.addAgentToRow(5, 6)
cte.addAgentToRow(6, 7)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 3)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.nextRow()
cte.addAgentToRow(1, 2, "Land on\nSame")
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "chains\nonto\nsame")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.addAgentToRow(7, 6)
cte.nextRow()
cte.addAgentToRow(1, 2, "3-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.nextRow()
if False:
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "4-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.addAgent((cte.iRowNext-1, 2), (cte.iRowNext, 2))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tee")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgent((cte.iRowNext+1, 3), (cte.iRowNext, 3))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tree")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
r1 = cte.iRowNext
r2 = cte.iRowNext+1
r3 = cte.iRowNext+2
cte.addAgent((r2, 3), (r1, 3))
cte.addAgent((r2, 2), (r2, 3))
cte.addAgent((r3, 2), (r2, 3))
cte.nextRow()
def test_agent_following():
omc = MotionCheck()
create_test_agents2(omc)
svStops = omc.find_stops()
svBlocked = omc.find_stop_preds()
llvSwaps = omc.find_swaps()
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
print(list(svBlocked))
lvCells = omc.G.nodes()
lColours = [ "magenta" if v in svStops
else "red" if v in svBlocked
else "purple" if v in svSwaps
else "lightblue"
for v in lvCells ]
dPos = dict(zip(lvCells, lvCells))
nx.draw(omc.G,
with_labels=True, arrowsize=20,
pos=dPos,
node_color = lColours)
def main():
test_agent_following()
if __name__=="__main__":
main()
from itertools import starmap from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np import numpy as np
from attr import attrs, attrib, Factory import warnings
@attrs from typing import Tuple, Optional, NamedTuple, List
class EnvAgentStatic(object):
""" EnvAgentStatic - Stores initial position, direction and target.
This is like static data for the environment - it's where an agent starts,
rather than where it is at the moment.
The target should also be stored here.
"""
position = attrib()
direction = attrib()
target = attrib()
moving = attrib(default=False)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
@classmethod from attr import attr, attrs, attrib, Factory
def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
def to_list(self): from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
# I can't find an expression which works on both tuples, lists and ndarrays from flatland.envs.step_utils.action_saver import ActionSaver
# which converts them all to a list of native python ints. from flatland.envs.step_utils.speed_counter import SpeedCounter
lPos = self.position from flatland.envs.step_utils.state_machine import TrainStateMachine
if type(lPos) is np.ndarray: from flatland.envs.step_utils.states import TrainState
lPos = lPos.tolist() from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
lTarget = self.target Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
if type(lTarget) is np.ndarray: ('initial_direction', Grid4TransitionsEnum),
lTarget = lTarget.tolist() ('direction', Grid4TransitionsEnum),
('target', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('handle', int),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data]
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs @attrs
class EnvAgent(EnvAgentStatic): class EnvAgent:
""" EnvAgent - replace separate agent_* lists with a single list # INIT FROM HERE IN _from_line()
of agent objects. The EnvAgent represent's the environment's view initial_position = attrib(type=Tuple[int, int])
of the dynamic agent state. initial_direction = attrib(type=Grid4TransitionsEnum)
We are duplicating target in the EnvAgent, which seems simpler than direction = attrib(type=Grid4TransitionsEnum)
forcing the env to refer to it in the EnvAgentStatic target = attrib(type=Tuple[int, int])
""" moving = attrib(default=False, type=bool)
# NEW : EnvAgent - Schedule properties
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
handle = attrib(default=None) handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
# used in rendering
old_direction = attrib(default=None) old_direction = attrib(default=None)
old_position = attrib(default=None) old_position = attrib(default=None)
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data]
@classmethod def reset(self):
def from_static(cls, oStatic):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
""" """
return EnvAgent(*oStatic.__dict__, handle=0) Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
"""
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
self.old_position = None
self.old_direction = None
self.moving = False
self.arrival_time = None
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
@classmethod @classmethod
def list_from_static(cls, lEnvAgentStatic, handles=None): def from_line(cls, line: Line):
""" Create an EnvAgent from the EnvAgentStatic, """ Create a list of EnvAgent from lists of positions, directions and targets
copying all the fields, and adding handle with the default 0.
""" """
if handles is None: num_agents = len(line.agent_positions)
handles = range(len(lEnvAgentStatic))
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} \n \
initial_direction: {self.initial_direction} \n \
position: {self.position} \n \
direction: {self.direction} \n \
target: {self.target} \n \
old_position: {self.old_position} \n \
old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} \n \
latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")
return [EnvAgent(**oEAS.__dict__, handle=handle)
for handle, oEAS in zip(handles, lEnvAgentStatic)]
from collections import deque
from typing import List, Optional
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
class DistanceMap:
def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
self.env_height = env_height
self.env_width = env_width
self.distance_map = None
self.agents_previous_computation = None
self.reset_was_called = False
self.agents: List[EnvAgent] = agents
self.rail: Optional[GridTransitionMap] = None
def set(self, distance_map: np.ndarray):
"""
Set the distance map
"""
self.distance_map = distance_map
def get(self) -> np.ndarray:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_computation is None and self.distance_map is not None:
compute_distance_map = False
if compute_distance_map:
self._compute(self.agents, self.rail)
elif self.distance_map is None:
self._compute(self.agents, self.rail)
return self.distance_map
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents: List[EnvAgent] = agents
self.rail = rail
self.env_height = rail.height
self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
This function computes the distance maps for each unique target. Thus if several targets are the same
we only compute the distance for them once and copy to all targets with same position.
:param agents: All the agents in the environment, independent of their current status
:param rail: The rail transition map
"""
self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
computed_targets = []
for i, agent in enumerate(agents):
if agent.target not in computed_targets:
self._distance_map_walker(rail, agent.target, i)
else:
# just copy the distance map form other agent with same target (performance)
self.distance_map[i, :, :, :] = np.copy(
self.distance_map[computed_targets.index(agent.target), :, :, :])
computed_targets.append(agent.target)
def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
from typing import Tuple
# Adrian Egli / Michel Marti performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def fast_delete(lis: list, index) -> list:
new_list = lis.copy()
new_list.pop(index)
return new_list
def fast_where(binary_iterable):
return [index for index, element in enumerate(binary_iterable) if element != 0]
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
"""
Parameters
----------
file_name : str
The pickle file.
load_from_package : str
The python module to import from. Example: 'env_data.tests'
This requires that there are `__init__.py` files in the folder structure we load the file from.
obs_builder_object: ObservationBuilder
The obs builder for the `RailEnv` that is created.
Returns
-------
RailEnv
The environment loaded from the pickle file.
"""
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.grid4_generators_utils import connect_rail
from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
def empty_rail_generator():
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width, height, num_agents=0, num_resets=0):
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, [], [], [], []
return generator
def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a start and goal position
# - validate min/max distance allowed
# - validate that start/goals are not placed too close to other start/goals
# - draw a rail from [start,goal]
# - if rail crosses existing rail then validate new connection
# - possibility that this fails to create a path to goal
# - on failure generate new start/goal
#
# step 2:
# - add more rails to map randomly between cells that have rails
# - validate all new rails, on failure don't add new rails
#
# step 3:
# - return transition map + list of [start_pos, start_dir, goal_pos] points
#
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
if not all_ok:
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
agents_position = [sg[0] for sg in start_goal[:num_agents]]
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
-------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets=0):
rail_env_transitions = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions)
for r in range(height):
for c in range(width):
rail_spec_of_cell = rail_spec[r][c]
index_basic_type_of_cell_ = rail_spec_of_cell[0]
rotation_cell_ = rail_spec_of_cell[1]
if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions):
print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_)
return []
basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_]
effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
rail.set_transitions((r, c), effective_transition_cell)
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail,
num_agents)
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def rail_from_file(filename):
"""
Utility to load pickle file
Parameters
-------
input_file : Pickle file generated by env.save() or editor
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets):
rail_env_transitions = RailEnvTransitions()
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
grid = np.array(data[b"grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid
# agents are always reset as not moving
agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
# setup with loaded data
agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
if b"distance_maps" in data.keys():
distance_maps = data[b"distance_maps"]
if len(distance_maps) > 0:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(
agents_position), distance_maps
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def rail_from_grid_transition_map(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions
are legit; if no cell can be filled in without violating some
transitions, pick one among those that can satisfy most transitions
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
incompatible.
- keep trying for a total number of insertions
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)): # don't include dead-ends
if t_utils.transitions[i] == int('0010000000000000', 2):
continue
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
return_rail,
num_agents)
return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
...@@ -7,18 +7,43 @@ a GridTransitionMap object. ...@@ -7,18 +7,43 @@ a GridTransitionMap object.
import numpy as np import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point, get_new_position
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
def connect_rail(rail_trans, rail_array, start, end): from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
rail_trans: RailEnvTransitions,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
flip_start_node_trans: bool = False, flip_end_node_trans: bool = False,
respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None,
avoid_rail=False) -> IntVector2DArray:
""" """
Creates a new path [start,end] in rail_array, based on rail_trans. Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions.
:param avoid_rail:
:param rail_trans: basic rail transition object
:param grid_map: grid map
:param start: start position of rail
:param end: end position of rail
:param flip_start_node_trans: make valid start position by adding dead-end, empty start if False
:param flip_end_node_trans: make valid end position by adding dead-end, empty end if False
:param respect_transition_validity: Only draw rail maps if legal rail elements can be use, False, draw line without
respecting rail transitions.
:param a_star_distance_function: Define what distance function a-star should use
:param forbidden_cells: cells to avoid when drawing rail. Rail cannot go through this list of cells
:return: List of cells in the path
""" """
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end) path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, avoid_rail,
respect_transition_validity,
forbidden_cells)
if len(path) < 2: if len(path) < 2:
return [] return []
current_dir = get_direction(path[0], path[1]) current_dir = get_direction(path[0], path[1])
end_pos = path[-1] end_pos = path[-1]
for index in range(len(path) - 1): for index in range(len(path) - 1):
...@@ -26,12 +51,15 @@ def connect_rail(rail_trans, rail_array, start, end): ...@@ -26,12 +51,15 @@ def connect_rail(rail_trans, rail_array, start, end):
new_pos = path[index + 1] new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos) new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos] new_trans = grid_map.grid[current_pos]
if index == 0: if index == 0:
if new_trans == 0: if new_trans == 0:
# end-point # end-point
# need to flip direction because of how end points are defined if flip_start_node_trans:
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) # need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
new_trans = 0
else: else:
# into existing rail # into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
...@@ -40,96 +68,108 @@ def connect_rail(rail_trans, rail_array, start, end): ...@@ -40,96 +68,108 @@ def connect_rail(rail_trans, rail_array, start, end):
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path # set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans grid_map.grid[current_pos] = new_trans
if new_pos == end_pos: if new_pos == end_pos:
# setup end pos setup # setup end pos setup
new_trans_e = rail_array[end_pos] new_trans_e = grid_map.grid[end_pos]
if new_trans_e == 0: if new_trans_e == 0:
# end-point # end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) if flip_end_node_trans:
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
new_trans_e = 0
else: else:
# into existing rail # into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
rail_array[end_pos] = new_trans_e grid_map.grid[end_pos] = new_trans_e
current_dir = new_dir current_dir = new_dir
return path return path
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D, rail_trans: RailEnvTransitions) -> IntVector2DArray:
""" """
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). Generates a straight rail line from start cell to end cell.
Diagonal lines are not allowed
TODO: add extensive documentation, as users may need this function to simplify their custom level generators. :param rail_trans:
:param grid_map:
:param start: Cell coordinates for start of line
:param end: Cell coordinates for end of line
:return: A list of all cells in the path
""" """
def _path_exists(rail, start, direction, end): if not (start[0] == end[0] or start[1] == end[1]):
# BFS - Check if a path exists between the 2 nodes print("No straight line possible!")
return []
visited = set()
stack = [(start, direction)] direction = direction_to_point(start, end)
while stack:
node = stack.pop() if direction is Grid4TransitionsEnum.NORTH or direction is Grid4TransitionsEnum.SOUTH:
if node[0][0] == end[0] and node[0][1] == end[1]: start_row = min(start[0], end[0])
return 1 end_row = max(start[0], end[0]) + 1
if node not in visited: rows = np.arange(start_row, end_row)
visited.add(node) length = np.abs(end[0] - start[0]) + 1
moves = rail.get_transitions(node[0][0], node[0][1], node[1]) cols = np.repeat(start[1], length)
for move_index in range(4):
if moves[move_index]: else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST
stack.append((get_new_position(node[0], move_index), start_col = min(start[1], end[1])
move_index)) end_col = max(start[1], end[1]) + 1
cols = np.arange(start_col, end_col)
# If cell is a dead-end, append previous node with reversed length = np.abs(end[1] - start[1]) + 1
# orientation! rows = np.repeat(start[0], length)
nbits = 0
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
return agents_position, agents_direction, agents_target path = list(zip(rows, cols))
for cell in path:
transition = grid_map.grid[cell]
transition = rail_trans.set_transition(transition, direction, direction, 1)
transition = rail_trans.set_transition(transition, mirror(direction), mirror(direction), 1)
grid_map.grid[cell] = transition
return path
def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, rail_trans: RailEnvTransitions):
"""
Fix inner city nodes by connecting it to its neighbouring parallel track
:param grid_map:
:param inner_node_pos: inner city node to fix
:param rail_trans:
:return:
"""
corner_directions = []
for direction in range(4):
tmp_pos = get_new_position(inner_node_pos, direction)
if grid_map.grid[tmp_pos] > 0:
corner_directions.append(direction)
if len(corner_directions) == 2:
transition = 0
transition = rail_trans.set_transition(transition, mirror(corner_directions[0]), corner_directions[1], 1)
transition = rail_trans.set_transition(transition, mirror(corner_directions[1]), corner_directions[0], 1)
grid_map.grid[inner_node_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[0])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[0], mirror(corner_directions[0]), 1)
grid_map.grid[tmp_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[1])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[1], mirror(corner_directions[1]),
1)
grid_map.grid[tmp_pos] = transition
return
def align_cell_to_city(city_center, city_orientation, cell):
"""
Alig all cells to face the city center along the city orientation
@param city_center: Center needed for orientation
@param city_orientation: Orientation of the city
@param cell: Cell we would like to orient
:@return: Orientation of cell towards city center along axis of city orientation
"""
if city_orientation % 2 == 0:
return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
else:
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1