From 02d8a6d7989d87355edb096a33972ad500427b0a Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 22 Oct 2020 10:11:54 +0200 Subject: [PATCH] Q&D --- .../multi_agent_training.py | 50 ++++++++++------- reinforcement_learning/multi_policy.py | 6 -- utils/dead_lock_avoidance_agent.py | 55 ++++++++++++++----- utils/extra.py | 2 +- utils/shortest_Distance_walker.py | 28 ++++++---- 5 files changed, 87 insertions(+), 54 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 7a80c78..c47b484 100644 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -194,8 +194,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo policy = PPOAgent(state_size, action_size, n_agents) if False: policy = MultiPolicy(state_size, action_size, n_agents, train_env) - if True: - policy = DeadLockAvoidanceAgent(train_env,state_size, action_size) + if False: + policy = DeadLockAvoidanceAgent(train_env, state_size, action_size) # Load existing policy if train_params.load_policy is not None: @@ -253,6 +253,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo reset_timer.start() obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) policy.reset() + deadLockAvoidanceAgent = DeadLockAvoidanceAgent(train_env, state_size, action_size) reset_timer.end() if train_params.render: @@ -273,20 +274,25 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo for step in range(max_steps - 1): inference_timer.start() policy.start_step() + deadLockAvoidanceAgent.start_step() for agent in train_env.get_agent_handles(): - if info['action_required'][agent]: - update_values[agent] = True - action = policy.act(agent,agent_obs[agent], eps=eps_start) - - action_count[action] += 1 - actions_taken.append(action) - else: - # An action is not required if the train hasn't joined the railway network, - # if it already reached its target, or if is currently malfunctioning. - update_values[agent] = False - action = 0 + action = deadLockAvoidanceAgent.act(agent, None, 0.0) + update_values[agent] = False + if action != RailEnvActions.STOP_MOVING: + if info['action_required'][agent]: + update_values[agent] = True + action = policy.act(agent, agent_obs[agent], eps=eps_start) + action_count[action] += 1 + actions_taken.append(action) + else: + # An action is not required if the train hasn't joined the railway network, + # if it already reached its target, or if is currently malfunctioning. + action = 0 + action_dict.update({agent: action}) policy.end_step() + deadLockAvoidanceAgent.end_step() + inference_timer.end() # Environment step @@ -458,22 +464,26 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): score = 0.0 obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True) + deadLockAvoidanceAgent = DeadLockAvoidanceAgent(env, None, None) final_step = 0 for step in range(max_steps - 1): + deadLockAvoidanceAgent.start_step() for agent in env.get_agent_handles(): if tree_observation.check_is_observation_valid(agent_obs[agent]): agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth, observation_radius=observation_radius) - action = 0 - if info['action_required'][agent]: - if tree_observation.check_is_observation_valid(agent_obs[agent]): - action = policy.act(agent,agent_obs[agent], eps=0.0) + action = deadLockAvoidanceAgent.act(agent, None, 0) + if action != RailEnvActions.STOP_MOVING: + if info['action_required'][agent]: + if tree_observation.check_is_observation_valid(agent_obs[agent]): + action = policy.act(agent, agent_obs[agent], eps=0.0) action_dict.update({agent: action}) obs, all_rewards, done, info = env.step(action_dict) + deadLockAvoidanceAgent.end_step() for agent in env.get_agent_handles(): score += all_rewards[agent] @@ -505,7 +515,7 @@ if __name__ == "__main__": parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=200000, type=int) parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int) parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.0001, type=float) @@ -525,9 +535,9 @@ if __name__ == "__main__": parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool) - parser.add_argument("--max_depth", help="max depth", default=-1, type=int) + parser.add_argument("--max_depth", help="max depth", default=1, type=int) parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool) training_params = parser.parse_args() diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py index 2ce6d0f..765bcf5 100644 --- a/reinforcement_learning/multi_policy.py +++ b/reinforcement_learning/multi_policy.py @@ -13,7 +13,6 @@ class MultiPolicy(Policy): self.action_size = action_size self.memory = [] self.loss = 0 - self.dead_lock_avoidance_policy = DeadLockAvoidanceAgent(env, state_size, action_size) self.extra_policy = ExtraPolicy(state_size, action_size) self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env) @@ -40,9 +39,6 @@ class MultiPolicy(Policy): self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done) def act(self, handle, state, eps=0.): - dead_lock_avoidance_action = self.dead_lock_avoidance_policy.act(handle, state, 0.0) - if dead_lock_avoidance_action == RailEnvActions.STOP_MOVING: - return RailEnvActions.STOP_MOVING action_extra_state = self.extra_policy.act(handle, state, 0.0) extended_state = np.copy(state) for action_itr in np.arange(self.action_size): @@ -60,11 +56,9 @@ class MultiPolicy(Policy): self.extra_policy.test() def start_step(self): - self.dead_lock_avoidance_policy.start_step() self.extra_policy.start_step() self.ppo_policy.start_step() def end_step(self): - self.dead_lock_avoidance_policy.end_step() self.extra_policy.end_step() self.ppo_policy.end_step() diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 382959a..bb9dc3d 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -1,14 +1,14 @@ import matplotlib.pyplot as plt import numpy as np from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero from reinforcement_learning.policy import Policy from utils.shortest_Distance_walker import ShortestDistanceWalker class MyWalker(ShortestDistanceWalker): - def __init__(self, env: RailEnv, agent_positions): + def __init__(self, env: RailEnv, agent_positions, switches): super().__init__(env) self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), self.env.height, @@ -22,25 +22,38 @@ class MyWalker(ShortestDistanceWalker): self.agent_positions = agent_positions - self.agent_map = {} + self.opp_agent_map = {} + self.same_agent_map = {} + self.switches = switches def get_action(self, handle, min_distances): + if min_distances[0] != np.inf: + m = min(min_distances) + if min_distances[0] < m + 5: + return 0 return np.argmin(min_distances) def getData(self): return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map - def callback(self, handle, agent, position, direction, action): + def callback(self, handle, agent, position, direction, action, possible_transitions): opp_a = self.agent_positions[position] if opp_a != -1 and opp_a != handle: if self.env.agents[opp_a].direction != direction: - d = self.agent_map.get(handle, []) + d = self.opp_agent_map.get(handle, []) if opp_a not in d: d.append(opp_a) - self.agent_map.update({handle: d}) - d = self.agent_map.get(handle, []) - if len(d) == 0: - self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + self.opp_agent_map.update({handle: d}) + else: + if len(self.opp_agent_map.get(handle, [])) == 0: + d = self.same_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.same_agent_map.update({handle: d}) + + if len(self.opp_agent_map.get(handle, [])) == 0: + if self.switches.get(position, None) is None: + self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 @@ -52,6 +65,7 @@ class DeadLockAvoidanceAgent(Policy): self.memory = [] self.loss = 0 self.agent_can_move = {} + self.switches = {} def step(self, handle, state, action, reward, next_state, done): pass @@ -61,11 +75,21 @@ class DeadLockAvoidanceAgent(Policy): check = self.agent_can_move.get(handle, None) if check is None: return RailEnvActions.STOP_MOVING - return check[3] def reset(self): - pass + self.switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in self.switches.keys(): + self.switches.update({pos: [dir]}) + else: + self.switches[pos].append(dir) def start_step(self): self.shortest_distance_mapper() @@ -86,7 +110,7 @@ class DeadLockAvoidanceAgent(Policy): if agent.position is not None: agent_positions[agent.position] = handle - my_walker = MyWalker(self.env, agent_positions) + my_walker = MyWalker(self.env, agent_positions, self.switches) for handle in range(self.env.get_num_agents()): agent = self.env.agents[handle] if agent.status <= RailAgentStatus.ACTIVE: @@ -96,14 +120,15 @@ class DeadLockAvoidanceAgent(Policy): self.agent_can_move = {} agent_positions_map = (agent_positions > -1).astype(int) for handle in range(self.env.get_num_agents()): - opp_agents = my_walker.agent_map.get(handle, []) + opp_agents = my_walker.opp_agent_map.get(handle, []) + same_agents = my_walker.same_agent_map.get(handle, []) me = shortest_distance_agent_map[handle] delta = me next_step_ok = True - next_position, next_direction, action = my_walker.walk_one_step(handle) + next_position, next_direction, action, possible_transitions = my_walker.walk_one_step(handle) for opp_a in opp_agents: opp = full_shortest_distance_agent_map[opp_a] - delta = (delta - opp - agent_positions_map > 0).astype(int) + delta = ((delta - opp - agent_positions_map) > 0).astype(int) if (np.sum(delta) < 3): next_step_ok = False diff --git a/utils/extra.py b/utils/extra.py index 1145521..83263c7 100644 --- a/utils/extra.py +++ b/utils/extra.py @@ -84,7 +84,7 @@ class Extra(ObservationBuilder): def getData(self): return self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter - def callback(self, handle, agent, position, direction, action): + def callback(self, handle, agent, position, direction, action, possible_transitions): self.shortest_distance_agent_counter[position] += 1 self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)] += 1 diff --git a/utils/shortest_Distance_walker.py b/utils/shortest_Distance_walker.py index bd1d5b3..ad754b1 100644 --- a/utils/shortest_Distance_walker.py +++ b/utils/shortest_Distance_walker.py @@ -16,7 +16,7 @@ class ShortestDistanceWalker: new_position = get_new_position(position, new_direction) dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction] - return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD + return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions else: min_distances = [] positions = [] @@ -34,28 +34,31 @@ class ShortestDistanceWalker: directions.append(None) a = self.get_action(handle, min_distances) - return positions[a], directions[a], min_distances[a], a + 1 + return positions[a], directions[a], min_distances[a], a + 1, possible_transitions def get_action(self, handle, min_distances): return np.argmin(min_distances) - def callback(self, handle, agent, position, direction, action): + def callback(self, handle, agent, position, direction, action, possible_transitions): pass - def walk_to_target(self, handle): + def walk_to_target(self, handle, max_step=500): agent = self.env.agents[handle] if agent.position is not None: position = agent.position else: position = agent.initial_position direction = agent.direction - while (position != agent.target): - position, direction, dist, action = self.walk(handle, position, direction) + + step = 0 + while (position != agent.target) and (step < max_step): + position, direction, dist, action, possible_transitions = self.walk(handle, position, direction) if position is None: break - self.callback(handle, agent, position, direction, action) + self.callback(handle, agent, position, direction, action, possible_transitions) + step += 1 - def callback_one_step(self, handle, agent, position, direction, action): + def callback_one_step(self, handle, agent, position, direction, action, possible_transitions): pass def walk_one_step(self, handle): @@ -65,9 +68,10 @@ class ShortestDistanceWalker: else: position = agent.initial_position direction = agent.direction + possible_transitions = (0, 1, 0, 0) if (position != agent.target): - new_position, new_direction, dist, action = self.walk(handle, position, direction) + new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction) if new_position is None: - return position, direction, RailEnvActions.STOP_MOVING - self.callback_one_step(handle, agent, new_position, new_direction, action) - return new_position, new_direction, action + return position, direction, RailEnvActions.STOP_MOVING, possible_transitions + self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions) + return new_position, new_direction, action, possible_transitions -- GitLab