Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • hebe0663/neurips2020-flatland-starter-kit
  • flatland/neurips2020-flatland-starter-kit
  • manavsinghal157/marl-flatland
3 results
Show changes
Showing
with 410 additions and 609 deletions
runs_bench/Screenshots/full.png

139 KiB

runs_bench/Screenshots/reduced.png

178 KiB

from flatland.envs.rail_env import RailEnvActions
# global action size
global _agent_action_config_action_size
_agent_action_config_action_size = 5
def get_flatland_full_action_size():
# The action space of flatland is 5 discrete actions
return 5
def set_action_size_full():
global _agent_action_config_action_size
# The agents (DDDQN, PPO, ... ) have this actions space
_agent_action_config_action_size = 5
def set_action_size_reduced():
global _agent_action_config_action_size
# The agents (DDDQN, PPO, ... ) have this actions space
_agent_action_config_action_size = 4
def get_action_size():
global _agent_action_config_action_size
# The agents (DDDQN, PPO, ... ) have this actions space
return _agent_action_config_action_size
def map_actions(actions):
# Map the
if get_action_size() != get_flatland_full_action_size():
for key in actions:
value = actions.get(key, 0)
actions.update({key: map_action(value)})
return actions
def map_action_policy(action):
if get_action_size() != get_flatland_full_action_size():
return action - 1
return action
def map_action(action):
if get_action_size() == get_flatland_full_action_size():
return action
if action == 0:
return RailEnvActions.MOVE_LEFT
if action == 1:
return RailEnvActions.MOVE_FORWARD
if action == 2:
return RailEnvActions.MOVE_RIGHT
if action == 3:
return RailEnvActions.STOP_MOVING
def map_rail_env_action(action):
if get_action_size() == get_flatland_full_action_size():
return action
if action == RailEnvActions.MOVE_LEFT:
return 0
elif action == RailEnvActions.MOVE_FORWARD:
return 1
elif action == RailEnvActions.MOVE_RIGHT:
return 2
elif action == RailEnvActions.STOP_MOVING:
return 3
# action == RailEnvActions.DO_NOTHING:
return 3
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero
class AgentCanChooseHelper:
def __init__(self):
pass
def build_data(self, env):
self.env = env
if self.env is not None:
self.env.dev_obs_dict = {}
self.switches = {}
self.switches_neighbours = {}
if self.env is not None:
self.find_all_cell_where_agent_can_choose()
def find_all_switches(self):
# Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation
# exists and collect all direction where the switch is a switch.
self.switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in self.switches.keys():
self.switches.update({pos: [dir]})
else:
self.switches[pos].append(dir)
def find_all_switch_neighbours(self):
# Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make
# just one step and he stands on a switch. A switch is a cell where the agents has more than one transition.
self.switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in self.switches.keys() and pos not in self.switches.keys():
if pos not in self.switches_neighbours.keys():
self.switches_neighbours.update({pos: [dir]})
else:
self.switches_neighbours[pos].append(dir)
def find_all_cell_where_agent_can_choose(self):
# prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
self.find_all_switches()
self.find_all_switch_neighbours()
def check_agent_decision(self, position, direction):
# Decide whether the agent is
# - on a switch
# - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than
# FORWARD/STOP
# - all switch : doesn't matter whether the agent has more options than FORWARD/STOP
# - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the
# switch
agents_on_switch = False
agents_on_switch_all = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in self.switches.keys():
agents_on_switch = direction in self.switches[position]
agents_on_switch_all = True
if position in self.switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in self.switches.keys():
if not direction in self.switches[new_cell]:
agents_near_to_switch = direction in self.switches_neighbours[position]
else:
agents_near_to_switch = direction in self.switches_neighbours[position]
agents_near_to_switch_all = direction in self.switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def required_agent_decision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_on_switch_all = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
self.check_agent_decision(
self.env.agents[a].position,
self.env.agents[a].direction)
agents_on_switch.update({a: ret_agents_on_switch})
agents_on_switch_all.update({a: ret_agents_on_switch_all})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
......@@ -6,7 +6,8 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
from reinforcement_learning.policy import Policy
from reinforcement_learning.policy import HeuristicPolicy, DummyMemory
from utils.agent_action_config import map_rail_env_action
from utils.shortest_distance_walker import ShortestDistanceWalker
......@@ -65,27 +66,40 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
class DeadLockAvoidanceAgent(Policy):
def __init__(self, env: RailEnv, show_debug_plot=False):
class DeadLockAvoidanceAgent(HeuristicPolicy):
def __init__(self, env: RailEnv, action_size, enable_eps=False, show_debug_plot=False):
print(">> DeadLockAvoidance")
self.env = env
self.memory = None
self.memory = DummyMemory()
self.loss = 0
self.action_size = action_size
self.agent_can_move = {}
self.agent_can_move_value = {}
self.switches = {}
self.show_debug_plot = show_debug_plot
self.enable_eps = enable_eps
def step(self, state, action, reward, next_state, done):
def step(self, handle, state, action, reward, next_state, done):
pass
def act(self, state, eps=0.):
def act(self, handle, state, eps=0.):
# Epsilon-greedy action selection
if self.enable_eps:
if np.random.random() < eps:
return np.random.choice(np.arange(self.action_size))
# agent = self.env.agents[state[0]]
check = self.agent_can_move.get(state[0], None)
if check is None:
return RailEnvActions.STOP_MOVING
return check[3]
check = self.agent_can_move.get(handle, None)
act = RailEnvActions.STOP_MOVING
if check is not None:
act = check[3]
return map_rail_env_action(act)
def reset(self):
def get_agent_can_move_value(self, handle):
return self.agent_can_move_value.get(handle, np.inf)
def reset(self, env):
self.env = env
self.agent_positions = None
self.shortest_distance_walker = None
self.switches = {}
......@@ -101,12 +115,12 @@ class DeadLockAvoidanceAgent(Policy):
else:
self.switches[pos].append(dir)
def start_step(self):
def start_step(self, train):
self.build_agent_position_map()
self.shortest_distance_mapper()
self.extract_agent_can_move()
def end_step(self):
def end_step(self, train):
pass
def get_actions(self):
......@@ -136,7 +150,8 @@ class DeadLockAvoidanceAgent(Policy):
for handle in range(self.env.get_num_agents()):
agent = self.env.agents[handle]
if agent.status < RailAgentStatus.DONE:
next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle],
next_step_ok = self.check_agent_can_move(handle,
shortest_distance_agent_map[handle],
self.shortest_distance_walker.same_agent_map.get(handle, []),
self.shortest_distance_walker.opp_agent_map.get(handle, []),
full_shortest_distance_agent_map)
......@@ -154,6 +169,7 @@ class DeadLockAvoidanceAgent(Policy):
plt.pause(0.01)
def check_agent_can_move(self,
handle,
my_shortest_walking_path,
same_agents,
opp_agents,
......@@ -166,6 +182,9 @@ class DeadLockAvoidanceAgent(Policy):
delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int)
if np.sum(delta) < (3 + len(opp_agents)):
next_step_ok = False
v = self.agent_can_move_value.get(handle, np.inf)
v = min(v, np.sum(delta))
self.agent_can_move_value.update({handle: v})
return next_step_ok
def save(self, filename):
......
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero
def get_agent_positions(env):
agent_positions: np.ndarray = np.full((env.height, env.width), -1)
for agent_handle in env.get_agent_handles():
agent = env.agents[agent_handle]
if agent.status == RailAgentStatus.ACTIVE:
position = agent.position
if position is None:
position = agent.initial_position
agent_positions[position] = agent_handle
return agent_positions
def get_agent_targets(env):
agent_targets = []
for agent_handle in env.get_agent_handles():
agent = env.agents[agent_handle]
if agent.status == RailAgentStatus.ACTIVE:
agent_targets.append(agent.target)
return agent_targets
def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None):
agent = env.agents[handle]
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
return False
position = agent.position
if position is None:
position = agent.initial_position
if check_position is not None:
position = check_position
direction = agent.direction
if check_direction is not None:
direction = check_direction
possible_transitions = env.rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
for dir_loop in range(4):
if possible_transitions[dir_loop] == 1:
new_position = get_new_position(position, dir_loop)
opposite_agent = agent_positions[new_position]
if opposite_agent != handle and opposite_agent != -1:
num_transitions -= 1
else:
return False
is_deadlock = num_transitions <= 0
return is_deadlock
def check_if_all_blocked(env):
......
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnvActions, fast_argmax, fast_count_nonzero
from reinforcement_learning.policy import Policy
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent, DeadlockAvoidanceShortestDistanceWalker
class ExtraPolicy(Policy):
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.memory = []
self.loss = 0
def load(self, filename):
pass
def save(self, filename):
pass
def step(self, handle, state, action, reward, next_state, done):
pass
def act(self, handle, state, eps=0.):
a = 0
b = 4
action = RailEnvActions.STOP_MOVING
if state[2] == 1 and state[10 + a] == 0:
action = RailEnvActions.MOVE_LEFT
elif state[3] == 1 and state[11 + a] == 0:
action = RailEnvActions.MOVE_FORWARD
elif state[4] == 1 and state[12 + a] == 0:
action = RailEnvActions.MOVE_RIGHT
elif state[5] == 1 and state[13 + a] == 0:
action = RailEnvActions.MOVE_FORWARD
elif state[6] == 1 and state[10 + b] == 0:
action = RailEnvActions.MOVE_LEFT
elif state[7] == 1 and state[11 + b] == 0:
action = RailEnvActions.MOVE_FORWARD
elif state[8] == 1 and state[12 + b] == 0:
action = RailEnvActions.MOVE_RIGHT
elif state[9] == 1 and state[13 + b] == 0:
action = RailEnvActions.MOVE_FORWARD
return action
def test(self):
pass
class Extra(ObservationBuilder):
def __init__(self, max_depth):
self.max_depth = max_depth
self.observation_dim = 31
def build_data(self):
self.dead_lock_avoidance_agent = None
if self.env is not None:
self.env.dev_obs_dict = {}
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, None, None)
self.switches = {}
self.switches_neighbours = {}
self.debug_render_list = []
self.debug_render_path_list = []
if self.env is not None:
self.find_all_cell_where_agent_can_choose()
def find_all_cell_where_agent_can_choose(self):
switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in switches.keys():
switches.update({pos: [dir]})
else:
switches[pos].append(dir)
switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in switches.keys() and pos not in switches.keys():
if pos not in switches_neighbours.keys():
switches_neighbours.update({pos: [dir]})
else:
switches_neighbours[pos].append(dir)
self.switches = switches
self.switches_neighbours = switches_neighbours
def check_agent_descision(self, position, direction):
switches = self.switches
switches_neighbours = self.switches_neighbours
agents_on_switch = False
agents_on_switch_all = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in switches.keys():
agents_on_switch = direction in switches[position]
agents_on_switch_all = True
if position in switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in switches.keys():
if not direction in switches[new_cell]:
agents_near_to_switch = direction in switches_neighbours[position]
else:
agents_near_to_switch = direction in switches_neighbours[position]
agents_near_to_switch_all = direction in switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def required_agent_descision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_on_switch_all = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
self.check_agent_descision(
self.env.agents[a].position,
self.env.agents[a].direction)
agents_on_switch.update({a: ret_agents_on_switch})
agents_on_switch_all.update({a: ret_agents_on_switch_all})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def debug_render(self, env_renderer):
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
self.required_agent_descision()
self.env.dev_obs_dict = {}
for a in range(max(3, self.env.get_num_agents())):
self.env.dev_obs_dict.update({a: []})
selected_agent = None
if agents_can_choose[0]:
if self.env.agents[0].position is not None:
self.debug_render_list.append(self.env.agents[0].position)
else:
self.debug_render_list.append(self.env.agents[0].initial_position)
if self.env.agents[0].position is not None:
self.debug_render_path_list.append(self.env.agents[0].position)
else:
self.debug_render_path_list.append(self.env.agents[0].initial_position)
env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
self.env.dev_obs_dict[0] = self.debug_render_list
self.env.dev_obs_dict[1] = self.switches.keys()
self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
self.env.dev_obs_dict[3] = self.debug_render_path_list
def reset(self):
self.build_data()
return
def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
_, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
local_walker = DeadlockAvoidanceShortestDistanceWalker(
self.env,
self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
local_walker.walk_to_target(handle, new_position, branch_direction)
shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
my_shortest_path_to_check = shortest_distance_agent_map[handle]
next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
opp_agents,
full_shortest_distance_agent_map)
return next_step_ok
def _explore(self, handle, new_position, new_direction, depth=0):
has_opp_agent = 0
has_same_agent = 0
has_switch = 0
visited = []
# stop exploring (max_depth reached)
if depth >= self.max_depth:
return has_opp_agent, has_same_agent, has_switch, visited
# max_explore_steps = 100
cnt = 0
while cnt < 100:
cnt += 1
visited.append(new_position)
opp_a = self.env.agent_positions[new_position]
if opp_a != -1 and opp_a != handle:
if self.env.agents[opp_a].direction != new_direction:
# opp agent found
has_opp_agent = 1
return has_opp_agent, has_same_agent, has_switch, visited
else:
has_same_agent = 1
return has_opp_agent, has_same_agent, has_switch, visited
# convert one-hot encoding to 0,1,2,3
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_descision(new_position, new_direction)
if agents_near_to_switch:
return has_opp_agent, has_same_agent, has_switch, visited
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
if agents_on_switch:
f = 0
for dir_loop in range(4):
if possible_transitions[dir_loop] == 1:
f += 1
hoa, hsa, hs, v = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
depth + 1)
visited.append(v)
has_opp_agent += hoa
has_same_agent += hsa
has_switch += hs
f = max(f, 1.0)
return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
else:
new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction)
return has_opp_agent, has_same_agent, has_switch, visited
def get(self, handle):
if handle == 0:
self.dead_lock_avoidance_agent.start_step()
# all values are [0,1]
# observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
# observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
# observation[2] : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
# observation[3] : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
# observation[4] : int(agent.status == RailAgentStatus.READY_TO_DEPART)
# observation[5] : int(agent.status == RailAgentStatus.ACTIVE)
# observation[6] : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
# observation[7] : current agent is located at a switch, where it can take a routing decision
# observation[8] : current agent is located at a cell, where it has to take a stop-or-go decision
# observation[9] : current agent is located one step before/after a switch
# observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
# observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
# observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
# observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
# observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
# observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
# observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
# observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
# observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
# observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
# observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
# observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
# observation[22] : If there is a switch on the path which agent can not use -> 1
# observation[23] : If there is a switch on the path which agent can not use -> 1
# observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : Is there a deadlock signal on shortest path walk(s) (direction 0)-> 1
# observation[27] : Is there a deadlock signal on shortest path walk(s) (direction 1)-> 1
# observation[28] : Is there a deadlock signal on shortest path walk(s) (direction 2)-> 1
# observation[29] : Is there a deadlock signal on shortest path walk(s) (direction 3)-> 1
# observation[30] : Is there a deadlock signal on shortest path walk(s) (current position check)-> 1
observation = np.zeros(self.observation_dim)
visited = []
agent = self.env.agents[handle]
agent_done = False
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
observation[4] = 1
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
observation[5] = 1
else:
observation[6] = 1
agent_virtual_position = (-1, -1)
agent_done = True
if not agent_done:
visited.append(agent_virtual_position)
distance_map = self.env.distance_map.get()
current_cell_dist = distance_map[handle,
agent_virtual_position[0], agent_virtual_position[1],
agent.direction]
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
orientation = agent.direction
if fast_count_nonzero(possible_transitions) == 1:
orientation = fast_argmax(possible_transitions)
for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_position = get_new_position(agent_virtual_position, branch_direction)
new_cell_dist = distance_map[handle,
new_position[0], new_position[1],
branch_direction]
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
visited.append(v)
observation[10 + dir_loop] = 1
observation[14 + dir_loop] = has_opp_agent
observation[18 + dir_loop] = has_same_agent
observation[22 + dir_loop] = has_switch
next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction)
if next_step_ok:
observation[26 + dir_loop] = 1
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_descision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all)
observation[30] = int(self.dead_lock_avoidance_agent.act(handle, None, 0) == RailEnvActions.STOP_MOVING)
self.env.dev_obs_dict.update({handle: visited})
return observation
@staticmethod
def agent_can_choose(observation):
return observation[7] == 1 or observation[8] == 1
from typing import List, Optional, Any
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
from utils.agent_action_config import get_flatland_full_action_size
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import get_agent_positions, get_agent_targets
"""
LICENCE for the FastTreeObs Observation Builder
......@@ -21,104 +26,15 @@ Author: Adrian Egli (adrian.egli@gmail.com)
class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth):
def __init__(self, max_depth: Any):
self.max_depth = max_depth
self.observation_dim = 27
def build_data(self):
if self.env is not None:
self.env.dev_obs_dict = {}
self.switches = {}
self.switches_neighbours = {}
self.debug_render_list = []
self.debug_render_path_list = []
if self.env is not None:
self.find_all_cell_where_agent_can_choose()
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env)
else:
self.dead_lock_avoidance_agent = None
def find_all_cell_where_agent_can_choose(self):
switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in switches.keys():
switches.update({pos: [dir]})
else:
switches[pos].append(dir)
switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in switches.keys() and pos not in switches.keys():
if pos not in switches_neighbours.keys():
switches_neighbours.update({pos: [dir]})
else:
switches_neighbours[pos].append(dir)
self.switches = switches
self.switches_neighbours = switches_neighbours
def check_agent_decision(self, position, direction):
switches = self.switches
switches_neighbours = self.switches_neighbours
agents_on_switch = False
agents_on_switch_all = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in switches.keys():
agents_on_switch = direction in switches[position]
agents_on_switch_all = True
if position in switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in switches.keys():
if not direction in switches[new_cell]:
agents_near_to_switch = direction in switches_neighbours[position]
else:
agents_near_to_switch = direction in switches_neighbours[position]
agents_near_to_switch_all = direction in switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
def required_agent_decision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_on_switch_all = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
self.check_agent_decision(
self.env.agents[a].position,
self.env.agents[a].direction)
agents_on_switch.update({a: ret_agents_on_switch})
agents_on_switch_all.update({a: ret_agents_on_switch_all})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
self.observation_dim = 35
self.agent_can_choose_helper = None
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(None, get_flatland_full_action_size())
def debug_render(self, env_renderer):
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
self.required_agent_decision()
self.agent_can_choose_helper.required_agent_decision()
self.env.dev_obs_dict = {}
for a in range(max(3, self.env.get_num_agents())):
self.env.dev_obs_dict.update({a: []})
......@@ -141,32 +57,28 @@ class FastTreeObs(ObservationBuilder):
env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
self.env.dev_obs_dict[0] = self.debug_render_list
self.env.dev_obs_dict[1] = self.switches.keys()
self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
self.env.dev_obs_dict[1] = self.agent_can_choose_helper.switches.keys()
self.env.dev_obs_dict[2] = self.agent_can_choose_helper.switches_neighbours.keys()
self.env.dev_obs_dict[3] = self.debug_render_path_list
def reset(self):
self.build_data()
return
def fast_argmax(self, array):
if array[0] == 1:
return 0
if array[1] == 1:
return 1
if array[2] == 1:
return 2
return 3
def _explore(self, handle, new_position, new_direction, depth=0):
if self.agent_can_choose_helper is None:
self.agent_can_choose_helper = AgentCanChooseHelper()
self.agent_can_choose_helper.build_data(self.env)
self.debug_render_list = []
self.debug_render_path_list = []
def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
has_opp_agent = 0
has_same_agent = 0
has_target = 0
has_opp_target = 0
visited = []
min_dist = distance_map[handle, new_position[0], new_position[1], new_direction]
# stop exploring (max_depth reached)
if depth >= self.max_depth:
return has_opp_agent, has_same_agent, has_target, visited
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# max_explore_steps = 100 -> just to ensure that the exploration ends
cnt = 0
......@@ -179,7 +91,7 @@ class FastTreeObs(ObservationBuilder):
if self.env.agents[opp_a].direction != new_direction:
# opp agent found -> stop exploring. This would be a strong signal.
has_opp_agent = 1
return has_opp_agent, has_same_agent, has_target, visited
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else:
# same agent found
# the agent can follow the agent, because this agent is still moving ahead and there shouldn't
......@@ -188,21 +100,26 @@ class FastTreeObs(ObservationBuilder):
# target on this branch -> thus the agents should scan further whether there will be an opposite
# agent walking on same track
has_same_agent = 1
# !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited
# !NOT stop exploring!
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
# agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
#
agents_on_switch, agents_near_to_switch, _, _ = \
self.check_agent_decision(new_position, new_direction)
self.agent_can_choose_helper.check_agent_decision(new_position, new_direction)
if agents_near_to_switch:
# The exploration was walking on a path where the agent can not decide
# Best option would be MOVE_FORWARD -> Skip exploring - just walking
return has_opp_agent, has_same_agent, has_target, visited
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
if self.env.agents[handle].target in self.agents_target:
has_opp_target = 1
if self.env.agents[handle].target == new_position:
has_target = 1
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
if agents_on_switch:
......@@ -217,22 +134,36 @@ class FastTreeObs(ObservationBuilder):
# --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
# we did in the TreeObservation (FLATLAND) ?
if possible_transitions[dir_loop] == 1:
hoa, hsa, ht, v = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
depth + 1)
hoa, hsa, ht, hot, v, m_dist = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
distance_map,
depth + 1)
visited.append(v)
has_opp_agent += hoa * 2 ** (-1 - depth)
has_same_agent += hsa * 2 ** (-1 - depth)
has_opp_agent = max(hoa, has_opp_agent)
has_same_agent = max(hsa, has_same_agent)
has_target = max(has_target, ht)
return has_opp_agent, has_same_agent, has_target, visited
has_opp_target = max(has_opp_target, hot)
min_dist = min(min_dist, m_dist)
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else:
new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction)
return has_opp_agent, has_same_agent, has_target, visited
min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction])
def get(self, handle):
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
def get_many(self, handles: Optional[List[int]] = None):
self.dead_lock_avoidance_agent.reset(self.env)
self.dead_lock_avoidance_agent.start_step(False)
self.agent_positions = get_agent_positions(self.env)
self.agents_target = get_agent_targets(self.env)
observations = super().get_many(handles)
self.dead_lock_avoidance_agent.end_step(False)
return observations
def get(self, handle: int = 0):
# all values are [0,1]
# observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
# observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
......@@ -260,10 +191,6 @@ class FastTreeObs(ObservationBuilder):
# observation[23] : If there is a switch on the path which agent can not use -> 1
# observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
if handle == 0:
self.dead_lock_avoidance_agent.start_step()
observation = np.zeros(self.observation_dim)
visited = []
......@@ -301,25 +228,40 @@ class FastTreeObs(ObservationBuilder):
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, has_target, v = self._explore(handle, new_position, branch_direction)
has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle,
new_position,
branch_direction,
distance_map)
visited.append(v)
observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist))
observation[14 + dir_loop] = has_opp_agent
observation[18 + dir_loop] = has_same_agent
observation[22 + dir_loop] = has_target
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[26] = int(action == RailEnvActions.STOP_MOVING)
if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)):
observation[11 + dir_loop] = int(min_dist < current_cell_dist)
observation[15 + dir_loop] = has_opp_agent
observation[19 + dir_loop] = has_same_agent
observation[23 + dir_loop] = has_target
observation[27 + dir_loop] = has_opp_target
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.agent_can_choose_helper.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_on_switch_all)
observation[9] = int(agents_near_to_switch)
observation[10] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act(handle, None, eps=0)
observation[30] = action == RailEnvActions.DO_NOTHING
observation[31] = action == RailEnvActions.MOVE_LEFT
observation[32] = action == RailEnvActions.MOVE_FORWARD
observation[33] = action == RailEnvActions.MOVE_RIGHT
observation[34] = action == RailEnvActions.STOP_MOVING
self.env.dev_obs_dict.update({handle: visited})
observation[np.isinf(observation)] = -1
observation[np.isnan(observation)] = -1
return observation
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
class ShortestDistanceWalker:
def __init__(self, env: RailEnv):
self.env = env
def walk(self, handle, position, direction):
possible_transitions = self.env.rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions == 1:
new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(position, new_direction)
dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
else:
min_distances = []
positions = []
directions = []
for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(position, new_direction)
min_distances.append(
self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
positions.append(new_position)
directions.append(new_direction)
else:
min_distances.append(np.inf)
positions.append(None)
directions.append(None)
a = self.get_action(handle, min_distances)
return positions[a], directions[a], min_distances[a], a + 1, possible_transitions
def get_action(self, handle, min_distances):
return np.argmin(min_distances)
def callback(self, handle, agent, position, direction, action, possible_transitions):
pass
def get_agent_position_and_direction(self, handle):
agent = self.env.agents[handle]
if agent.position is not None:
position = agent.position
else:
position = agent.initial_position
direction = agent.direction
return position, direction
def walk_to_target(self, handle, position=None, direction=None, max_step=500):
if position is None and direction is None:
position, direction = self.get_agent_position_and_direction(handle)
elif position is None:
position, _ = self.get_agent_position_and_direction(handle)
elif direction is None:
_, direction = self.get_agent_position_and_direction(handle)
agent = self.env.agents[handle]
step = 0
while (position != agent.target) and (step < max_step):
position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
if position is None:
break
self.callback(handle, agent, position, direction, action, possible_transitions)
step += 1
def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
pass
def walk_one_step(self, handle):
agent = self.env.agents[handle]
if agent.position is not None:
position = agent.position
else:
position = agent.initial_position
direction = agent.direction
possible_transitions = (0, 1, 0, 0)
if (position != agent.target):
new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
if new_position is None:
return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
return new_position, new_direction, action, possible_transitions
import numpy as np
from flatland.envs.rail_env import RailEnvActions
from reinforcement_learning.policy import Policy
class ShortestPathWalkerHeuristicPolicy(Policy):
def step(self, state, action, reward, next_state, done):
pass
def act(self, handle, node, eps=0.):
left_node = node.childs.get('L')
forward_node = node.childs.get('F')
right_node = node.childs.get('R')
dist_map = np.zeros(5)
dist_map[RailEnvActions.DO_NOTHING] = np.inf
dist_map[RailEnvActions.STOP_MOVING] = 100000
# left
if left_node == -np.inf:
dist_map[RailEnvActions.MOVE_LEFT] = np.inf
else:
if left_node.num_agents_opposite_direction == 0:
dist_map[RailEnvActions.MOVE_LEFT] = left_node.dist_min_to_target
else:
dist_map[RailEnvActions.MOVE_LEFT] = np.inf
# forward
if forward_node == -np.inf:
dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
else:
if forward_node.num_agents_opposite_direction == 0:
dist_map[RailEnvActions.MOVE_FORWARD] = forward_node.dist_min_to_target
else:
dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
# right
if right_node == -np.inf:
dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
else:
if right_node.num_agents_opposite_direction == 0:
dist_map[RailEnvActions.MOVE_RIGHT] = right_node.dist_min_to_target
else:
dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
return np.argmin(dist_map)
def save(self, filename):
pass
def load(self, filename):
pass
policy = ShortestPathWalkerHeuristicPolicy()
def normalize_observation(observation, tree_depth: int, observation_radius=0):
return observation