diff --git a/run.py b/run.py index 7ad08cd45122f68de9da711022269f4255727a69..fec8bc01cff56b1733a016666eb703ae560353c0 100644 --- a/run.py +++ b/run.py @@ -29,6 +29,7 @@ checkpoint = "./checkpoints/201103160541-1800.pth" # Use last action cache USE_ACTION_CACHE = True +USE_DEAD_LOCK_AVOIDANCE_AGENT = False # Observation parameters (must match training parameters!) observation_tree_depth = 2 @@ -82,8 +83,6 @@ while True: nb_agents = len(local_env.agents) max_nb_steps = local_env._max_episode_steps - policy = DeadLockAvoidanceAgent(local_env) - tree_observation.set_env(local_env) tree_observation.reset() observation = tree_observation.get_many(list(range(nb_agents))) @@ -105,6 +104,9 @@ while True: agent_last_action = {} nb_hit = 0 + if USE_DEAD_LOCK_AVOIDANCE_AGENT: + policy = DeadLockAvoidanceAgent(local_env) + while True: try: ##################################################################### @@ -118,7 +120,14 @@ while True: time_start = time.time() action_dict = {} policy.start_step() + if USE_DEAD_LOCK_AVOIDANCE_AGENT: + observation = np.zeros((local_env.get_num_agents(), 2)) for agent in range(nb_agents): + + if USE_DEAD_LOCK_AVOIDANCE_AGENT: + observation[agent][0] = agent + observation[agent][1] = steps + if info['action_required'][agent]: if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]): # cache hit diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index c7c6d8a5c171eb8cc208520bf6dc8f5d6cfc2845..700600c337882271eebe519233b473449be3b1ab 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -7,7 +7,7 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero from reinforcement_learning.policy import Policy -from utils.shortest_Distance_walker import ShortestDistanceWalker +from utils.shortest_distance_walker import ShortestDistanceWalker class DeadlockAvoidanceObservation(DummyObservationBuilder): diff --git a/utils/shortest_distance_walker.py b/utils/shortest_distance_walker.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8121f46e681ebc37ea3d1afb6b4023d33f2e14 --- /dev/null +++ b/utils/shortest_distance_walker.py @@ -0,0 +1,87 @@ +import numpy as np +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax + + +class ShortestDistanceWalker: + def __init__(self, env: RailEnv): + self.env = env + + def walk(self, handle, position, direction): + possible_transitions = self.env.rail.get_transitions(*position, direction) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions == 1: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(position, new_direction) + + dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction] + return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions + else: + min_distances = [] + positions = [] + directions = [] + for new_direction in [(direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(position, new_direction) + min_distances.append( + self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]) + positions.append(new_position) + directions.append(new_direction) + else: + min_distances.append(np.inf) + positions.append(None) + directions.append(None) + + a = self.get_action(handle, min_distances) + return positions[a], directions[a], min_distances[a], a + 1, possible_transitions + + def get_action(self, handle, min_distances): + return np.argmin(min_distances) + + def callback(self, handle, agent, position, direction, action, possible_transitions): + pass + + def get_agent_position_and_direction(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + return position, direction + + def walk_to_target(self, handle, position=None, direction=None, max_step=500): + if position is None and direction is None: + position, direction = self.get_agent_position_and_direction(handle) + elif position is None: + position, _ = self.get_agent_position_and_direction(handle) + elif direction is None: + _, direction = self.get_agent_position_and_direction(handle) + + agent = self.env.agents[handle] + step = 0 + while (position != agent.target) and (step < max_step): + position, direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if position is None: + break + self.callback(handle, agent, position, direction, action, possible_transitions) + step += 1 + + def callback_one_step(self, handle, agent, position, direction, action, possible_transitions): + pass + + def walk_one_step(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + possible_transitions = (0, 1, 0, 0) + if (position != agent.target): + new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if new_position is None: + return position, direction, RailEnvActions.STOP_MOVING, possible_transitions + self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions) + return new_position, new_direction, action, possible_transitions