Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • hebe0663/neurips2020-flatland-starter-kit
  • flatland/neurips2020-flatland-starter-kit
  • manavsinghal157/marl-flatland
3 results
Show changes
from flatland.envs.rail_env import RailEnvActions
# global action size
global _agent_action_config_action_size
_agent_action_config_action_size = 5
def get_flatland_full_action_size():
# The action space of flatland is 5 discrete actions
return 5
def set_action_size_full():
global _agent_action_config_action_size
# The agents (DDDQN, PPO, ... ) have this actions space
_agent_action_config_action_size = 5
def set_action_size_reduced():
global _agent_action_config_action_size
# The agents (DDDQN, PPO, ... ) have this actions space
_agent_action_config_action_size = 4
def get_action_size():
global _agent_action_config_action_size
# The agents (DDDQN, PPO, ... ) have this actions space
return _agent_action_config_action_size
def map_actions(actions):
# Map the
if get_action_size() != get_flatland_full_action_size():
for key in actions:
value = actions.get(key, 0)
actions.update({key: map_action(value)})
return actions
def map_action_policy(action):
if get_action_size() != get_flatland_full_action_size():
return action - 1
return action
def map_action(action):
if get_action_size() == get_flatland_full_action_size():
return action
if action == 0:
return RailEnvActions.MOVE_LEFT
if action == 1:
return RailEnvActions.MOVE_FORWARD
if action == 2:
return RailEnvActions.MOVE_RIGHT
if action == 3:
return RailEnvActions.STOP_MOVING
def map_rail_env_action(action):
if get_action_size() == get_flatland_full_action_size():
return action
if action == RailEnvActions.MOVE_LEFT:
return 0
elif action == RailEnvActions.MOVE_FORWARD:
return 1
elif action == RailEnvActions.MOVE_RIGHT:
return 2
elif action == RailEnvActions.STOP_MOVING:
return 3
# action == RailEnvActions.DO_NOTHING:
return 3
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero
class AgentCanChooseHelper:
def __init__(self):
pass
def build_data(self, env):
self.env = env
if self.env is not None:
self.env.dev_obs_dict = {}
self.switches = {}
self.switches_neighbours = {}
if self.env is not None:
self.find_all_cell_where_agent_can_choose()
def find_all_switches(self):
# Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation
# exists and collect all direction where the switch is a switch.
self.switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in self.switches.keys():
self.switches.update({pos: [dir]})
else:
self.switches[pos].append(dir)
def find_all_switch_neighbours(self):
# Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make
# just one step and he stands on a switch. A switch is a cell where the agents has more than one transition.
self.switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in self.switches.keys() and pos not in self.switches.keys():
if pos not in self.switches_neighbours.keys():
self.switches_neighbours.update({pos: [dir]})
else:
self.switches_neighbours[pos].append(dir)
def find_all_cell_where_agent_can_choose(self):
# prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
self.find_all_switches()
self.find_all_switch_neighbours()
def check_agent_decision(self, position, direction):
# Decide whether the agent is
# - on a switch
# - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than
# FORWARD/STOP
# - all switch : doesn't matter whether the agent has more options than FORWARD/STOP
# - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the
# switch
agents_on_switch = False
agents_on_switch_all = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in self.switches.keys():
agents_on_switch = direction in self.switches[position]
agents_on_switch_all = True
if position in self.switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in self.switches.keys():
if not direction in self.switches[new_cell]:
agents_near_to_switch = direction in self.switches_neighbours[position]
else:
agents_near_to_switch = direction in self.switches_neighbours[position]
agents_near_to_switch_all = direction in self.switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def required_agent_decision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_on_switch_all = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
self.check_agent_decision(
self.env.agents[a].position,
self.env.agents[a].direction)
agents_on_switch.update({a: ret_agents_on_switch})
agents_on_switch_all.update({a: ret_agents_on_switch_all})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
......@@ -6,7 +6,8 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
from reinforcement_learning.policy import Policy
from reinforcement_learning.policy import HeuristicPolicy, DummyMemory
from utils.agent_action_config import map_rail_env_action
from utils.shortest_distance_walker import ShortestDistanceWalker
......@@ -65,36 +66,40 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
class DeadLockAvoidanceAgent(Policy):
def __init__(self, env: RailEnv, action_size, show_debug_plot=False):
class DeadLockAvoidanceAgent(HeuristicPolicy):
def __init__(self, env: RailEnv, action_size, enable_eps=False, show_debug_plot=False):
print(">> DeadLockAvoidance")
self.env = env
self.memory = None
self.memory = DummyMemory()
self.loss = 0
self.action_size = action_size
self.agent_can_move = {}
self.agent_can_move_value = {}
self.switches = {}
self.show_debug_plot = show_debug_plot
self.enable_eps = enable_eps
def step(self, state, action, reward, next_state, done):
def step(self, handle, state, action, reward, next_state, done):
pass
def act(self, state, eps=0.):
def act(self, handle, state, eps=0.):
# Epsilon-greedy action selection
if np.random.random() < eps:
return np.random.choice(np.arange(self.action_size))
if self.enable_eps:
if np.random.random() < eps:
return np.random.choice(np.arange(self.action_size))
# agent = self.env.agents[state[0]]
check = self.agent_can_move.get(state[0], None)
if check is None:
return RailEnvActions.STOP_MOVING
return check[3]
check = self.agent_can_move.get(handle, None)
act = RailEnvActions.STOP_MOVING
if check is not None:
act = check[3]
return map_rail_env_action(act)
def get_agent_can_move_value(self, handle):
return self.agent_can_move_value.get(handle, np.inf)
def reset(self):
def reset(self, env):
self.env = env
self.agent_positions = None
self.shortest_distance_walker = None
self.switches = {}
......
......@@ -17,6 +17,15 @@ def get_agent_positions(env):
return agent_positions
def get_agent_targets(env):
agent_targets = []
for agent_handle in env.get_agent_handles():
agent = env.agents[agent_handle]
if agent.status == RailAgentStatus.ACTIVE:
agent_targets.append(agent.target)
return agent_targets
def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None):
agent = env.agents[handle]
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
......
from typing import List, Optional
from typing import List, Optional, Any
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
......@@ -6,8 +6,10 @@ from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
from utils.agent_action_config import get_flatland_full_action_size
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import check_for_deadlock, get_agent_positions
from utils.deadlock_check import get_agent_positions, get_agent_targets
"""
LICENCE for the FastTreeObs Observation Builder
......@@ -24,116 +26,15 @@ Author: Adrian Egli (adrian.egli@gmail.com)
class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth):
def __init__(self, max_depth: Any):
self.max_depth = max_depth
self.observation_dim = 41
def build_data(self):
if self.env is not None:
self.env.dev_obs_dict = {}
self.switches = {}
self.switches_neighbours = {}
self.debug_render_list = []
self.debug_render_path_list = []
if self.env is not None:
self.find_all_cell_where_agent_can_choose()
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False)
else:
self.dead_lock_avoidance_agent = None
def find_all_switches(self):
# Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation
# exists and collect all direction where the switch is a switch.
self.switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in self.switches.keys():
self.switches.update({pos: [dir]})
else:
self.switches[pos].append(dir)
def find_all_switch_neighbours(self):
# Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make
# just one step and he stands on a switch. A switch is a cell where the agents has more than one transition.
self.switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in self.switches.keys() and pos not in self.switches.keys():
if pos not in self.switches_neighbours.keys():
self.switches_neighbours.update({pos: [dir]})
else:
self.switches_neighbours[pos].append(dir)
def find_all_cell_where_agent_can_choose(self):
# prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
self.find_all_switches()
self.find_all_switch_neighbours()
def check_agent_decision(self, position, direction):
# Decide whether the agent is
# - on a switch
# - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than
# FORWARD/STOP
# - all switch : doesn't matter whether the agent has more options than FORWARD/STOP
# - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the
# switch
agents_on_switch = False
agents_on_switch_all = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in self.switches.keys():
agents_on_switch = direction in self.switches[position]
agents_on_switch_all = True
if position in self.switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in self.switches.keys():
if not direction in self.switches[new_cell]:
agents_near_to_switch = direction in self.switches_neighbours[position]
else:
agents_near_to_switch = direction in self.switches_neighbours[position]
agents_near_to_switch_all = direction in self.switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def required_agent_decision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_on_switch_all = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
self.check_agent_decision(
self.env.agents[a].position,
self.env.agents[a].direction)
agents_on_switch.update({a: ret_agents_on_switch})
agents_on_switch_all.update({a: ret_agents_on_switch_all})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
self.observation_dim = 35
self.agent_can_choose_helper = None
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(None, get_flatland_full_action_size())
def debug_render(self, env_renderer):
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
self.required_agent_decision()
self.agent_can_choose_helper.required_agent_decision()
self.env.dev_obs_dict = {}
for a in range(max(3, self.env.get_num_agents())):
self.env.dev_obs_dict.update({a: []})
......@@ -156,24 +57,28 @@ class FastTreeObs(ObservationBuilder):
env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
self.env.dev_obs_dict[0] = self.debug_render_list
self.env.dev_obs_dict[1] = self.switches.keys()
self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
self.env.dev_obs_dict[1] = self.agent_can_choose_helper.switches.keys()
self.env.dev_obs_dict[2] = self.agent_can_choose_helper.switches_neighbours.keys()
self.env.dev_obs_dict[3] = self.debug_render_path_list
def reset(self):
self.build_data()
return
if self.agent_can_choose_helper is None:
self.agent_can_choose_helper = AgentCanChooseHelper()
self.agent_can_choose_helper.build_data(self.env)
self.debug_render_list = []
self.debug_render_path_list = []
def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
has_opp_agent = 0
has_same_agent = 0
has_target = 0
has_opp_target = 0
visited = []
min_dist = distance_map[handle, new_position[0], new_position[1], new_direction]
# stop exploring (max_depth reached)
if depth >= self.max_depth:
return has_opp_agent, has_same_agent, has_target, visited, min_dist
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# max_explore_steps = 100 -> just to ensure that the exploration ends
cnt = 0
......@@ -186,7 +91,7 @@ class FastTreeObs(ObservationBuilder):
if self.env.agents[opp_a].direction != new_direction:
# opp agent found -> stop exploring. This would be a strong signal.
has_opp_agent = 1
return has_opp_agent, has_same_agent, has_target, visited, min_dist
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else:
# same agent found
# the agent can follow the agent, because this agent is still moving ahead and there shouldn't
......@@ -195,21 +100,26 @@ class FastTreeObs(ObservationBuilder):
# target on this branch -> thus the agents should scan further whether there will be an opposite
# agent walking on same track
has_same_agent = 1
# !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited,min_dist
# !NOT stop exploring!
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
# agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
#
agents_on_switch, agents_near_to_switch, _, _ = \
self.check_agent_decision(new_position, new_direction)
self.agent_can_choose_helper.check_agent_decision(new_position, new_direction)
if agents_near_to_switch:
# The exploration was walking on a path where the agent can not decide
# Best option would be MOVE_FORWARD -> Skip exploring - just walking
return has_opp_agent, has_same_agent, has_target, visited, min_dist
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
if self.env.agents[handle].target in self.agents_target:
has_opp_target = 1
if self.env.agents[handle].target == new_position:
has_target = 1
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
if agents_on_switch:
......@@ -224,33 +134,36 @@ class FastTreeObs(ObservationBuilder):
# --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
# we did in the TreeObservation (FLATLAND) ?
if possible_transitions[dir_loop] == 1:
hoa, hsa, ht, v, m_dist = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
distance_map,
depth + 1)
hoa, hsa, ht, hot, v, m_dist = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
distance_map,
depth + 1)
visited.append(v)
has_opp_agent += max(hoa, has_opp_agent)
has_same_agent += max(hsa, has_same_agent)
has_opp_agent = max(hoa, has_opp_agent)
has_same_agent = max(hsa, has_same_agent)
has_target = max(has_target, ht)
has_opp_target = max(has_opp_target, hot)
min_dist = min(min_dist, m_dist)
return has_opp_agent, has_same_agent, has_target, visited, min_dist
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else:
new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction)
min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction])
return has_opp_agent, has_same_agent, has_target, visited, min_dist
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
def get_many(self, handles: Optional[List[int]] = None):
self.dead_lock_avoidance_agent.start_step(train=False)
self.dead_lock_avoidance_agent.reset(self.env)
self.dead_lock_avoidance_agent.start_step(False)
self.agent_positions = get_agent_positions(self.env)
self.agents_target = get_agent_targets(self.env)
observations = super().get_many(handles)
self.dead_lock_avoidance_agent.end_step(train=False)
self.dead_lock_avoidance_agent.end_step(False)
return observations
def get(self, handle):
def get(self, handle: int = 0):
# all values are [0,1]
# observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
# observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
......@@ -278,8 +191,6 @@ class FastTreeObs(ObservationBuilder):
# observation[23] : If there is a switch on the path which agent can not use -> 1
# observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
# observation[27] : If there the agent can only walk forward or stop -> 1
observation = np.zeros(self.observation_dim)
visited = []
......@@ -317,40 +228,36 @@ class FastTreeObs(ObservationBuilder):
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, has_target, v, min_dist = self._explore(handle,
new_position,
branch_direction,
distance_map)
has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle,
new_position,
branch_direction,
distance_map)
visited.append(v)
if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)):
observation[31 + dir_loop] = int(min_dist < current_cell_dist)
observation[11 + dir_loop] = int(not np.math.isinf(new_cell_dist))
observation[11 + dir_loop] = int(min_dist < current_cell_dist)
observation[15 + dir_loop] = has_opp_agent
observation[19 + dir_loop] = has_same_agent
observation[23 + dir_loop] = has_target
observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist))
observation[36] = int(check_for_deadlock(handle,
self.env,
self.agent_positions,
new_position,
branch_direction))
observation[27 + dir_loop] = has_opp_target
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_decision(agent_virtual_position, agent.direction)
self.agent_can_choose_helper.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_on_switch_all)
observation[9] = int(agents_near_to_switch)
observation[10] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[35] = int(action == RailEnvActions.STOP_MOVING)
observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions))
action = self.dead_lock_avoidance_agent.act(handle, None, eps=0)
observation[30] = action == RailEnvActions.DO_NOTHING
observation[31] = action == RailEnvActions.MOVE_LEFT
observation[32] = action == RailEnvActions.MOVE_FORWARD
observation[33] = action == RailEnvActions.MOVE_RIGHT
observation[34] = action == RailEnvActions.STOP_MOVING
self.env.dev_obs_dict.update({handle: visited})
......
......@@ -8,7 +8,7 @@ class ShortestPathWalkerHeuristicPolicy(Policy):
def step(self, state, action, reward, next_state, done):
pass
def act(self, node, eps=0.):
def act(self, handle, node, eps=0.):
left_node = node.childs.get('L')
forward_node = node.childs.get('F')
......