Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • hebe0663/neurips2020-flatland-starter-kit
  • flatland/neurips2020-flatland-starter-kit
  • manavsinghal157/marl-flatland
3 results
Show changes
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero
def get_agent_positions(env):
agent_positions: np.ndarray = np.full((env.height, env.width), -1)
for agent_handle in env.get_agent_handles():
agent = env.agents[agent_handle]
if agent.status == RailAgentStatus.ACTIVE:
position = agent.position
if position is None:
position = agent.initial_position
agent_positions[position] = agent_handle
return agent_positions
def get_agent_targets(env):
agent_targets = []
for agent_handle in env.get_agent_handles():
agent = env.agents[agent_handle]
if agent.status == RailAgentStatus.ACTIVE:
agent_targets.append(agent.target)
return agent_targets
def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None):
agent = env.agents[handle]
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
return False
position = agent.position
if position is None:
position = agent.initial_position
if check_position is not None:
position = check_position
direction = agent.direction
if check_direction is not None:
direction = check_direction
possible_transitions = env.rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
for dir_loop in range(4):
if possible_transitions[dir_loop] == 1:
new_position = get_new_position(position, dir_loop)
opposite_agent = agent_positions[new_position]
if opposite_agent != handle and opposite_agent != -1:
num_transitions -= 1
else:
return False
is_deadlock = num_transitions <= 0
return is_deadlock
def check_if_all_blocked(env):
......
from typing import List, Optional, Any
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
from utils.agent_action_config import get_flatland_full_action_size
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import get_agent_positions, get_agent_targets
"""
LICENCE for the FastTreeObs Observation Builder
......@@ -21,104 +26,15 @@ Author: Adrian Egli (adrian.egli@gmail.com)
class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth):
def __init__(self, max_depth: Any):
self.max_depth = max_depth
self.observation_dim = 27
def build_data(self):
if self.env is not None:
self.env.dev_obs_dict = {}
self.switches = {}
self.switches_neighbours = {}
self.debug_render_list = []
self.debug_render_path_list = []
if self.env is not None:
self.find_all_cell_where_agent_can_choose()
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env)
else:
self.dead_lock_avoidance_agent = None
def find_all_cell_where_agent_can_choose(self):
switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in switches.keys():
switches.update({pos: [dir]})
else:
switches[pos].append(dir)
switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in switches.keys() and pos not in switches.keys():
if pos not in switches_neighbours.keys():
switches_neighbours.update({pos: [dir]})
else:
switches_neighbours[pos].append(dir)
self.switches = switches
self.switches_neighbours = switches_neighbours
def check_agent_decision(self, position, direction):
switches = self.switches
switches_neighbours = self.switches_neighbours
agents_on_switch = False
agents_on_switch_all = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in switches.keys():
agents_on_switch = direction in switches[position]
agents_on_switch_all = True
if position in switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in switches.keys():
if not direction in switches[new_cell]:
agents_near_to_switch = direction in switches_neighbours[position]
else:
agents_near_to_switch = direction in switches_neighbours[position]
agents_near_to_switch_all = direction in switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def required_agent_decision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_on_switch_all = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
self.check_agent_decision(
self.env.agents[a].position,
self.env.agents[a].direction)
agents_on_switch.update({a: ret_agents_on_switch})
agents_on_switch_all.update({a: ret_agents_on_switch_all})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
self.observation_dim = 35
self.agent_can_choose_helper = None
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(None, get_flatland_full_action_size())
def debug_render(self, env_renderer):
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
self.required_agent_decision()
self.agent_can_choose_helper.required_agent_decision()
self.env.dev_obs_dict = {}
for a in range(max(3, self.env.get_num_agents())):
self.env.dev_obs_dict.update({a: []})
......@@ -141,32 +57,28 @@ class FastTreeObs(ObservationBuilder):
env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
self.env.dev_obs_dict[0] = self.debug_render_list
self.env.dev_obs_dict[1] = self.switches.keys()
self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
self.env.dev_obs_dict[1] = self.agent_can_choose_helper.switches.keys()
self.env.dev_obs_dict[2] = self.agent_can_choose_helper.switches_neighbours.keys()
self.env.dev_obs_dict[3] = self.debug_render_path_list
def reset(self):
self.build_data()
return
def fast_argmax(self, array):
if array[0] == 1:
return 0
if array[1] == 1:
return 1
if array[2] == 1:
return 2
return 3
def _explore(self, handle, new_position, new_direction, depth=0):
if self.agent_can_choose_helper is None:
self.agent_can_choose_helper = AgentCanChooseHelper()
self.agent_can_choose_helper.build_data(self.env)
self.debug_render_list = []
self.debug_render_path_list = []
def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
has_opp_agent = 0
has_same_agent = 0
has_target = 0
has_opp_target = 0
visited = []
min_dist = distance_map[handle, new_position[0], new_position[1], new_direction]
# stop exploring (max_depth reached)
if depth >= self.max_depth:
return has_opp_agent, has_same_agent, has_target, visited
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# max_explore_steps = 100 -> just to ensure that the exploration ends
cnt = 0
......@@ -179,7 +91,7 @@ class FastTreeObs(ObservationBuilder):
if self.env.agents[opp_a].direction != new_direction:
# opp agent found -> stop exploring. This would be a strong signal.
has_opp_agent = 1
return has_opp_agent, has_same_agent, has_target, visited
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else:
# same agent found
# the agent can follow the agent, because this agent is still moving ahead and there shouldn't
......@@ -188,21 +100,26 @@ class FastTreeObs(ObservationBuilder):
# target on this branch -> thus the agents should scan further whether there will be an opposite
# agent walking on same track
has_same_agent = 1
# !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited
# !NOT stop exploring!
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
# agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
#
agents_on_switch, agents_near_to_switch, _, _ = \
self.check_agent_decision(new_position, new_direction)
self.agent_can_choose_helper.check_agent_decision(new_position, new_direction)
if agents_near_to_switch:
# The exploration was walking on a path where the agent can not decide
# Best option would be MOVE_FORWARD -> Skip exploring - just walking
return has_opp_agent, has_same_agent, has_target, visited
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
if self.env.agents[handle].target in self.agents_target:
has_opp_target = 1
if self.env.agents[handle].target == new_position:
has_target = 1
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
if agents_on_switch:
......@@ -217,22 +134,36 @@ class FastTreeObs(ObservationBuilder):
# --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
# we did in the TreeObservation (FLATLAND) ?
if possible_transitions[dir_loop] == 1:
hoa, hsa, ht, v = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
depth + 1)
hoa, hsa, ht, hot, v, m_dist = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
distance_map,
depth + 1)
visited.append(v)
has_opp_agent += hoa * 2 ** (-1 - depth)
has_same_agent += hsa * 2 ** (-1 - depth)
has_opp_agent = max(hoa, has_opp_agent)
has_same_agent = max(hsa, has_same_agent)
has_target = max(has_target, ht)
return has_opp_agent, has_same_agent, has_target, visited
has_opp_target = max(has_opp_target, hot)
min_dist = min(min_dist, m_dist)
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else:
new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction)
return has_opp_agent, has_same_agent, has_target, visited
min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction])
def get(self, handle):
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
def get_many(self, handles: Optional[List[int]] = None):
self.dead_lock_avoidance_agent.reset(self.env)
self.dead_lock_avoidance_agent.start_step(False)
self.agent_positions = get_agent_positions(self.env)
self.agents_target = get_agent_targets(self.env)
observations = super().get_many(handles)
self.dead_lock_avoidance_agent.end_step(False)
return observations
def get(self, handle: int = 0):
# all values are [0,1]
# observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
# observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
......@@ -260,10 +191,6 @@ class FastTreeObs(ObservationBuilder):
# observation[23] : If there is a switch on the path which agent can not use -> 1
# observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
if handle == 0:
self.dead_lock_avoidance_agent.start_step()
observation = np.zeros(self.observation_dim)
visited = []
......@@ -301,25 +228,40 @@ class FastTreeObs(ObservationBuilder):
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, has_target, v = self._explore(handle, new_position, branch_direction)
has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle,
new_position,
branch_direction,
distance_map)
visited.append(v)
observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist))
observation[14 + dir_loop] = has_opp_agent
observation[18 + dir_loop] = has_same_agent
observation[22 + dir_loop] = has_target
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[26] = int(action == RailEnvActions.STOP_MOVING)
if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)):
observation[11 + dir_loop] = int(min_dist < current_cell_dist)
observation[15 + dir_loop] = has_opp_agent
observation[19 + dir_loop] = has_same_agent
observation[23 + dir_loop] = has_target
observation[27 + dir_loop] = has_opp_target
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.agent_can_choose_helper.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_on_switch_all)
observation[9] = int(agents_near_to_switch)
observation[10] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act(handle, None, eps=0)
observation[30] = action == RailEnvActions.DO_NOTHING
observation[31] = action == RailEnvActions.MOVE_LEFT
observation[32] = action == RailEnvActions.MOVE_FORWARD
observation[33] = action == RailEnvActions.MOVE_RIGHT
observation[34] = action == RailEnvActions.STOP_MOVING
self.env.dev_obs_dict.update({handle: visited})
observation[np.isinf(observation)] = -1
observation[np.isnan(observation)] = -1
return observation
import numpy as np
from flatland.envs.rail_env import RailEnvActions
from reinforcement_learning.policy import Policy
class ShortestPathWalkerHeuristicPolicy(Policy):
def step(self, state, action, reward, next_state, done):
pass
def act(self, handle, node, eps=0.):
left_node = node.childs.get('L')
forward_node = node.childs.get('F')
right_node = node.childs.get('R')
dist_map = np.zeros(5)
dist_map[RailEnvActions.DO_NOTHING] = np.inf
dist_map[RailEnvActions.STOP_MOVING] = 100000
# left
if left_node == -np.inf:
dist_map[RailEnvActions.MOVE_LEFT] = np.inf
else:
if left_node.num_agents_opposite_direction == 0:
dist_map[RailEnvActions.MOVE_LEFT] = left_node.dist_min_to_target
else:
dist_map[RailEnvActions.MOVE_LEFT] = np.inf
# forward
if forward_node == -np.inf:
dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
else:
if forward_node.num_agents_opposite_direction == 0:
dist_map[RailEnvActions.MOVE_FORWARD] = forward_node.dist_min_to_target
else:
dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
# right
if right_node == -np.inf:
dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
else:
if right_node.num_agents_opposite_direction == 0:
dist_map[RailEnvActions.MOVE_RIGHT] = right_node.dist_min_to_target
else:
dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
return np.argmin(dist_map)
def save(self, filename):
pass
def load(self, filename):
pass
policy = ShortestPathWalkerHeuristicPolicy()
def normalize_observation(observation, tree_depth: int, observation_radius=0):
return observation