From d6103087ab59c61e8e0eb449385594833ccf5e7a Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 3 Nov 2020 16:05:33 +0100 Subject: [PATCH] ... --- reinforcement_learning/dddqn_policy.py | 6 +- .../multi_agent_training.py | 4 +- run.py | 4 +- utils/dead_lock_avoidance_agent.py | 175 ++++++++++++++++++ utils/fast_tree_obs.py | 6 +- 5 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 utils/dead_lock_avoidance_agent.py diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 32d7110..c1177b9 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -123,7 +123,11 @@ class DDDQNPolicy(Policy): self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) else: - raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target")) + if os.path.exists(filename): + self.qnetwork_local.load_state_dict(torch.load(filename)) + self.qnetwork_target.load_state_dict(torch.load(filename)) + else: + raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target")) def save_replay_buffer(self, filename): memory = self.memory.memory diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 195b46a..9c78a72 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -322,7 +322,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Print logs if episode_idx % checkpoint_interval == 0: - torch.save(policy.qnetwork_local, './checkpoints/' + training_id + '-' + str(episode_idx) + '.pth') + policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth') if save_replay_buffer: policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl') @@ -475,7 +475,7 @@ if __name__ == "__main__": parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, type=int) parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=50, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) diff --git a/run.py b/run.py index a2309a6..918d78c 100644 --- a/run.py +++ b/run.py @@ -28,7 +28,7 @@ from utils.observation_utils import normalize_observation VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "checkpoints/201014015722-1500.pth" +checkpoint = "checkpoints/201103150429-2500.pth" # Use last action cache USE_ACTION_CACHE = True @@ -55,7 +55,7 @@ action_size = 5 policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) if os.path.isfile(checkpoint): - policy.qnetwork_local = torch.load(checkpoint) + policy.load(checkpoint) else: print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint)) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py new file mode 100644 index 0000000..c7c6d8a --- /dev/null +++ b/utils/dead_lock_avoidance_agent.py @@ -0,0 +1,175 @@ +from typing import Optional, List + +import matplotlib.pyplot as plt +import numpy as np +from flatland.core.env_observation_builder import DummyObservationBuilder +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero + +from reinforcement_learning.policy import Policy +from utils.shortest_Distance_walker import ShortestDistanceWalker + + +class DeadlockAvoidanceObservation(DummyObservationBuilder): + def __init__(self): + self.counter = 0 + + def get_many(self, handles: Optional[List[int]] = None) -> bool: + self.counter += 1 + obs = np.ones(len(handles), 2) + for handle in handles: + obs[handle][0] = handle + obs[handle][1] = self.counter + return obs + + +class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): + def __init__(self, env: RailEnv, agent_positions, switches): + super().__init__(env) + self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), + self.env.height, + self.env.width), + dtype=int) - 1 + + self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), + self.env.height, + self.env.width), + dtype=int) - 1 + + self.agent_positions = agent_positions + + self.opp_agent_map = {} + self.same_agent_map = {} + self.switches = switches + + def getData(self): + return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map + + def callback(self, handle, agent, position, direction, action, possible_transitions): + opp_a = self.agent_positions[position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != direction: + d = self.opp_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.opp_agent_map.update({handle: d}) + else: + if len(self.opp_agent_map.get(handle, [])) == 0: + d = self.same_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.same_agent_map.update({handle: d}) + + if len(self.opp_agent_map.get(handle, [])) == 0: + if self.switches.get(position, None) is None: + self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + + +class DeadLockAvoidanceAgent(Policy): + def __init__(self, env: RailEnv, show_debug_plot=False): + self.env = env + self.memory = None + self.loss = 0 + self.agent_can_move = {} + self.switches = {} + self.show_debug_plot = show_debug_plot + + def step(self, state, action, reward, next_state, done): + pass + + def act(self, state, eps=0.): + # agent = self.env.agents[state[0]] + check = self.agent_can_move.get(state[0], None) + if check is None: + return RailEnvActions.STOP_MOVING + return check[3] + + def reset(self): + self.agent_positions = None + self.shortest_distance_walker = None + self.switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in self.switches.keys(): + self.switches.update({pos: [dir]}) + else: + self.switches[pos].append(dir) + + def start_step(self): + self.build_agent_position_map() + self.shortest_distance_mapper() + self.extract_agent_can_move() + + def end_step(self): + pass + + def get_actions(self): + pass + + def build_agent_position_map(self): + # build map with agent positions (only active agents) + self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1 + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.status == RailAgentStatus.ACTIVE: + if agent.position is not None: + self.agent_positions[agent.position] = handle + + def shortest_distance_mapper(self): + self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env, + self.agent_positions, + self.switches) + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.status <= RailAgentStatus.ACTIVE: + self.shortest_distance_walker.walk_to_target(handle) + + def extract_agent_can_move(self): + self.agent_can_move = {} + shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData() + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.status < RailAgentStatus.DONE: + next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle], + self.shortest_distance_walker.same_agent_map.get(handle, []), + self.shortest_distance_walker.opp_agent_map.get(handle, []), + full_shortest_distance_agent_map) + if next_step_ok: + next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle) + self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]}) + + if self.show_debug_plot: + a = np.floor(np.sqrt(self.env.get_num_agents())) + b = np.ceil(self.env.get_num_agents() / a) + for handle in range(self.env.get_num_agents()): + plt.subplot(a, b, handle + 1) + plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle]) + plt.show(block=False) + plt.pause(0.01) + + def check_agent_can_move(self, + my_shortest_walking_path, + same_agents, + opp_agents, + full_shortest_distance_agent_map): + agent_positions_map = (self.agent_positions > -1).astype(int) + delta = my_shortest_walking_path + next_step_ok = True + for opp_a in opp_agents: + opp = full_shortest_distance_agent_map[opp_a] + delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int) + if np.sum(delta) < (3 + len(opp_agents)): + next_step_ok = False + return next_step_ok + + def save(self, filename): + pass + + def load(self, filename): + pass diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 12c91ca..7e4c934 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -294,8 +294,4 @@ class FastTreeObs(ObservationBuilder): self.env.dev_obs_dict.update({handle: visited}) - return observation - - @staticmethod - def agent_can_choose(observation): - return observation[7] == 1 or observation[8] == 1 + return observation \ No newline at end of file -- GitLab