Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2402 additions and 255 deletions
import random
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
# Relative weights of each cell type to be used by the random rail generators.
transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0, # Case 7 - dead end
0.2, # Case 8 - turn left
0.2, # Case 9 - turn right
1.0] # Case 10 - mirrored switch
# Example generate a random rail
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3)
env.reset()
env_renderer = RenderTool(env, gl="PIL")
env_renderer.render_env(show=True)
# uncomment to keep the renderer open
#input("Press Enter to continue...")
import getopt
import sys
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=1),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObservation,
number_of_agents=3)
env.reset()
env_renderer = RenderTool(env, gl="PILSVG", )
def create_env():
nAgents = 1
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=30,
height=40,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
)
return env
# Import your own Agent or use RLlib to train agents on Flatland
......@@ -60,42 +70,85 @@ class RandomAgent:
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 5)
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
print("Starting Training...")
for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs, info = env.reset()
for idx in range(env.get_num_agents()):
tmp_agent = env.agents[idx]
tmp_agent.speed_data["speed"] = 1 / (idx + 1)
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}\t Score = {}'.format(trials, score))
def training_example(sleep_for_animation, do_rendering):
np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
env = create_env()
env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 5)
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
print("Starting Training...")
for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs, info = env.reset()
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}\t Score = {}'.format(trials, score))
if env_renderer is not None:
env_renderer.close_window()
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
training_example(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
......@@ -4,4 +4,4 @@
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
__version__ = '2.1.8'
__version__ = '3.0.15'
File moved
import pprint
from typing import Dict, List, Optional, NamedTuple
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
# ---- ActionPlan ---------------
# an action plan element represents the actions to be taken by an agent at the given time step
ActionPlanElement = NamedTuple('ActionPlanElement', [
('scheduled_at', int),
('action', RailEnvActions)
])
# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps
ActionPlan = List[ActionPlanElement]
# An action plan dict gathers all the actions for every agent identified by the dictionary key = agent_handle
ActionPlanDict = Dict[int, ActionPlan]
class ControllerFromTrainruns():
"""Takes train runs, derives the actions from it and re-acts them."""
pp = pprint.PrettyPrinter(indent=4)
def __init__(self,
env: RailEnv,
trainrun_dict: Dict[int, Trainrun]):
self.env: RailEnv = env
self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict
self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path)
for agent_id, chosen_path in trainrun_dict.items()]
def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
"""
Get the way point point from which the current position can be extracted.
Parameters
----------
agent_id
step
Returns
-------
WalkingElement
"""
trainrun = self.trainrun_dict[agent_id]
entry_time_step = trainrun[0].scheduled_at
# the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
if step <= entry_time_step:
return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction)
# the agent has no position as soon as the target is reached
exit_time_step = trainrun[-1].scheduled_at
if step >= exit_time_step:
# agent loses position as soon as target cell is reached
return Waypoint(position=None, direction=trainrun[-1].waypoint.direction)
waypoint = None
for trainrun_waypoint in trainrun:
if step < trainrun_waypoint.scheduled_at:
return waypoint
if step >= trainrun_waypoint.scheduled_at:
waypoint = trainrun_waypoint.waypoint
assert waypoint is not None
return waypoint
def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
"""
Get the current action if any is defined in the `ActionPlan`.
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
agent_id
current_step
Returns
-------
WalkingElement, optional
"""
for action_plan_element in self.action_plan[agent_id]:
scheduled_at = action_plan_element.scheduled_at
if scheduled_at > current_step:
return None
elif current_step == scheduled_at:
return action_plan_element.action
return None
def act(self, current_step: int) -> Dict[int, RailEnvActions]:
"""
Get the action dictionary to be replayed at the current step.
Returns only action where required (no action for done agents or those not at the beginning of the cell).
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
current_step: int
Returns
-------
Dict[int, RailEnvActions]
"""
action_dict = {}
for agent_id in range(len(self.env.agents)):
action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step)
if action is not None:
action_dict[agent_id] = action
return action_dict
def print_action_plan(self):
"""Pretty-prints `ActionPlanDict` of this `ControllerFromTrainruns` to stdout."""
self.__class__.print_action_plan_dict(self.action_plan)
@staticmethod
def print_action_plan_dict(action_plan: ActionPlanDict):
"""Pretty-prints `ActionPlanDict` to stdout."""
for agent_id, plan in enumerate(action_plan):
print("{}: ".format(agent_id))
for step in plan:
print(" {}".format(step))
@staticmethod
def assert_actions_plans_equal(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict):
assert len(expected_action_plan) == len(actual_action_plan)
for k in range(len(expected_action_plan)):
assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \
"len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format(
k,
len(expected_action_plan[k]),
ControllerFromTrainruns.pp.pformat(expected_action_plan[k]),
len(actual_action_plan[k]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k]))
for i in range(len(expected_action_plan[k])):
assert expected_action_plan[k][i] == actual_action_plan[k][i], \
"not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format(
k, i,
ControllerFromTrainruns.pp.pformat(expected_action_plan[k][i]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k][i]))
assert expected_action_plan == actual_action_plan, \
"expected {}, found {}".format(expected_action_plan, actual_action_plan)
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = agent.speed_counter.max_count + 1
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
position = trainrun_waypoint.waypoint.position
if Vec2d.is_equal(agent.target, position):
break
next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1]
next_position = next_trainrun_waypoint.waypoint.position
if path_loop == 0:
self._add_action_plan_elements_for_first_path_element_of_agent(
action_plan,
trainrun_waypoint,
next_trainrun_waypoint,
minimum_cell_time
)
continue
just_before_target = Vec2d.is_equal(agent.target, next_position)
self._add_action_plan_elements_for_current_path_element(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
# add a final element
if just_before_target:
self._add_action_plan_elements_for_target_at_path_element_just_before_target(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
return action_plan
def _add_action_plan_elements_for_current_path_element(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
next_entry_value = next_trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the next entry is later than minimum_cell_time, then stop here and
# move minimum_cell_time before the exit
# we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
if next_entry_value > scheduled_at + minimum_cell_time:
action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING)
action_plan.append(action)
action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action)
action_plan.append(action)
else:
action = ActionPlanElement(scheduled_at, next_action)
action_plan.append(action)
def _add_action_plan_elements_for_target_at_path_element_just_before_target(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
action_plan.append(action)
def _add_action_plan_elements_for_first_path_element_of_agent(self,
action_plan: ActionPlan,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint,
minimum_cell_time: int):
scheduled_at = trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
# add intial do nothing if we do not enter immediately, actually not necessary
if scheduled_at > 0:
action = ActionPlanElement(0, RailEnvActions.DO_NOTHING)
action_plan.append(action)
# add action to enter the grid
action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD)
action_plan.append(action)
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the agent is blocked in the cell, we have to call stop upon entering!
if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time:
action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING)
action_plan.append(action)
# execute the action exactly minimum_cell_time before the entry into the next cell
action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
action_plan.append(action)
from typing import Callable
from flatland.action_plan.action_plan import ControllerFromTrainruns
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_trainrun_data_structures import Waypoint
ControllerFromTrainrunsReplayerRenderCallback = Callable[[RailEnv], None]
class ControllerFromTrainrunsReplayer():
"""Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
@staticmethod
def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv,
call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a, **k: None):
"""Replays this deterministic `ActionPlan` and verifies whether it is feasible.
Parameters
----------
ctl
env
call_back
Called before/after each step() call. The env is passed to it.
"""
call_back(env)
i = 0
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
actions = ctl.act(i)
obs, all_rewards, done, _ = env.step(actions)
call_back(env)
i += 1
......@@ -9,9 +9,9 @@ import numpy as np
import redis
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.evaluators.service import FlatlandRemoteEvaluationService
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.evaluators.service import FlatlandRemoteEvaluationService, FLATLAND_RL_SERVICE_ID
from flatland.utils.rendertools import RenderTool
......@@ -19,39 +19,41 @@ from flatland.utils.rendertools import RenderTool
def demo(args=None):
"""Demo script to check installation"""
env = RailEnv(
width=15,
height=15,
rail_generator=complex_rail_generator(
nr_start_goal=10,
nr_extra=1,
min_dist=8,
max_dist=99999),
schedule_generator=complex_schedule_generator(),
width=30,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=3,
grid_mode=False,
max_rails_between_cities=4,
max_rail_pairs_in_city=2,
seed=0
),
line_generator=sparse_line_generator(),
number_of_agents=5)
env._max_episode_steps = int(15 * (env.width + env.height))
env_renderer = RenderTool(env)
while True:
obs, info = env.reset()
_done = False
# Run a single episode here
step = 0
while not _done:
# Compute Action
_action = {}
for _idx, _ in enumerate(env.agents):
_action[_idx] = np.random.randint(0, 5)
obs, all_rewards, done, _ = env.step(_action)
_done = done['__all__']
step += 1
env_renderer.render_env(
show=True,
frames=False,
show_observations=False,
show_predictions=False
)
time.sleep(0.3)
obs, info = env.reset()
_done = False
# Run a single episode here
step = 0
while not _done:
# Compute Action
_action = {}
for _idx, _ in enumerate(env.agents):
_action[_idx] = np.random.randint(0, 5)
obs, all_rewards, done, _ = env.step(_action)
_done = done['__all__']
step += 1
env_renderer.render_env(
show=True,
frames=False,
show_observations=False,
show_predictions=False
)
time.sleep(0.1)
return 0
......@@ -62,11 +64,28 @@ def demo(args=None):
required=True
)
@click.option('--service_id',
default="FLATLAND_RL_SERVICE_ID",
default=FLATLAND_RL_SERVICE_ID,
help="Evaluation Service ID. This has to match the service id on the client.",
required=False
)
def evaluator(tests, service_id):
@click.option('--shuffle',
type=bool,
default=False,
help="Shuffle the environments before starting evaluation.",
required=False
)
@click.option('--disable_timeouts',
default=False,
help="Disable all evaluation timeouts.",
required=False
)
@click.option('--results_path',
type=click.Path(exists=False),
default=None,
help="Path where the evaluator should write the results metadata.",
required=False
)
def evaluator(tests, service_id, shuffle, disable_timeouts, results_path):
try:
redis_connection = redis.Redis()
redis_connection.ping()
......@@ -80,7 +99,10 @@ def evaluator(tests, service_id):
test_env_folder=tests,
flatland_rl_service_id=service_id,
visualize=False,
verbose=False
result_output_path=results_path,
verbose=False,
shuffle=shuffle,
disable_timeouts=disable_timeouts
)
grader.run()
......
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return env
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
self._possible_agents = self.agents[:]
self._reset_next_step = True
self._agent_selector = agent_selector(self.agents)
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self.preprocessor = self._obtain_preprocessor(preprocessor)
self._include_agent_info = agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
obs = self.preprocessor(obs)
self.observation_spaces = {
i: infer_observation_space(ob) for i, ob in obs.items()
}
@property
def environment(self) -> RailEnv:
"""Returns the wrapped environment."""
return self._environment
@property
def dones(self):
dones = self._environment.dones
# remove_all = dones.pop("__all__", None)
return {get_agent_keys(key): value for key, value in dones.items()}
@property
def obs_builder(self):
return self._environment.obs_builder
@property
def width(self):
return self._environment.width
@property
def height(self):
return self._environment.height
@property
def agents_data(self):
"""Rail Env Agents data."""
return self._environment.agents
@property
def num_agents(self) -> int:
"""Returns the number of trains/agents in the flatland environment"""
return int(self._environment.number_of_agents)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@property
def agents(self):
return self._agents
@property
def possible_agents(self):
return self._possible_agents
def env_done(self):
return self._environment.dones["__all__"] or not self.agents
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
def state(self):
'''
Returns an observation of the global environment
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
self.agent_selection = self._agent_selector.next()
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
return observations
def step(self, action):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
else:
self._clear_rewards()
# self._cumulative_rewards[agent] = 0
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(self, observes, info):
observations = {}
infos = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
agent_info = np.array(
list(all_infos.values()), dtype=np.float32
)
infos[agent] = all_infos
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent] = obs
self.infos = infos
self.obs = observations
return observations
def set_probs(self, probs):
self.probs = probs
def render(self, mode='rgb_array'):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
"""
if mode == "rgb_array":
env_rgb_array = self._environment.render(mode)
if not hasattr(self, "image_shape "):
self.image_shape = env_rgb_array.shape
if not hasattr(self, "probs "):
self.probs = [[0., 0., 0., 0.]]
fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
constrained_layout=True, dpi=100)
df = pd.DataFrame(np.array(self.probs).T)
sns.barplot(x=df.index, y=0, data=df, ax=ax)
ax.set(xlabel='actions', ylabel='probs')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())
Image.fromarray(X)
# Image.fromarray(X)
rgb_image = np.array(Image.fromarray(X).convert('RGB'))
plt.close(fig)
q_value_rgb_array = rgb_image
return np.append(env_rgb_array, q_value_rgb_array, axis=1)
else:
return self._environment.render(mode)
def close(self):
self._environment.close()
def _obtain_preprocessor(self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
_preprocessor = preprocessor if preprocessor else lambda x: x
if isinstance(self.obs_builder, TreeObsForRailEnv):
_preprocessor = (
partial(
normalize_observation, tree_depth=self.obs_builder.max_depth
)
if not preprocessor
else preprocessor
)
assert _preprocessor is not None
else:
def _preprocessor(x):
return x
def returned_preprocessor(obs):
temp_obs = {}
for agent_id, ob in obs.items():
temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
return temp_obs
return returned_preprocessor
# Utility functions
def convert_np_type(dtype, value):
return np.dtype(dtype).type(value)
def get_agent_handle(id):
"""Obtain an agents handle given its id"""
return int(id)
def get_agent_keys(id):
"""Obtain an agents handle given its id"""
return str(id)
\ No newline at end of file
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
seaborn
matplotlib
pandas
\ No newline at end of file
from ray import tune
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import numpy as np
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
def env_creator(args):
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
# __sphinx_doc_end__
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
import supersuit as ss
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import fnmatch
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True,
config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95,
n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
batch_size=150, seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
# __sphinx_doc_end__
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment=rail_env, use_renderer=True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:", completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
run.join()
from typing import List
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.step_utils import env_utils
class Deadlock_Checker:
def __init__(self, env):
self.env = env
self.deadlocked_agents = []
self.immediate_deadlocked = []
def reset(self) -> None:
self.deadlocked_agents = []
self.immediate_deadlocked = []
# an immediate deadlock consists of two trains "trying to pass through each other".
# An agent may have a free possible transition, but took a bad action and "ran into another train". This is now a deadlock, and the other free
# direction can not be chosen anymore!
def check_immediate_deadlocks(self, action_dict) -> List[EnvAgent]:
"""
output: list of agents who are in immediate deadlocks
"""
env = self.env
newly_deadlocked_agents = []
# TODO: check restrictions to relevant agents (status ACTIVE, etc.)
relevant_agents = [agent for agent in env.agents if agent.state != TrainState.DONE and agent.position is not None]
for agent in relevant_agents:
other_agents = [other_agent for other_agent in env.agents if other_agent != agent] # check if this is a good test for inequality. Maybe use handles...
# get the transitions the agent can take from his current position and orientation
# an indicator array of the form e.g. (0,1,1,0) meaning that he can only go to east and south, not to north and west.
possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
#print(f"possible transitions: {possible_transitions}")
# the directions are: 0(north), 1(east), 2(south) and 3(west)
#possible_directions = [direction for direction, flag in enumerate(possible_transitions) if flag == 1]
#print(f"possible directions: {possible_directions}")
################### only consider direction for actually chosen action ###############################
new_position, new_direction = env_utils.apply_action_independent(action=action_dict[agent.handle], rail=env.rail, position=agent.position, direction=agent.direction)
#assert new_direction in possible_directions, "Error, action leads to impossible direction"
assert new_position == get_new_position(agent.position, new_direction), "Error, something is wrong with new position"
opposed_agent_id = env.agent_positions[new_position] # TODO: check that agent_positions now works correctly in flatland V3 (i.e. gets correctly updated...)
# agent_positions[cell] is an agent_id if an agent is there, otherwise -1.
if opposed_agent_id != -1:
opposed_agent = env.agents[opposed_agent_id]
# other agent with opposing direction is in the way --> deadlock
# an opposing direction means having a different direction than our agent would have if he moved to the new cell. (180 degrees or 90 degrees to our agent)
if opposed_agent.direction != new_direction:
if agent not in newly_deadlocked_agents: # to avoid duplicates
newly_deadlocked_agents.append(agent)
if opposed_agent not in newly_deadlocked_agents: # to avoid duplicates
newly_deadlocked_agents.append(opposed_agent)
self.immediate_deadlocked = newly_deadlocked_agents
return newly_deadlocked_agents
# main method to check for all deadlocks
def check_deadlocks(self, action_dict) -> List[EnvAgent]:
env = self.env
relevant_agents = [agent for agent in env.agents if agent.state != TrainState.DONE and agent.position is not None]
immediate_deadlocked = self.check_immediate_deadlocks(action_dict)
self.immediate_deadlocked = immediate_deadlocked
deadlocked = immediate_deadlocked[:]
# now we have to "close": each train which is blocked by another deadlocked train becomes deadlocked itself.
still_changing = True
while still_changing:
still_changing = False # will be overwritten below if a change did occur
# check if for any agent, there is a new deadlock found
for agent in relevant_agents:
#possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
#print(f"possible transitions: {possible_transitions}")
# the directions are: 0 (north), 1(east), 2(south) and 3(west)
#possible_directions = [direction for direction, flag in enumerate(possible_transitions) if flag == 1]
#print(f"possible directions: {possible_directions}")
new_position, new_direction = env_utils.apply_action_independent(action=action_dict[agent.handle], rail=env.rail, position=agent.position, direction=agent.direction)
#assert new_direction in possible_directions, "Error, action leads to impossible direction"
assert new_position == get_new_position(agent.position, new_direction), "Error, something is wrong with new position"
opposed_agent_id = env.agent_positions[new_position]
if opposed_agent_id != -1: # there is an opposed agent there
opposed_agent = env.agents[opposed_agent_id]
if opposed_agent in deadlocked:
if agent not in deadlocked: # to avoid duplicates
deadlocked.append(agent)
still_changing = True
self.deadlocked_agents = deadlocked
return deadlocked
\ No newline at end of file
import logging
import random
import numpy as np
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.fast_methods import fast_count_nonzero, fast_argmax
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]
if agent.status in [TrainState.WAITING, TrainState.READY_TO_DEPART,
TrainState.MALFUNCTION_OFF_MAP]:
agent_virtual_position = agent.initial_position
elif agent.status in [TrainState.MALFUNCTION, TrainState.MOVING, TrainState.STOPPED]:
agent_virtual_position = agent.position
elif agent.status == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
if agent.position:
possible_transitions = env.rail.get_transitions(
*agent.position, agent.direction)
else:
possible_transitions = env.rail.get_transitions(
*agent.initial_position, agent.direction)
num_transitions = fast_count_nonzero(possible_transitions)
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(
agent_virtual_position, direction)
min_distances.append(
distance_map[handle, new_position[0],
new_position[1], direction])
else:
min_distances.append(np.inf)
if num_transitions == 1:
observation = [0, 1, 0]
elif num_transitions == 2:
idx = np.argpartition(np.array(min_distances), 2)
observation = [0, 0, 0]
observation[idx[0]] = 1
return fast_argmax(observation) + 1
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
random.seed(random_seed)
width = 30
height = 30
nr_trains = 5
max_num_cities = 4
grid_mode = False
max_rails_between_cities = 2
max_rails_in_city = 3
malfunction_rate = 0
malfunction_min_duration = 0
malfunction_max_duration = 0
rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
speed_ratio_map = None
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
random.seed(random_seed)
size = random.randint(0, 5)
width = 20 + size * 5
height = 20 + size * 5
nr_cities = 2 + size // 2 + random.randint(0, 2)
nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5)) # , 10 + random.randint(0, 10))
max_rails_between_cities = 2
max_rails_in_cities = 3 + random.randint(0, size)
malfunction_rate = 30 + random.randint(0, 100)
malfunction_min_duration = 3 + random.randint(0, 7)
malfunction_max_duration = 20 + random.randint(0, 80)
rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_cities)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def sparse_env_small(random_seed, observation_builder):
width = 30 # With of map
height = 30 # Height of map
nr_trains = 2 # Number of trains that have an assigned task in the env
cities_in_map = 3 # Number of cities where agents can start or end
seed = 10 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
line_generator = sparse_rail_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
rail_env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True)
return rail_env
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
def perc_completion(env):
tasks_finished = 0
if hasattr(env, "agents_data"):
agent_data = env.agents_data
else:
agent_data = env.agents
for current_agent in agent_data:
if current_agent.status == TrainState.DONE:
tasks_finished += 1
return 100 * np.mean(tasks_finished / max(
1, len(agent_data)))
from collections import defaultdict
from typing import Dict, Tuple
from flatland.contrib.utils.deadlock_checker import Deadlock_Checker
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.step_utils.states import TrainState
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.state == TrainState.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
else:
print("no action possible!")
print("agent state: ", agent.state)
# NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...)
return [(RailEnvActions.DO_NOTHING, 0)] * 2
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
print(f"possible transitions: {possible_transitions}")
distance_map = env.distance_map.get()[handle]
possible_steps = []
for movement in list(range(4)):
if possible_transitions[movement]:
if movement == agent.direction:
action = RailEnvActions.MOVE_FORWARD
elif movement == (agent.direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?")
action = RailEnvActions.MOVE_FORWARD
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
# if there is only one path to target, this is both the shortest one and the second shortest path.
if len(possible_steps) == 1:
return possible_steps * 2
else:
return possible_steps
class RailEnvWrapper:
def __init__(self, env:RailEnv):
self.env = env
assert self.env is not None
assert self.env.rail is not None, "Reset original environment first!"
assert self.env.agents is not None, "Reset original environment first!"
assert len(self.env.agents) > 0, "Reset original environment first!"
# @property
# def number_of_agents(self):
# return self.env.number_of_agents
# @property
# def agents(self):
# return self.env.agents
# @property
# def _seed(self):
# return self.env._seed
# @property
# def obs_builder(self):
# return self.env.obs_builder
def __getattr__(self, name):
try:
return super().__getattr__(self,name)
except:
"""Expose any other attributes of the underlying environment."""
return getattr(self.env, name)
@property
def rail(self):
return self.env.rail
@property
def width(self):
return self.env.width
@property
def height(self):
return self.env.height
@property
def agent_positions(self):
return self.env.agent_positions
def get_num_agents(self):
return self.env.get_num_agents()
def get_agent_handles(self):
return self.env.get_agent_handles()
def step(self, action_dict: Dict[int, RailEnvActions]):
return self.env.step(action_dict)
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return obs, info
class ShortestPathActionWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv):
super().__init__(env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# input: action dict with actions in [0, 1, 2].
transformed_action_dict = {}
for agent_id, action in action_dict.items():
if action == 0:
transformed_action_dict[agent_id] = action
else:
#assert action in [1, 2]
#assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
#assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
obs, rewards, dones, info = self.env.step(transformed_action_dict)
return obs, rewards, dones, info
def find_all_cells_where_agent_can_choose(env: RailEnv):
"""
input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
WHICH HAS BEEN RESET ALREADY!
(o.w., we call env.rail, which is None before reset(), and crash.)
"""
switches = []
switches_neighbors = []
directions = list(range(4))
for h in range(env.height):
for w in range(env.width):
pos = (h, w)
is_switch = False
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
break
if is_switch:
# Add all neighbouring rails, if pos is a switch
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
for movement in directions:
if possible_transitions[movement]:
switches_neighbors.append(get_new_position(pos, movement))
decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.switches = None
self.switches_neighbors = None
self.decision_cells = None
self.skipped_rewards = defaultdict(list)
# sets initial values for switches, decision_cells, etc.
self.reset_cells()
def on_decision_cell(self, agent: EnvAgent) -> bool:
return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
def on_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches
def next_to_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches_neighbors
def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# need to initialize i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["state"] = dict()
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self.reset_cells()
return obs, info
class DeadlockWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv, deadlock_reward=-100) -> None:
super().__init__(env)
self.deadlock_reward = deadlock_reward
self.deadlock_checker = Deadlock_Checker(env=self.env)
@property
def deadlocked_agents(self):
return self.deadlock_checker.deadlocked_agents
@property
def immediate_deadlocks(self):
return [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
# make sure to assign the deadlock reward only once to each deadlocked agent...
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# agents which are already deadlocked from previous steps
already_deadlocked_ids = [agent.handle for agent in self.deadlocked_agents]
# step environment
obs, rewards, dones, info = self.env.step(action_dict)
# compute new list of deadlocked agents (ids) after stepping the environment
deadlocked_agents = self.deadlock_checker.check_deadlocks(action_dict) # also stored in self.deadlocked_checker.deadlocked_agents
deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents]
# immediate deadlocked ids only used for prints
immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
print(f"immediate deadlocked: {immediate_deadlocked_ids}")
print(f"total deadlocked: {deadlocked_agents_ids}")
newly_deadlocked_agents_ids = [agent_id for agent_id in deadlocked_agents_ids if agent_id not in already_deadlocked_ids]
# assign deadlock rewards
for agent_id in newly_deadlocked_agents_ids:
print(f"assigning deadlock reward of {self.deadlock_reward} to agent {agent_id}")
rewards[agent_id] = self.deadlock_reward
return obs, rewards, dones, info
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list
obs, info = super().reset(**kwargs)
return obs, info
from enum import IntEnum
from functools import lru_cache
from typing import Type, List
import numpy as np
......@@ -6,6 +7,63 @@ import numpy as np
from flatland.core.transitions import Transitions
# maxsize=None can be used because the number of possible transition is limited (16 bit encoded) and the
# direction/orientation is also limited (2bit). Where the 16bit are only sparse used = number of rail types
# Those methods can be cached -> the are independant of the railways (env)
@lru_cache(maxsize=128)
def fast_grid4_get_transitions(cell_transition, orientation):
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
@lru_cache(maxsize=128)
def fast_grid4_get_transition(cell_transition, orientation, direction):
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
@lru_cache(maxsize=128)
def fast_grid4_set_transitions(cell_transition, orientation, new_transitions):
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_remove_deadends(cell_transition):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
maskDeadEnds = Grid4Transitions.maskDeadEnds()
cell_transition &= cell_transition & (~maskDeadEnds) & 0xffff
return cell_transition
@lru_cache(maxsize=128)
def fast_grid4_rotate_transition(cell_transition, rotation=0):
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = fast_grid4_get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = fast_grid4_set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (
value >> (rotation * 4))
cell_transition = value
return cell_transition
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
......@@ -57,8 +115,11 @@ class Grid4Transitions(Transitions):
# row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
# These bits represent all the possible dead ends
self.maskDeadEnds = 0b0010000110000100
# These bits represent all the possible dead ends
@staticmethod
@lru_cache()
def maskDeadEnds():
return 0b0010000110000100
def get_type(self):
return np.uint16
......@@ -83,8 +144,7 @@ class Grid4Transitions(Transitions):
List of the validity of transitions in the cell.
"""
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
return fast_grid4_get_transitions(cell_transition, orientation)
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
......@@ -111,18 +171,7 @@ class Grid4Transitions(Transitions):
`orientation`.
"""
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
return fast_grid4_set_transitions(cell_transition, orientation, new_transitions)
def get_transition(self, cell_transition, orientation, direction):
"""
......@@ -146,9 +195,10 @@ class Grid4Transitions(Transitions):
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
return fast_grid4_get_transition(cell_transition, orientation, direction)
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
def set_transition(self, cell_transition, orientation, direction, new_transition,
remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation` and inside a cell with transitions
......@@ -181,7 +231,7 @@ class Grid4Transitions(Transitions):
cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
if remove_deadends:
cell_transition = self.remove_deadends(cell_transition)
cell_transition = fast_grid4_remove_deadends(cell_transition)
return cell_transition
......@@ -206,27 +256,18 @@ class Grid4Transitions(Transitions):
"""
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
cell_transition = value
return cell_transition
return fast_grid4_rotate_transition(cell_transition, rotation)
def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
return Grid4TransitionsEnum
def has_deadend(self, cell_transition):
@staticmethod
@lru_cache()
def has_deadend(cell_transition):
"""
Checks if one entry can only by exited by a turn-around.
"""
if cell_transition & self.maskDeadEnds > 0:
if cell_transition & Grid4Transitions.maskDeadEnds() > 0:
return True
else:
return False
......@@ -235,9 +276,9 @@ class Grid4Transitions(Transitions):
"""
Remove all turn-arounds (e.g. N-S, S-N, E-W,...).
"""
cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
return cell_transition
return fast_grid4_remove_deadends(cell_transition)
@staticmethod
@lru_cache()
def get_entry_directions(cell_transition) -> List[int]:
return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
......@@ -25,17 +25,9 @@ def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
def mirror(dir):
return (dir + 2) % 4
MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1])
def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
......
from math import isnan
from typing import Tuple, Callable, List, Type
import numpy as np
......@@ -281,15 +282,12 @@ def coordinate_to_position(depth, coords):
:param coords:
:return:
"""
position = np.empty(len(coords), dtype=int)
idx = 0
for t in coords:
# Set None type coordinates off the grid
if np.isnan(t[0]):
position[idx] = -1
position = list(range(len(coords)))
for index, t in enumerate(coords):
if isnan(t[0]):
position[index] = -1
else:
position[idx] = int(t[1] * depth + t[0])
idx += 1
position[index] = int(t[1] * depth + t[0])
return position
......
......@@ -117,7 +117,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])):
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None):
"""
Builder for GridTransitionMap object.
......@@ -136,7 +136,11 @@ class GridTransitionMap(TransitionMap):
self.width = width
self.height = height
self.transitions = transitions
self.random_generator = np.random.RandomState()
if random_seed is None:
self.random_generator.seed(12)
else:
self.random_generator.seed(random_seed)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_full_transitions(self, row, column):
......@@ -317,12 +321,8 @@ class GridTransitionMap(TransitionMap):
boolean
True if and only if the cell is a dead-end.
"""
nbits = 0
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
return nbits == 1
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
......@@ -511,7 +511,6 @@ class GridTransitionMap(TransitionMap):
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
......@@ -572,15 +571,15 @@ class GridTransitionMap(TransitionMap):
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = np.random.choice(three_way_transitions, 1)
transition = self.random_generator.choice(three_way_transitions, 1)[0]
else:
transition = np.random.choice(three_way_transitions, 1)
transition = self.random_generator.choice(three_way_transitions, 1)[0]
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
# Make a double slip switch
if number_of_incoming == 4:
rotation = np.random.randint(2)
rotation = self.random_generator.randint(2)
transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
......
import networkx as nx
import numpy as np
from typing import List, Tuple
import graphviz as gv
class MotionCheck(object):
""" Class to find chains of agents which are "colliding" with a stopped agent.
This is to allow close-packed chains of agents, ie a train of agents travelling
at the same speed with no gaps between them,
"""
def __init__(self):
self.G = nx.DiGraph()
self.nDeadlocks = 0
self.svDeadlocked = set()
def addAgent(self, iAg, rc1, rc2, xlabel=None):
""" add an agent and its motion as row,col tuples of current and next position.
The agent's current position is given an "agent" attribute recording the agent index.
If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created.
xlabel is used for test cases to give a label (see graphviz)
"""
# Agents which have not yet entered the env have position None.
# Substitute this for the row = -1, column = agent index
if rc1 is None:
rc1 = (-1, iAg)
if rc2 is None:
rc2 = (-1, iAg)
self.G.add_node(rc1, agent=iAg)
if xlabel:
self.G.nodes[rc1]["xlabel"] = xlabel
self.G.add_edge(rc1, rc2)
def find_stops(self):
""" find all the stopped agents as a set of rc position nodes
A stopped agent is a self-loop on a cell node.
"""
# get the (sparse) adjacency matrix
spAdj = nx.linalg.adjacency_matrix(self.G)
# the stopped agents appear as 1s on the diagonal
# the where turns this into a list of indices of the 1s
giStops = np.where(spAdj.diagonal())[0]
# convert the cell/node indices into the node rc values
lvAll = list(self.G.nodes())
# pick out the stops by their indices
lvStops = [ lvAll[i] for i in giStops ]
# make it into a set ready for a set intersection
svStops = set(lvStops)
return svStops
def find_stops2(self):
""" alternative method to find stopped agents, using a networkx call to find selfloop edges
"""
svStops = { u for u,v in nx.classes.function.selfloop_edges(self.G) }
return svStops
def find_stop_preds(self, svStops=None):
""" Find the predecessors to a list of stopped agents (ie the nodes / vertices)
Returns the set of predecessors.
Includes "chained" predecessors.
"""
if svStops is None:
svStops = self.find_stops2()
# Get all the chains of agents - weakly connected components.
# Weakly connected because it's a directed graph and you can traverse a chain of agents
# in only one direction
lWCC = list(nx.algorithms.components.weakly_connected_components(self.G))
svBlocked = set()
for oWCC in lWCC:
#print("Component:", oWCC)
# Get the node details for this WCC in a subgraph
Gwcc = self.G.subgraph(oWCC)
# Find all the stops in this chain or tree
svCompStops = svStops.intersection(Gwcc)
#print(svCompStops)
if len(svCompStops) > 0:
# We need to traverse it in reverse - back up the movement edges
Gwcc_rev = Gwcc.reverse()
for vStop in svCompStops:
# Find all the agents stopped by vStop by following the (reversed) edges
# This traverses a tree - dfs = depth first seearch
iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc_rev, vStop)
lStops = list(iter_stops)
svBlocked.update(lStops)
# the set of all the nodes/agents blocked by this set of stopped nodes
return svBlocked
def find_swaps(self):
""" find all the swap conflicts where two agents are trying to exchange places.
These appear as simple cycles of length 2.
These agents are necessarily deadlocked (since they can't change direction in flatland) -
meaning they will now be stuck for the rest of the episode.
"""
#svStops = self.find_stops2()
llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
return svSwaps
def find_same_dest(self):
""" find groups of agents which are trying to land on the same cell.
ie there is a gap of one cell between them and they are both landing on it.
"""
pass
def block_preds(self, svStops, color="red"):
""" Take a list of stopped agents, and apply a stop color to any chains/trees
of agents trying to head toward those cells.
Count the number of agents blocked, ignoring those which are already marked.
(Otherwise it can double count swaps)
"""
iCount = 0
svBlocked = set()
# The reversed graph allows us to follow directed edges to find affected agents.
Grev = self.G.reverse()
for v in svStops:
# Use depth-first-search to find a tree of agents heading toward the blocked cell.
lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
svBlocked |= set(lvPred)
svBlocked.add(v)
#print("node:", v, "set", svBlocked)
# only count those not already marked
for v2 in [v]+lvPred:
if self.G.nodes[v2].get("color") != color:
self.G.nodes[v2]["color"] = color
iCount += 1
return svBlocked
def find_conflicts(self):
svStops = self.find_stops2() # voluntarily stopped agents - have self-loops
svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions
# Block all swaps and their tree of predessors
self.svDeadlocked = self.block_preds(svSwaps, color="purple")
# Take the union of the above, and find all the predecessors
#svBlocked = self.find_stop_preds(svStops.union(svSwaps))
# Just look for the the tree of preds for each voluntarily stopped agent
svBlocked = self.find_stop_preds(svStops)
# iterate the nodes v with their predecessors dPred (dict of nodes->{})
for (v, dPred) in self.G.pred.items():
# mark any swaps with purple - these are directly deadlocked
#if v in svSwaps:
# self.G.nodes[v]["color"] = "purple"
# If they are not directly deadlocked, but are in the union of stopped + deadlocked
#elif v in svBlocked:
# if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
if v in svBlocked:
self.G.nodes[v]["color"] = "red"
# not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
elif len(dPred)>1:
# if this agent is already red/blocked, ignore. CHECK: why?
# certainly we want to ignore purple so we don't overwrite with red.
if self.G.nodes[v].get("color") in ("red", "purple"):
continue
# if this node has no agent, and >=2 want to enter it.
if self.G.nodes[v].get("agent") is None:
self.G.nodes[v]["color"] = "blue"
# this node has an agent and >=2 want to enter
else:
self.G.nodes[v]["color"] = "magenta"
# predecessors of a contended cell: {agent index -> node}
diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred}
# remove the agent with the lowest index, who wins
iAgWinner = min(diAgCell)
diAgCell.pop(iAgWinner)
# Block all the remaining predessors, and their tree of preds
#for iAg, v in diAgCell.items():
# self.G.nodes[v]["color"] = "red"
# for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
# self.G.nodes[vPred]["color"] = "red"
self.block_preds(diAgCell.values(), "red")
def check_motion(self, iAgent, rcPos):
""" Returns tuple of boolean can the agent move, and the cell it will move into.
If agent position is None, we use a dummy position of (-1, iAgent)
"""
if rcPos is None:
rcPos = (-1, iAgent)
dAttr = self.G.nodes.get(rcPos)
#print("pos:", rcPos, "dAttr:", dAttr)
if dAttr is None:
dAttr = {}
# If it's been marked red or purple then it can't move
if "color" in dAttr:
sColor = dAttr["color"]
if sColor in [ "red", "purple" ]:
return False
dSucc = self.G.succ[rcPos]
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return False
# This agent has a successor
rcNext = self.G.successors(rcPos).__next__()
if rcNext == rcPos: # the agent didn't want to move
return False
# The agent wanted to move, and it can
return True
def render(omc:MotionCheck, horizontal=True):
try:
oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
oAG.layout("dot")
sDot = oAG.to_string()
if horizontal:
sDot = sDot.replace('{', '{ rankdir="LR" ')
#return oAG.draw(format="png")
# This returns a graphviz object which implements __repr_svg
return gv.Source(sDot)
except ImportError as oError:
print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs")
return None
class ChainTestEnv(object):
""" Just for testing agent chains
"""
def __init__(self, omc:MotionCheck):
self.iAgNext = 0
self.iRowNext = 1
self.omc = omc
def addAgent(self, rc1, rc2, xlabel=None):
self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel)
self.iAgNext+=1
def addAgentToRow(self, c1, c2, xlabel=None):
self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel)
def create_test_chain(self,
nAgents:int,
rcVel:Tuple[int] = (0,1),
liStopped:List[int]=[],
xlabel=None):
""" create a chain of agents
"""
lrcAgPos = [ (self.iRowNext, i * rcVel[1]) for i in range(nAgents) ]
for iAg, rcPos in zip(range(nAgents), lrcAgPos):
if iAg in liStopped:
rcVel1 = (0,0)
else:
rcVel1 = rcVel
self.omc.addAgent(iAg+self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1]) )
if xlabel:
self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel
self.iAgNext += nAgents
self.iRowNext += 1
def nextRow(self):
self.iRowNext+=1
def create_test_agents(omc:MotionCheck):
# blocked chain
omc.addAgent(1, (1,2), (1,3))
omc.addAgent(2, (1,3), (1,4))
omc.addAgent(3, (1,4), (1,5))
omc.addAgent(31, (1,5), (1,5))
# unblocked chain
omc.addAgent(4, (2,1), (2,2))
omc.addAgent(5, (2,2), (2,3))
# blocked short chain
omc.addAgent(6, (3,1), (3,2))
omc.addAgent(7, (3,2), (3,2))
# solitary agent
omc.addAgent(8, (4,1), (4,2))
# solitary stopped agent
omc.addAgent(9, (5,1), (5,1))
# blocked short chain (opposite direction)
omc.addAgent(10, (6,4), (6,3))
omc.addAgent(11, (6,3), (6,3))
# swap conflict
omc.addAgent(12, (7,1), (7,2))
omc.addAgent(13, (7,2), (7,1))
def create_test_agents2(omc:MotionCheck):
# blocked chain
cte = ChainTestEnv(omc)
cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain")
cte.create_test_chain(4, xlabel="running\nchain")
cte.create_test_chain(2, liStopped = [1], xlabel="stopped \nshort\n chain")
cte.addAgentToRow(1, 2, "swap")
cte.addAgentToRow(2, 1)
cte.nextRow()
cte.addAgentToRow(1, 2, "chain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nstop")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 4)
cte.addAgentToRow(5, 6)
cte.addAgentToRow(6, 7)
cte.nextRow()
cte.addAgentToRow(1, 2, "midchain\nswap")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(4, 3)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.nextRow()
cte.addAgentToRow(1, 2, "Land on\nSame")
cte.addAgentToRow(3, 2)
cte.nextRow()
cte.addAgentToRow(1, 2, "chains\nonto\nsame")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgentToRow(5, 4)
cte.addAgentToRow(6, 5)
cte.addAgentToRow(7, 6)
cte.nextRow()
cte.addAgentToRow(1, 2, "3-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.nextRow()
if False:
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "4-way\nsame")
cte.addAgentToRow(3, 2)
cte.addAgent((cte.iRowNext+1, 2), (cte.iRowNext, 2))
cte.addAgent((cte.iRowNext-1, 2), (cte.iRowNext, 2))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tee")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
cte.addAgent((cte.iRowNext+1, 3), (cte.iRowNext, 3))
cte.nextRow()
cte.nextRow()
cte.addAgentToRow(1, 2, "Tree")
cte.addAgentToRow(2, 3)
cte.addAgentToRow(3, 4)
r1 = cte.iRowNext
r2 = cte.iRowNext+1
r3 = cte.iRowNext+2
cte.addAgent((r2, 3), (r1, 3))
cte.addAgent((r2, 2), (r2, 3))
cte.addAgent((r3, 2), (r2, 3))
cte.nextRow()
def test_agent_following():
omc = MotionCheck()
create_test_agents2(omc)
svStops = omc.find_stops()
svBlocked = omc.find_stop_preds()
llvSwaps = omc.find_swaps()
svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
print(list(svBlocked))
lvCells = omc.G.nodes()
lColours = [ "magenta" if v in svStops
else "red" if v in svBlocked
else "purple" if v in svSwaps
else "lightblue"
for v in lvCells ]
dPos = dict(zip(lvCells, lvCells))
nx.draw(omc.G,
with_labels=True, arrowsize=20,
pos=dPos,
node_color = lColours)
def main():
test_agent_following()
if __name__=="__main__":
main()
from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
import warnings
from attr import attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.schedule_utils import Schedule
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
class RailAgentStatus(IntEnum):
READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path
DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
('direction', Grid4TransitionsEnum),
('target', Tuple[int, int]),
('moving', bool),
('speed_data', dict),
('malfunction_data', dict),
('earliest_departure', int),
('latest_arrival', int),
('handle', int),
('status', RailAgentStatus),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int])])
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs
class EnvAgent:
# INIT FROM HERE IN _from_line()
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False})))
# NEW : EnvAgent - Schedule properties
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
# used in rendering
old_direction = attrib(default=None)
old_position = attrib(default=None)
def reset(self):
"""
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
"""
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
self.status = RailAgentStatus.READY_TO_DEPART
self.old_position = None
self.old_direction = None
self.moving = False
self.arrival_time = None
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, speed_data=self.speed_data,
malfunction_data=self.malfunction_data, handle=self.handle, status=self.status,
position=self.position, old_direction=self.old_direction, old_position=self.old_position)
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
@classmethod
def from_schedule(cls, schedule: Schedule):
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(schedule.agent_positions)):
speed_datas.append({'position_fraction': 0.0,
'speed': schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
malfunction_datas = []
for i in range(len(schedule.agent_positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': schedule.agent_malfunction_rates[
i] if schedule.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(schedule.agent_positions,
schedule.agent_directions,
schedule.agent_directions,
schedule.agent_targets,
[False] * len(schedule.agent_positions),
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)))))
num_agents = len(line.agent_positions)
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} \n \
initial_direction: {self.initial_direction} \n \
position: {self.position} \n \
direction: {self.direction} \n \
target: {self.target} \n \
old_position: {self.old_position} \n \
old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} \n \
latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")