Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2459 additions and 367 deletions
from ray import tune
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import numpy as np
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
def env_creator(args):
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
# __sphinx_doc_end__
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
import supersuit as ss
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import fnmatch
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True,
config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95,
n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
batch_size=150, seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
# __sphinx_doc_end__
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment=rail_env, use_renderer=True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:", completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
run.join()
from typing import List
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.step_utils import env_utils
class Deadlock_Checker:
def __init__(self, env):
self.env = env
self.deadlocked_agents = []
self.immediate_deadlocked = []
def reset(self) -> None:
self.deadlocked_agents = []
self.immediate_deadlocked = []
# an immediate deadlock consists of two trains "trying to pass through each other".
# An agent may have a free possible transition, but took a bad action and "ran into another train". This is now a deadlock, and the other free
# direction can not be chosen anymore!
def check_immediate_deadlocks(self, action_dict) -> List[EnvAgent]:
"""
output: list of agents who are in immediate deadlocks
"""
env = self.env
newly_deadlocked_agents = []
# TODO: check restrictions to relevant agents (status ACTIVE, etc.)
relevant_agents = [agent for agent in env.agents if agent.state != TrainState.DONE and agent.position is not None]
for agent in relevant_agents:
other_agents = [other_agent for other_agent in env.agents if other_agent != agent] # check if this is a good test for inequality. Maybe use handles...
# get the transitions the agent can take from his current position and orientation
# an indicator array of the form e.g. (0,1,1,0) meaning that he can only go to east and south, not to north and west.
possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
#print(f"possible transitions: {possible_transitions}")
# the directions are: 0(north), 1(east), 2(south) and 3(west)
#possible_directions = [direction for direction, flag in enumerate(possible_transitions) if flag == 1]
#print(f"possible directions: {possible_directions}")
################### only consider direction for actually chosen action ###############################
new_position, new_direction = env_utils.apply_action_independent(action=action_dict[agent.handle], rail=env.rail, position=agent.position, direction=agent.direction)
#assert new_direction in possible_directions, "Error, action leads to impossible direction"
assert new_position == get_new_position(agent.position, new_direction), "Error, something is wrong with new position"
opposed_agent_id = env.agent_positions[new_position] # TODO: check that agent_positions now works correctly in flatland V3 (i.e. gets correctly updated...)
# agent_positions[cell] is an agent_id if an agent is there, otherwise -1.
if opposed_agent_id != -1:
opposed_agent = env.agents[opposed_agent_id]
# other agent with opposing direction is in the way --> deadlock
# an opposing direction means having a different direction than our agent would have if he moved to the new cell. (180 degrees or 90 degrees to our agent)
if opposed_agent.direction != new_direction:
if agent not in newly_deadlocked_agents: # to avoid duplicates
newly_deadlocked_agents.append(agent)
if opposed_agent not in newly_deadlocked_agents: # to avoid duplicates
newly_deadlocked_agents.append(opposed_agent)
self.immediate_deadlocked = newly_deadlocked_agents
return newly_deadlocked_agents
# main method to check for all deadlocks
def check_deadlocks(self, action_dict) -> List[EnvAgent]:
env = self.env
relevant_agents = [agent for agent in env.agents if agent.state != TrainState.DONE and agent.position is not None]
immediate_deadlocked = self.check_immediate_deadlocks(action_dict)
self.immediate_deadlocked = immediate_deadlocked
deadlocked = immediate_deadlocked[:]
# now we have to "close": each train which is blocked by another deadlocked train becomes deadlocked itself.
still_changing = True
while still_changing:
still_changing = False # will be overwritten below if a change did occur
# check if for any agent, there is a new deadlock found
for agent in relevant_agents:
#possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
#print(f"possible transitions: {possible_transitions}")
# the directions are: 0 (north), 1(east), 2(south) and 3(west)
#possible_directions = [direction for direction, flag in enumerate(possible_transitions) if flag == 1]
#print(f"possible directions: {possible_directions}")
new_position, new_direction = env_utils.apply_action_independent(action=action_dict[agent.handle], rail=env.rail, position=agent.position, direction=agent.direction)
#assert new_direction in possible_directions, "Error, action leads to impossible direction"
assert new_position == get_new_position(agent.position, new_direction), "Error, something is wrong with new position"
opposed_agent_id = env.agent_positions[new_position]
if opposed_agent_id != -1: # there is an opposed agent there
opposed_agent = env.agents[opposed_agent_id]
if opposed_agent in deadlocked:
if agent not in deadlocked: # to avoid duplicates
deadlocked.append(agent)
still_changing = True
self.deadlocked_agents = deadlocked
return deadlocked
\ No newline at end of file
import logging
import random
import numpy as np
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.fast_methods import fast_count_nonzero, fast_argmax
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]
if agent.status in [TrainState.WAITING, TrainState.READY_TO_DEPART,
TrainState.MALFUNCTION_OFF_MAP]:
agent_virtual_position = agent.initial_position
elif agent.status in [TrainState.MALFUNCTION, TrainState.MOVING, TrainState.STOPPED]:
agent_virtual_position = agent.position
elif agent.status == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
if agent.position:
possible_transitions = env.rail.get_transitions(
*agent.position, agent.direction)
else:
possible_transitions = env.rail.get_transitions(
*agent.initial_position, agent.direction)
num_transitions = fast_count_nonzero(possible_transitions)
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(
agent_virtual_position, direction)
min_distances.append(
distance_map[handle, new_position[0],
new_position[1], direction])
else:
min_distances.append(np.inf)
if num_transitions == 1:
observation = [0, 1, 0]
elif num_transitions == 2:
idx = np.argpartition(np.array(min_distances), 2)
observation = [0, 0, 0]
observation[idx[0]] = 1
return fast_argmax(observation) + 1
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
random.seed(random_seed)
width = 30
height = 30
nr_trains = 5
max_num_cities = 4
grid_mode = False
max_rails_between_cities = 2
max_rails_in_city = 3
malfunction_rate = 0
malfunction_min_duration = 0
malfunction_max_duration = 0
rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
speed_ratio_map = None
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
random.seed(random_seed)
size = random.randint(0, 5)
width = 20 + size * 5
height = 20 + size * 5
nr_cities = 2 + size // 2 + random.randint(0, 2)
nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5)) # , 10 + random.randint(0, 10))
max_rails_between_cities = 2
max_rails_in_cities = 3 + random.randint(0, size)
malfunction_rate = 30 + random.randint(0, 100)
malfunction_min_duration = 3 + random.randint(0, 7)
malfunction_max_duration = 20 + random.randint(0, 80)
rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_cities)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def sparse_env_small(random_seed, observation_builder):
width = 30 # With of map
height = 30 # Height of map
nr_trains = 2 # Number of trains that have an assigned task in the env
cities_in_map = 3 # Number of cities where agents can start or end
seed = 10 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
line_generator = sparse_rail_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
rail_env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True)
return rail_env
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
def perc_completion(env):
tasks_finished = 0
if hasattr(env, "agents_data"):
agent_data = env.agents_data
else:
agent_data = env.agents
for current_agent in agent_data:
if current_agent.status == TrainState.DONE:
tasks_finished += 1
return 100 * np.mean(tasks_finished / max(
1, len(agent_data)))
from collections import defaultdict
from typing import Dict, Tuple
from flatland.contrib.utils.deadlock_checker import Deadlock_Checker
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.step_utils.states import TrainState
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.state == TrainState.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
else:
print("no action possible!")
print("agent state: ", agent.state)
# NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...)
return [(RailEnvActions.DO_NOTHING, 0)] * 2
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
print(f"possible transitions: {possible_transitions}")
distance_map = env.distance_map.get()[handle]
possible_steps = []
for movement in list(range(4)):
if possible_transitions[movement]:
if movement == agent.direction:
action = RailEnvActions.MOVE_FORWARD
elif movement == (agent.direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?")
action = RailEnvActions.MOVE_FORWARD
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
# if there is only one path to target, this is both the shortest one and the second shortest path.
if len(possible_steps) == 1:
return possible_steps * 2
else:
return possible_steps
class RailEnvWrapper:
def __init__(self, env:RailEnv):
self.env = env
assert self.env is not None
assert self.env.rail is not None, "Reset original environment first!"
assert self.env.agents is not None, "Reset original environment first!"
assert len(self.env.agents) > 0, "Reset original environment first!"
# @property
# def number_of_agents(self):
# return self.env.number_of_agents
# @property
# def agents(self):
# return self.env.agents
# @property
# def _seed(self):
# return self.env._seed
# @property
# def obs_builder(self):
# return self.env.obs_builder
def __getattr__(self, name):
try:
return super().__getattr__(self,name)
except:
"""Expose any other attributes of the underlying environment."""
return getattr(self.env, name)
@property
def rail(self):
return self.env.rail
@property
def width(self):
return self.env.width
@property
def height(self):
return self.env.height
@property
def agent_positions(self):
return self.env.agent_positions
def get_num_agents(self):
return self.env.get_num_agents()
def get_agent_handles(self):
return self.env.get_agent_handles()
def step(self, action_dict: Dict[int, RailEnvActions]):
return self.env.step(action_dict)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return obs, info
class ShortestPathActionWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv):
super().__init__(env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# input: action dict with actions in [0, 1, 2].
transformed_action_dict = {}
for agent_id, action in action_dict.items():
if action == 0:
transformed_action_dict[agent_id] = action
else:
#assert action in [1, 2]
#assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
#assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
obs, rewards, dones, info = self.env.step(transformed_action_dict)
return obs, rewards, dones, info
def find_all_cells_where_agent_can_choose(env: RailEnv):
"""
input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
WHICH HAS BEEN RESET ALREADY!
(o.w., we call env.rail, which is None before reset(), and crash.)
"""
switches = []
switches_neighbors = []
directions = list(range(4))
for h in range(env.height):
for w in range(env.width):
pos = (h, w)
is_switch = False
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
break
if is_switch:
# Add all neighbouring rails, if pos is a switch
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
for movement in directions:
if possible_transitions[movement]:
switches_neighbors.append(get_new_position(pos, movement))
decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.switches = None
self.switches_neighbors = None
self.decision_cells = None
self.skipped_rewards = defaultdict(list)
# sets initial values for switches, decision_cells, etc.
self.reset_cells()
def on_decision_cell(self, agent: EnvAgent) -> bool:
return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
def on_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches
def next_to_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches_neighbors
def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# need to initialize i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["state"] = dict()
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self.reset_cells()
return obs, info
class DeadlockWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv, deadlock_reward=-100) -> None:
super().__init__(env)
self.deadlock_reward = deadlock_reward
self.deadlock_checker = Deadlock_Checker(env=self.env)
@property
def deadlocked_agents(self):
return self.deadlock_checker.deadlocked_agents
@property
def immediate_deadlocks(self):
return [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
# make sure to assign the deadlock reward only once to each deadlocked agent...
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# agents which are already deadlocked from previous steps
already_deadlocked_ids = [agent.handle for agent in self.deadlocked_agents]
# step environment
obs, rewards, dones, info = self.env.step(action_dict)
# compute new list of deadlocked agents (ids) after stepping the environment
deadlocked_agents = self.deadlock_checker.check_deadlocks(action_dict) # also stored in self.deadlocked_checker.deadlocked_agents
deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents]
# immediate deadlocked ids only used for prints
immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
print(f"immediate deadlocked: {immediate_deadlocked_ids}")
print(f"total deadlocked: {deadlocked_agents_ids}")
newly_deadlocked_agents_ids = [agent_id for agent_id in deadlocked_agents_ids if agent_id not in already_deadlocked_ids]
# assign deadlock rewards
for agent_id in newly_deadlocked_agents_ids:
print(f"assigning deadlock reward of {self.deadlock_reward} to agent {agent_id}")
rewards[agent_id] = self.deadlock_reward
return obs, rewards, dones, info
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list
obs, info = super().reset(**kwargs)
return obs, info
......@@ -11,11 +11,11 @@ class Environment:
Derived environments should implement the following attributes:
action_space: tuple with the dimensions of the actions to be passed to the step method
observation_space: tuple with the dimensions of the observations returned by reset and step
Agents are identified by agent ids (handles).
Examples:
>>> obs = env.reset()
>>> obs, info = env.reset()
>>> print(obs)
{
"train_0": [2.4, 1.6],
......@@ -40,18 +40,19 @@ class Environment:
"train_0": {}, # info for train_0
"train_1": {}, # info for train_1
}
"""
def __init__(self):
self.action_space = ()
self.observation_space = ()
pass
def reset(self):
"""
Resets the env and returns observations from agents in the environment.
Returns:
Returns
-------
obs : dict
New observations for each agent.
"""
......@@ -66,7 +67,7 @@ class Environment:
The returns are dicts mapping from agent_id strings to values.
Parameters
-------
----------
action_dict : dict
Dictionary of actions to execute, indexed by agent id.
......
......@@ -2,27 +2,29 @@
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ `reset()` is called after each environment reset, to allow for pre-computing relevant data.
+ `get()` is called whenever an observation has to be computed, potentially for each agent independently in case of \
multi-agent environments.
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
from typing import Optional, List
import numpy as np
from flatland.core.env import Environment
class ObservationBuilder:
"""
ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
observations.
"""
def __init__(self):
self.observation_space = ()
self.env = None
def _set_env(self, env):
self.env = env
def set_env(self, env: Environment):
self.env: Environment = env
def reset(self):
"""
......@@ -30,35 +32,37 @@ class ObservationBuilder:
"""
raise NotImplementedError()
def get_many(self, handles=[]):
def get_many(self, handles: Optional[List[int]] = None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
Parameters
-------
handles : list of handles (optional)
----------
handles : list of handles, optional
List with the handles of the agents for which to compute the observation vector.
Returns
-------
function
A dictionary of observation structures, specific to the corresponding environment, with handles from
`handles' as keys.
`handles` as keys.
"""
observations = {}
if handles is None:
handles = []
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle=0):
def get(self, handle: int = 0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Called whenever an observation has to be computed for the `env` environment, possibly
for each agent independently (agent id `handle`).
Parameters
-------
handle : int (optional)
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......@@ -82,16 +86,13 @@ class DummyObservationBuilder(ObservationBuilder):
"""
def __init__(self):
self.observation_space = ()
def _set_env(self, env):
self.env = env
super().__init__()
def reset(self):
pass
def get_many(self, handles=[]):
def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True
def get(self, handle=0):
def get(self, handle: int = 0) -> bool:
return True
......@@ -3,11 +3,12 @@ PredictionBuilder objects are objects that can be passed to environments designe
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ `reset()` is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an step has to be computed, potentially for each agent independently in
+ `get()` is called whenever an step has to be computed, potentially for each agent independently in \
case of multi-agent environments.
"""
from flatland.core.env import Environment
class PredictionBuilder:
......@@ -18,8 +19,9 @@ class PredictionBuilder:
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
self.env = None
def _set_env(self, env):
def set_env(self, env: Environment):
self.env = env
def reset(self):
......@@ -28,16 +30,13 @@ class PredictionBuilder:
"""
pass
def get(self, custom_args=None, handle=0):
def get(self, handle: int = 0):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Implementation-dependent custom arguments, see the sub-classes.
handle : int (optional)
----------
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......
from enum import IntEnum
from typing import Type
from functools import lru_cache
from typing import Type, List
import numpy as np
from flatland.core.transitions import Transitions
# maxsize=None can be used because the number of possible transition is limited (16 bit encoded) and the
# direction/orientation is also limited (2bit). Where the 16bit are only sparse used = number of rail types
# Those methods can be cached -> the are independant of the railways (env)
@lru_cache(maxsize=128)
def fast_grid4_get_transitions(cell_transition, orientation):
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
@lru_cache(maxsize=128)
def fast_grid4_get_transition(cell_transition, orientation, direction):
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
@lru_cache(maxsize=128)
def fast_grid4_set_transitions(cell_transition, orientation, new_transitions):
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_remove_deadends(cell_transition):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
maskDeadEnds = Grid4Transitions.maskDeadEnds()
cell_transition &= cell_transition & (~maskDeadEnds) & 0xffff
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_rotate_transition(cell_transition, rotation=0):
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = fast_grid4_get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = fast_grid4_set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (
value >> (rotation * 4))
cell_transition = value
return cell_transition
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
......@@ -24,9 +82,9 @@ class Grid4Transitions(Transitions):
"""
Grid4Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 16 bits.
Whether a transition is allowed or not depends on which direction an agent
......@@ -57,8 +115,11 @@ class Grid4Transitions(Transitions):
# row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
# These bits represent all the possible dead ends
self.maskDeadEnds = 0b0010000110000100
# These bits represent all the possible dead ends
@staticmethod
@lru_cache()
def maskDeadEnds():
return 0b0010000110000100
def get_type(self):
return np.uint16
......@@ -67,8 +128,8 @@ class Grid4Transitions(Transitions):
"""
Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent oriented
in direction `orientation' and inside a cell with
transitions `cell_transition'.
in direction `orientation` and inside a cell with
transitions `cell_transition`.
Parameters
----------
......@@ -83,16 +144,15 @@ class Grid4Transitions(Transitions):
List of the validity of transitions in the cell.
"""
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
return fast_grid4_get_transitions(cell_transition, orientation)
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition'. A new `cell_transition' is returned with
the specified bits replaced by `new_transitions'.
oriented in direction `orientation` and inside a cell with transitions
`cell_transition'. A new `cell_transition` is returned with
the specified bits replaced by `new_transitions`.
Parameters
----------
......@@ -107,28 +167,17 @@ class Grid4Transitions(Transitions):
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
return fast_grid4_set_transitions(cell_transition, orientation, new_transitions)
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
......@@ -146,13 +195,14 @@ class Grid4Transitions(Transitions):
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
return fast_grid4_get_transition(cell_transition, orientation, direction)
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
def set_transition(self, cell_transition, orientation, direction, new_transition,
remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
......@@ -171,8 +221,8 @@ class Grid4Transitions(Transitions):
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
if new_transition:
......@@ -181,7 +231,7 @@ class Grid4Transitions(Transitions):
cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
if remove_deadends:
cell_transition = self.remove_deadends(cell_transition)
cell_transition = fast_grid4_remove_deadends(cell_transition)
return cell_transition
......@@ -196,7 +246,7 @@ class Grid4Transitions(Transitions):
16 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees.
`cell_transition` by. I.e., rotation={0, 90, 180, 270} degrees.
Returns
-------
......@@ -206,27 +256,18 @@ class Grid4Transitions(Transitions):
"""
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
cell_transition = value
return cell_transition
return fast_grid4_rotate_transition(cell_transition, rotation)
def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
return Grid4TransitionsEnum
def has_deadend(self, cell_transition):
@staticmethod
@lru_cache()
def has_deadend(cell_transition):
"""
Checks if one entry can only by exited by a turn-around.
"""
if cell_transition & self.maskDeadEnds > 0:
if cell_transition & Grid4Transitions.maskDeadEnds() > 0:
return True
else:
return False
......@@ -235,5 +276,9 @@ class Grid4Transitions(Transitions):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
return cell_transition
return fast_grid4_remove_deadends(cell_transition)
@staticmethod
@lru_cache()
def get_entry_directions(cell_transition) -> List[int]:
return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
from flatland.core.grid.grid4_utils import validate_new_transition
import numpy as np
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
from flatland.core.grid.grid_utils import IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.ordered_set import OrderedSet
class AStarNode():
class AStarNode:
"""A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None):
def __init__(self, pos: IntVector2D, parent=None):
self.parent = parent
self.pos = pos
self.g = 0
self.h = 0
self.f = 0
self.pos: IntVector2D = pos
self.g = 0.0
self.h = 0.0
self.f = 0.0
def __eq__(self, other):
"""
Parameters
----------
other : AStarNode
"""
return self.pos == other.pos
def __hash__(self):
......@@ -25,16 +37,35 @@ class AStarNode():
self.f = other.f
def a_star(rail_trans, rail_array, start, end):
def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, avoid_rails=False,
respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
"""
:param avoid_rails:
:param grid_map: Grid Map where the path is found in
:param start: Start positions as (row,column)
:param end: End position as (row,column)
:param a_star_distance_function: Define the distance function to use as heuristc:
-get_euclidean_distance
-get_manhattan_distance
-get_chebyshev_distance
:param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map.
- True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found
- False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards
:param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map
:return: IF a path is found a ordered list of al cells in path is returned
"""
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape = rail_array.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_nodes = set()
closed_nodes = set()
rail_shape = grid_map.grid.shape
start_node = AStarNode(start, None)
end_node = AStarNode(end, None)
open_nodes = OrderedSet()
closed_nodes = OrderedSet()
open_nodes.add(start_node)
while len(open_nodes) > 0:
......@@ -58,6 +89,7 @@ def a_star(rail_trans, rail_array, start, end):
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
return path[::-1]
......@@ -67,17 +99,28 @@ def a_star(rail_trans, rail_array, start, end):
prev_pos = current_node.parent.pos
else:
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
# update the "current" pos
node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos)
# is node_pos inside the grid?
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue
# validate positions
if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
#
if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos,
end_node.pos) and respect_transition_validity:
continue
# create new node
new_node = AStarNode(current_node, node_pos)
new_node = AStarNode(node_pos, current_node)
# Skip paths through forbidden regions if they are provided
if forbidden_cells is not None:
if node_pos in forbidden_cells and new_node != start_node and new_node != end_node:
continue
children.append(new_node)
# loop through children
......@@ -87,11 +130,12 @@ def a_star(rail_trans, rail_array, start, end):
continue
# create the f, g, and h values
child.g = current_node.g + 1
# this heuristic favors diagonal paths:
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \# noqa: E800
child.g = current_node.g + 1.0
# this heuristic avoids diagonal paths
child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
if avoid_rails:
child.h = a_star_distance_function(child.pos, end_node.pos) + np.clip(grid_map.grid[child.pos], 0, 1)
else:
child.h = a_star_distance_function(child.pos, end_node.pos)
child.f = child.g + child.h
# already in the open list?
......
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2D
def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
......@@ -9,66 +12,41 @@ def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1]
if diff_0 < 0:
return 0
return Grid4TransitionsEnum.NORTH
if diff_0 > 0:
return 2
return Grid4TransitionsEnum.SOUTH
if diff_1 > 0:
return 1
return Grid4TransitionsEnum.EAST
if diff_1 < 0:
return 3
return Grid4TransitionsEnum.WEST
raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
def mirror(dir):
return (dir + 2) % 4
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = rail_array[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
:param pos2: position we want to know it is facing
:return: direction NESW as int N:0 E:1 S:2 W:3
"""
diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1]))
axis = np.argmax(np.power(diff_vec, 2))
direction = np.sign(diff_vec[axis])
if axis == 0:
if direction > 0:
return Grid4TransitionsEnum.NORTH
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
return Grid4TransitionsEnum.SOUTH
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
if direction > 0:
return Grid4TransitionsEnum.WEST
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
return Grid4TransitionsEnum.EAST
......@@ -20,9 +20,9 @@ class Grid8Transitions(Transitions):
"""
Grid8Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Special case of `Transitions` over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
GridTransitions keeps track of valid transitions supplied as `transitions`
list, each represented as a bitmap of 64 bits.
0=North, 1=North-East, etc.
......@@ -82,8 +82,8 @@ class Grid8Transitions(Transitions):
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
......@@ -106,8 +106,8 @@ class Grid8Transitions(Transitions):
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
......@@ -131,8 +131,8 @@ class Grid8Transitions(Transitions):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
oriented in direction `orientation` and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction`
relative to the current cell.
Parameters
......@@ -150,8 +150,8 @@ class Grid8Transitions(Transitions):
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
of `cell_transition' with `new_transitions`, for the appropriate
`orientation`.
"""
if new_transition:
......@@ -172,7 +172,7 @@ class Grid8Transitions(Transitions):
64 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
`cell_transition` by. I.e., rotation={0, 45, 90, 135, 180,
225, 270, 315} degrees.
Returns
......
from math import isnan
from typing import Tuple, Callable, List, Type
import numpy as np
Vector2D: Type = Tuple[float, float]
IntVector2D: Type = Tuple[int, int]
def position_to_coordinate(depth, positions):
"""Converts coordinates to positions:
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
IntVector2DArray: Type = List[IntVector2D]
IntVector2DArrayArray: Type = List[List[IntVector2D]]
Vector2DArray: Type = List[Vector2D]
Vector2DArrayArray: Type = List[List[Vector2D]]
IntVector2DDistance: Type = Callable[[IntVector2D, IntVector2D], float]
class Vec2dOperations:
@staticmethod
def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
check if node_a and nobe_b are equal
"""
return node_a[0] == node_b[0] and node_a[1] == node_b[1]
@staticmethod
def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a - node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] - node_b[0], node_a[1] - node_b[1]
@staticmethod
def add(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node_a[0] + node_b[0], node_a[1] + node_b[1]
@staticmethod
def make_orthogonal(node: Vector2D) -> Vector2D:
"""
vector operation : rotates the 2D vector +90°
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[1], -node[0]
@staticmethod
def get_norm(node: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return np.sqrt(node[0] * node[0] + node[1] * node[1])
@staticmethod
def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the euclidean norm of the 2d vector
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Euclidean distance
"""
return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a))
@staticmethod
def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the manhattan distance of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
Mahnhattan distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return np.abs(delta[0]) + np.abs(delta[1])
@staticmethod
def get_chebyshev_distance(node_a: Vector2D, node_b: Vector2D) -> float:
"""
calculates the chebyshev norm of the 2d vector
[see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/]
Parameters
----------
node_a
tuple with coordinate (x,y) or 2d vector
node_b
tuple with coordinate (x,y) or 2d vector
Returns
-------
float
the chebyshev distance
"""
delta = (Vec2dOperations.subtract(node_b, node_a))
return max(np.abs(delta[0]), np.abs(delta[1]))
@staticmethod
def normalize(node: Vector2D) -> Tuple[float, float]:
"""
normalize the 2d vector = `v/|v|`
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
n = Vec2dOperations.get_norm(node)
if n > 0.0:
n = 1 / n
return Vec2dOperations.scale(node, n)
@staticmethod
def scale(node: Vector2D, scale: float) -> Vector2D:
"""
scales the 2d vector = node * scale
:param node: tuple with coordinate (x,y) or 2d vector
:param scale: scalar to scale
:return: tuple with coordinate (x,y) or 2d vector
"""
return node[0] * scale, node[1] * scale
@staticmethod
def round(node: Vector2D) -> IntVector2D:
"""
rounds the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return: tuple with coordinate (x,y) or 2d vector
"""
return int(np.round(node[0])), int(np.round(node[1]))
@staticmethod
def ceil(node: Vector2D) -> IntVector2D:
"""
ceiling the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.ceil(node[0])), int(np.ceil(node[1]))
@staticmethod
def floor(node: Vector2D) -> IntVector2D:
"""
floor the x and y coordinate and convert them to an integer values
:param node: tuple with coordinate (x,y) or 2d vector
:return:
tuple with coordinate (x,y) or 2d vector
"""
return int(np.floor(node[0])), int(np.floor(node[1]))
@staticmethod
def bound(node: Vector2D, min_value: float, max_value: float) -> Vector2D:
"""
force the values x and y to be between min_value and max_value
:param node: tuple with coordinate (x,y) or 2d vector
:param min_value: scalar value
:param max_value: scalar value
:return:
tuple with coordinate (x,y) or 2d vector
"""
return max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))
@staticmethod
def rotate(node: Vector2D, rot_in_degree: float) -> Vector2D:
"""
rotate the 2d vector with given angle in degree
:param node: tuple with coordinate (x,y) or 2d vector
:param rot_in_degree: angle in degree
:return:
tuple with coordinate (x,y) or 2d vector
"""
alpha = rot_in_degree / 180.0 * np.pi
x0 = node[0]
y0 = node[1]
x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha)
y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha)
return x1, y1
def position_to_coordinate(depth: int, positions: List[int]):
"""Converts coordinates to positions::
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
-->
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
:param depth:
:param positions:
:return:
Parameters
----------
depth : int
positions : List[Tuple[int,int]]
"""
coords = ()
for p in positions:
......@@ -29,7 +264,8 @@ def position_to_coordinate(depth, positions):
def coordinate_to_position(depth, coords):
"""
Converts positions to coordinates:
Converts positions to coordinates::
[ 0 d .. (w-1)*d
1 d+1
...
......@@ -46,13 +282,17 @@ def coordinate_to_position(depth, coords):
:param coords:
:return:
"""
position = np.empty(len(coords), dtype=int)
idx = 0
for t in coords:
position[idx] = int(t[1] * depth + t[0])
idx += 1
position = list(range(len(coords)))
for index, t in enumerate(coords):
if isnan(t[0]):
position[index] = -1
else:
position[index] = int(t[1] * depth + t[0])
return position
def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def distance_on_rail(pos1, pos2, metric="Euclidean"):
if metric == "Euclidean":
return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
if metric == "Manhattan":
return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.utils.ordered_set import OrderedSet
class RailEnvTransitions(Grid4Transitions):
"""
Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
Special case of `GridTransitions` over a 2D-grid, with a pre-defined set
of transitions mimicking the types of real Swiss rail connections.
--------------------------------------------------------------------------
As no diagonal transitions are allowed in the RailEnv environment, the
possible transitions for RailEnv from a cell to its neighboring ones
are represented over 16 bits.
......@@ -44,7 +43,7 @@ class RailEnvTransitions(Grid4Transitions):
)
# create this to make validation faster
self.transitions_all = set()
self.transitions_all = OrderedSet()
for index, trans in enumerate(self.transitions):
self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10):
......
......@@ -7,9 +7,15 @@ from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
class TransitionMap:
"""
Base TransitionMap class.
......@@ -21,7 +27,7 @@ class TransitionMap:
def get_transitions(self, cell_id):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -41,8 +47,8 @@ class TransitionMap:
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
......@@ -58,8 +64,8 @@ class TransitionMap:
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
......@@ -83,8 +89,8 @@ class TransitionMap:
def set_transition(self, cell_id, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
......@@ -111,7 +117,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])):
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None):
"""
Builder for GridTransitionMap object.
......@@ -130,7 +136,11 @@ class GridTransitionMap(TransitionMap):
self.width = width
self.height = height
self.transitions = transitions
self.random_generator = np.random.RandomState()
if random_seed is None:
self.random_generator.seed(12)
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_full_transitions(self, row, column):
......@@ -154,7 +164,7 @@ class GridTransitionMap(TransitionMap):
def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
`cell_id` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -176,8 +186,8 @@ class GridTransitionMap(TransitionMap):
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
Replaces the available transitions in cell `cell_id` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
......@@ -202,8 +212,8 @@ class GridTransitionMap(TransitionMap):
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
Return the status of whether an agent in cell `cell_id` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
......@@ -230,8 +240,8 @@ class GridTransitionMap(TransitionMap):
def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Replaces the validity of transition to `transition_index` in cell
`cell_id' with the new `new_transition`.
Parameters
......@@ -259,7 +269,7 @@ class GridTransitionMap(TransitionMap):
def save_transition_map(self, filename):
"""
Save the transitions grid as `filename', in npy format.
Save the transitions grid as `filename`, in npy format.
Parameters
----------
......@@ -271,9 +281,9 @@ class GridTransitionMap(TransitionMap):
def load_transition_map(self, package, resource):
"""
Load the transitions grid from `filename' (npy format).
Load the transitions grid from `filename` (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway.
initialized with the correct `transitions` object anyway.
Parameters
----------
......@@ -283,7 +293,7 @@ class GridTransitionMap(TransitionMap):
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) )
......@@ -298,26 +308,82 @@ class GridTransitionMap(TransitionMap):
self.height = new_height
self.grid = new_grid
def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
def is_dead_end(self,rcPos):
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a dead-end.
"""
Check if the cell is a dead-end
:param rcPos: tuple(row, column) with grid coordinate
:return: False : if not a dead-end else True
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a left/right simple turn.
"""
nbits = 0
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
return nbits==1
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
def is_simple_turn(trans):
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
for _ in range(3):
trans = self.transitions.rotate_transition(trans, rotation=90)
all_simple_turns.add(trans)
return trans in all_simple_turns
return is_simple_turn(tmp)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions(node_position[0], node_position[1], node_direction)
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
......@@ -361,15 +427,36 @@ class GridTransitionMap(TransitionMap):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
return False
return True
def fix_neighbours(self, rcPos, check_this_cell=False):
def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
......@@ -417,16 +504,22 @@ class GridTransitionMap(TransitionMap):
return True
def fix_transitions(self, rcPos):
def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
symmetrical = cells[6]
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south]
# loop over available outbound directions (indices) for rcPos
self.set_transitions(rcPos, 0)
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
......@@ -449,38 +542,97 @@ class GridTransitionMap(TransitionMap):
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line
# Only one incoming direction --> Straight line set deadend
if number_of_incoming == 1:
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
if self.get_full_transitions(*rcPos) == 0:
self.set_transitions(rcPos, 0)
else:
self.set_transitions(rcPos, 0)
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
self.set_transitions(rcPos, 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection fro three entries
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
hole = np.argwhere(incoming_connections < 1)[0][0]
connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
# Make a cross
if direction >= 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 0:
transition = simple_switch_west_south
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
else:
transition = self.random_generator.choice(three_way_transitions, 1)[0]
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
# Make a double slip switch
if number_of_incoming == 4:
connect_directions = np.arange(4)
self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1)
rotation = self.random_generator.randint(2)
transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
:param prev_pos: The previous position we were checking
:param current_pos: The current position we are checking
:param new_pos: Possible child position we move into
:param end_pos: End cell of path we are drawing
:return: True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return self.transitions.is_valid(new_trans)
def mirror(dir):
return (dir + 2) % 4
......
......@@ -12,7 +12,7 @@ class Transitions:
Generic class that implements checks to control whether a
certain transition is allowed (agent facing a direction
`orientation' and moving into direction `orientation')
`orientation' and moving into direction `orientation`)
"""
def get_type(self):
......@@ -21,7 +21,7 @@ class Transitions:
def get_transitions(self, cell_transition, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_transition' for an agent facing direction `orientation'
`cell_transition' for an agent facing direction `orientation`
(e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
......@@ -45,9 +45,9 @@ class Transitions:
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Return a `cell_transition' specification where the transitions
available for an agent facing direction `orientation' are replaced
with the tuple `new_transitions'. `new_orientations' must have
Return a `cell_transition` specification where the transitions
available for an agent facing direction `orientation` are replaced
with the tuple `new_transitions'. `new_orientations` must have
one element for each possible transition.
Parameters
......@@ -65,8 +65,8 @@ class Transitions:
-------
[cell-content]
An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions',
for the appropriate `orientation'.
transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation`.
"""
raise NotImplementedError()
......@@ -74,8 +74,8 @@ class Transitions:
def get_transition(self, cell_transition, orientation, direction):
"""
Return the status of whether an agent oriented in directions
`orientation' and inside a cell with transitions `cell_transition'
can move to the cell in direction `direction' relative
`orientation' and inside a cell with transitions `cell_transition`
can move to the cell in direction `direction` relative
to the current cell.
Parameters
......@@ -101,11 +101,11 @@ class Transitions:
def set_transition(self, cell_transition, orientation, direction,
new_transition):
"""
Return a `cell_transition' specification where the status of
whether an agent oriented in direction `orientation' and inside
a cell with transitions `cell_transition' can move to the cell
in direction `direction' relative to the current cell is set
to `new_transition'.
Return a `cell_transition` specification where the status of
whether an agent oriented in direction `orientation` and inside
a cell with transitions `cell_transition` can move to the cell
in direction `direction` relative to the current cell is set
to `new_transition`.
Parameters
----------
......@@ -125,8 +125,8 @@ class Transitions:
-------
[cell-content]
An updated class-specific object that replaces the original
transitions validity of `cell_transition' with `new_transitions',
for the appropriate `orientation' to `direction'.
transitions validity of `cell_transition' with `new_transitions`,
for the appropriate `orientation' to `direction`.
"""
raise NotImplementedError()
......
import networkx as nx
import numpy as np
from typing import List, Tuple
import graphviz as gv
class MotionCheck(object):
""" Class to find chains of agents which are "colliding" with a stopped agent.
This is to allow close-packed chains of agents, ie a train of agents travelling
at the same speed with no gaps between them,
"""
def __init__(self):
self.G = nx.DiGraph()
self.nDeadlocks = 0
self.svDeadlocked = set()
def addAgent(self, iAg, rc1, rc2, xlabel=None):
""" add an agent and its motion as row,col tuples of current and next position.
The agent's current position is given an "agent" attribute recording the agent index.
If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created.
xlabel is used for test cases to give a label (see graphviz)
"""
# Agents which have not yet entered the env have position None.
# Substitute this for the row = -1, column = agent index
if rc1 is None:
rc1 = (-1, iAg)
if rc2 is None:
rc2 = (-1, iAg)
self.G.add_node(rc1, agent=iAg)
if xlabel:
self.G.nodes[rc1]["xlabel"] = xlabel
self.G.add_edge(rc1, rc2)
def find_stops(self):
""" find all the stopped agents as a set of rc position nodes
A stopped agent is a self-loop on a cell node.
"""
# get the (sparse) adjacency matrix
spAdj = nx.linalg.adjacency_matrix(self.G)
# the stopped agents appear as 1s on the diagonal
# the where turns this into a list of indices of the 1s
giStops = np.where(spAdj.diagonal())[0]
# convert the cell/node indices into the node rc values
lvAll = list(self.G.nodes())
# pick out the stops by their indices
lvStops = [ lvAll[i] for i in giStops ]
# make it into a set ready for a set intersection
svStops = set(lvStops)
return svStops
def find_stops2(self):
""" alternative method to find stopped agents, using a networkx call to find selfloop edges
"""
svStops = { u for u,v in nx.classes.function.selfloop_edges(self.G) }
return svStops
def find_stop_preds(self, svStops=None):
""" Find the predecessors to a list of stopped agents (ie the nodes / vertices)
Returns the set of predecessors.
Includes "chained" predecessors.
"""
if svStops is None:
svStops = self.find_stops2()
# Get all the chains of agents - weakly connected components.
# Weakly connected because it's a directed graph and you can traverse a chain of agents
# in only one direction
lWCC = list(nx.algorithms.components.weakly_connected_components(self.G))
svBlocked = set()
for oWCC in lWCC:
#print("Component:", oWCC)
# Get the node details for this WCC in a subgraph
Gwcc = self.G.subgraph(oWCC)
# Find all the stops in this chain or tree
svCompStops = svStops.intersection(Gwcc)
#print(svCompStops)
if len(svCompStops) > 0:
# We need to traverse it in reverse - back up the movement edges
Gwcc_rev = Gwcc.reverse()
for vStop in svCompStops:
# Find all the agents stopped by vStop by following the (reversed) edges
# This traverses a tree - dfs = depth first seearch
iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc_rev, vStop)
lStops = list(iter_stops)
svBlocked.update(lStops)
# the set of all the nodes/agents blocked by this set of stopped nodes
return svBlocked
def find_swaps(self):
""" find all the swap conflicts where two agents are trying to exchange places.
These appear as simple cycles of length 2.
These agents are necessarily deadlocked (since they can't change direction in flatland) -
meaning they will now be stuck for the rest of the episode.
"""
#svStops = self.find_stops2()
llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
return svSwaps
def find_same_dest(self):
""" find groups of agents which are trying to land on the same cell.
ie there is a gap of one cell between them and they are both landing on it.
"""
pass
def block_preds(self, svStops, color="red"):
""" Take a list of stopped agents, and apply a stop color to any chains/trees
of agents trying to head toward those cells.
Count the number of agents blocked, ignoring those which are already marked.
(Otherwise it can double count swaps)
"""
iCount = 0
svBlocked = set()
# The reversed graph allows us to follow directed edges to find affected agents.
Grev = self.G.reverse()
for v in svStops:
# Use depth-first-search to find a tree of agents heading toward the blocked cell.
lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
svBlocked |= set(lvPred)
svBlocked.add(v)
#print("node:", v, "set", svBlocked)
# only count those not already marked
for v2 in [v]+lvPred:
if self.G.nodes[v2].get("color") != color:
self.G.nodes[v2]["color"] = color
iCount += 1
return svBlocked
def find_conflicts(self):
svStops = self.find_stops2() # voluntarily stopped agents - have self-loops
svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions
# Block all swaps and their tree of predessors
self.svDeadlocked = self.block_preds(svSwaps, color="purple")
# Take the union of the above, and find all the predecessors
#svBlocked = self.find_stop_preds(svStops.union(svSwaps))
# Just look for the the tree of preds for each voluntarily stopped agent
svBlocked = self.find_stop_preds(svStops)
# iterate the nodes v with their predecessors dPred (dict of nodes->{})
for (v, dPred) in self.G.pred.items():
# mark any swaps with purple - these are directly deadlocked
#if v in svSwaps:
# self.G.nodes[v]["color"] = "purple"
# If they are not directly deadlocked, but are in the union of stopped + deadlocked
#elif v in svBlocked:
# if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
if v in svBlocked:
self.G.nodes[v]["color"] = "red"
# not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
elif len(dPred)>1:
# if this agent is already red/blocked, ignore. CHECK: why?
# certainly we want to ignore purple so we don't overwrite with red.
if self.G.nodes[v].get("color") in ("red", "purple"):
continue
# if this node has no agent, and >=2 want to enter it.
if self.G.nodes[v].get("agent") is None:
self.G.nodes[v]["color"] = "blue"
# this node has an agent and >=2 want to enter
else:
self.G.nodes[v]["color"] = "magenta"
# predecessors of a contended cell: {agent index -> node}
diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred}
# remove the agent with the lowest index, who wins
iAgWinner = min(diAgCell)
diAgCell.pop(iAgWinner)
# Block all the remaining predessors, and their tree of preds
#for iAg, v in diAgCell.items():
# self.G.nodes[v]["color"] = "red"
# for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
# self.G.nodes[vPred]["color"] = "red"
self.block_preds(diAgCell.values(), "red")
def check_motion(self, iAgent, rcPos):
""" Returns tuple of boolean can the agent move, and the cell it will move into.
If agent position is None, we use a dummy position of (-1, iAgent)
"""
if rcPos is None:
rcPos = (-1, iAgent)
dAttr = self.G.nodes.get(rcPos)
#print("pos:", rcPos, "dAttr:", dAttr)
if dAttr is None:
dAttr = {}
# If it's been marked red or purple then it can't move
if "color" in dAttr:
sColor = dAttr["color"]
if sColor in [ "red", "purple" ]:
return False
dSucc = self.G.succ[rcPos]
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return False
# This agent has a successor
rcNext = self.G.successors(rcPos).__next__()
if rcNext == rcPos: # the agent didn't want to move
return False
# The agent wanted to move, and it can
return True
def render(omc:MotionCheck, horizontal=True):
try:
oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
oAG.layout("dot")
sDot = oAG.to_string()
if horizontal:
sDot = sDot.replace('{', '{ rankdir="LR" ')
#return oAG.draw(format="png")
# This returns a graphviz object which implements __repr_svg
return gv.Source(sDot)
except ImportError as oError:
print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs")
return None
class ChainTestEnv(object):
""" Just for testing agent chains
"""
def __init__(self, omc:MotionCheck):
self.iAgNext = 0
self.iRowNext = 1
self.omc = omc
def addAgent(self, rc1, rc2, xlabel=None):
self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel)
self.iAgNext+=1
def addAgentToRow(self, c1, c2, xlabel=None):
self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel)
def create_test_chain(self,
nAgents:int,
rcVel:Tuple[int] = (0,1),
liStopped:List[int]=[],
xlabel=None):
""" create a chain of agents
"""
lrcAgPos = [ (self.iRowNext, i * rcVel[1]) for i in range(nAgents) ]
for iAg, rcPos in zip(range(nAgents), lrcAgPos):
if iAg in liStopped:
rcVel1 = (0,0)
else:
rcVel1 = rcVel
self.omc.addAgent(iAg+self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1]) )
if xlabel:
self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel
self.iAgNext += nAgents
self.iRowNext += 1
def nextRow(self):
self.iRowNext+=1
def create_test_agents(omc:MotionCheck):
# blocked chain
omc.addAgent(1, (1,2), (1,3))
omc.addAgent(2, (1,3), (1,4))
omc.addAgent(3, (1,4), (1,5))
omc.addAgent(31, (1,5), (1,5))
# unblocked chain
omc.addAgent(4, (2,1), (2,2))
omc.addAgent(5, (2,2), (2,3))
# blocked short chain
omc.addAgent(6, (3,1), (3,2))
omc.addAgent(7, (3,2), (3,2))
# solitary agent
omc.addAgent(8, (4,1), (4,2))
# solitary stopped agent
omc.addAgent(9, (5,1), (5,1))
# blocked short chain (opposite direction)
omc.addAgent(10, (6,4), (6,3))
omc.addAgent(11, (6,3), (6,3))
# swap conflict
omc.addAgent(12, (7,1), (7,2))
omc.addAgent(13, (7,2), (7,1))
def create_test_agents2(omc:MotionCheck):
# blocked chain
cte = ChainTestEnv(omc)
cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain")
cte.create_test_chain(4, xlabel="running\nchain")
cte.create_test_chain(2, liStopped = [1], xlabel="stopped \nshort\n chain")
cte.addAgentToRow(1, 2, "swap")
cte.addAgentToRow(2, 1)
cte.nextRow()
cte.addAgentToRow(1, 2, "chain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nstop")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 4)
cte.addAgentToRow(5, 6)
cte.addAgentToRow(6, 7)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 3)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.nextRow()
cte.addAgentToRow(1, 2, "Land on\nSame")
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "chains\nonto\nsame")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.addAgentToRow(7, 6)
cte.nextRow()
cte.addAgentToRow(1, 2, "3-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.nextRow()
if False:
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "4-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.addAgent((cte.iRowNext-1, 2), (cte.iRowNext, 2))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tee")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgent((cte.iRowNext+1, 3), (cte.iRowNext, 3))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tree")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
r1 = cte.iRowNext
r2 = cte.iRowNext+1
r3 = cte.iRowNext+2
cte.addAgent((r2, 3), (r1, 3))
cte.addAgent((r2, 2), (r2, 3))
cte.addAgent((r3, 2), (r2, 3))
cte.nextRow()
def test_agent_following():
omc = MotionCheck()
create_test_agents2(omc)
svStops = omc.find_stops()
svBlocked = omc.find_stop_preds()
llvSwaps = omc.find_swaps()
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
print(list(svBlocked))
lvCells = omc.G.nodes()
lColours = [ "magenta" if v in svStops
else "red" if v in svBlocked
else "purple" if v in svSwaps
else "lightblue"
for v in lvCells ]
dPos = dict(zip(lvCells, lvCells))
nx.draw(omc.G,
with_labels=True, arrowsize=20,
pos=dPos,
node_color = lColours)
def main():
test_agent_following()
if __name__=="__main__":
main()
from itertools import starmap
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from attr import attrs, attrib, Factory
import warnings
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
('direction', Grid4TransitionsEnum),
('target', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('handle', int),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs
class EnvAgentStatic(object):
""" EnvAgentStatic - Stores initial position, direction and target.
This is like static data for the environment - it's where an agent starts,
rather than where it is at the moment.
The target should also be stored here.
"""
position = attrib()
direction = attrib()
target = attrib()
moving = attrib(default=False)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
class EnvAgent:
# INIT FROM HERE IN _from_line()
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# NEW : EnvAgent - Schedule properties
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
# TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
# some as broken?
malfunction_datas = []
for i in range(len(positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': 0,
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgentStatic, zip(positions,
directions,
targets,
[False] * len(positions),
speed_datas,
malfunction_datas)))
def to_list(self):
# I can't find an expression which works on both tuples, lists and ndarrays
# which converts them all to a list of native python ints.
lPos = self.position
if type(lPos) is np.ndarray:
lPos = lPos.tolist()
lTarget = self.target
if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data]
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
@attrs
class EnvAgent(EnvAgentStatic):
""" EnvAgent - replace separate agent_* lists with a single list
of agent objects. The EnvAgent represent's the environment's view
of the dynamic agent state.
We are duplicating target in the EnvAgent, which seems simpler than
forcing the env to refer to it in the EnvAgentStatic
"""
handle = attrib(default=None)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
# used in rendering
old_direction = attrib(default=None)
old_position = attrib(default=None)
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data]
@classmethod
def from_static(cls, oStatic):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
def reset(self):
"""
return EnvAgent(*oStatic.__dict__, handle=0)
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
"""
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
self.old_position = None
self.old_direction = None
self.moving = False
self.arrival_time = None
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
@classmethod
def list_from_static(cls, lEnvAgentStatic, handles=None):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
if handles is None:
handles = range(len(lEnvAgentStatic))
num_agents = len(line.agent_positions)
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} \n \
initial_direction: {self.initial_direction} \n \
position: {self.position} \n \
direction: {self.direction} \n \
target: {self.target} \n \
old_position: {self.old_position} \n \
old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} \n \
latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")
return [EnvAgent(**oEAS.__dict__, handle=handle)
for handle, oEAS in zip(handles, lEnvAgentStatic)]
from collections import deque
from typing import List, Optional
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
class DistanceMap:
def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
self.env_height = env_height
self.env_width = env_width
self.distance_map = None
self.agents_previous_computation = None
self.reset_was_called = False
self.agents: List[EnvAgent] = agents
self.rail: Optional[GridTransitionMap] = None
def set(self, distance_map: np.ndarray):
"""
Set the distance map
"""
self.distance_map = distance_map
def get(self) -> np.ndarray:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_computation is None and self.distance_map is not None:
compute_distance_map = False
if compute_distance_map:
self._compute(self.agents, self.rail)
elif self.distance_map is None:
self._compute(self.agents, self.rail)
return self.distance_map
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents: List[EnvAgent] = agents
self.rail = rail
self.env_height = rail.height
self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
This function computes the distance maps for each unique target. Thus if several targets are the same
we only compute the distance for them once and copy to all targets with same position.
:param agents: All the agents in the environment, independent of their current status
:param rail: The rail transition map
"""
self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
computed_targets = []
for i, agent in enumerate(agents):
if agent.target not in computed_targets:
self._distance_map_walker(rail, agent.target, i)
else:
# just copy the distance map form other agent with same target (performance)
self.distance_map[i, :, :, :] = np.copy(
self.distance_map[computed_targets.index(agent.target), :, :, :])
computed_targets.append(agent.target)
def _distance_map_walker(self, rail: GridTransitionMap, position, target_nr: int):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(rail, position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(rail, (node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env_height and new_cell[1] >= 0 and new_cell[1] < self.env_width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
from typing import Tuple
# Adrian Egli / Michel Marti performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def fast_delete(lis: list, index) -> list:
new_list = lis.copy()
new_list.pop(index)
return new_list
def fast_where(binary_iterable):
return [index for index, element in enumerate(binary_iterable) if element != 0]