From 225bbed91e9961f4ac5599b05e3b53e97896af72 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Nov 2020 10:49:09 +0100 Subject: [PATCH] refactored file name --- reinforcement_learning/dddqn_policy.py | 4 - .../multi_agent_training.py | 23 ++++-- run.py | 5 +- utils/extra.py | 2 + utils/fast_tree_obs.py | 81 +++++++++---------- 5 files changed, 60 insertions(+), 55 deletions(-) diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index b34dd36..3eb54a3 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -66,10 +66,6 @@ class DDDQNPolicy(Policy): # Epsilon-greedy action selection if random.random() >= eps: return np.argmax(action_values.cpu().data.numpy()) - qvals = action_values.cpu().data.numpy()[0] - qvals = qvals - np.min(qvals) - qvals = qvals / (1e-5 + np.sum(qvals)) - return np.argmax(np.random.multinomial(1, qvals)) else: return random.choice(np.arange(self.action_size)) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index be905e0..b3271b5 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -171,9 +171,14 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead completion_window = deque(maxlen=checkpoint_interval) + # IF USE_SINGLE_AGENT_TRAINING is set and the episode_idx <= MAX_SINGLE_TRAINING_ITERATION then + # the training gets done with single use. Each UPDATE_POLICY2_N_EPISODE the second policy get replaced + # with the policy (the one which get trained). + USE_SINGLE_AGENT_TRAINING = True + MAX_SINGLE_TRAINING_ITERATION = 1000 + UPDATE_POLICY2_N_EPISODE = 200 + # Double Dueling DQN policy - USE_SINGLE_AGENT_TRAINING = False - UPDATE_POLICY2_N_EPISODE = 1000 policy = DDDQNPolicy(state_size, action_size, train_params) # policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy @@ -221,6 +226,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): preproc_timer = Timer() inference_timer = Timer() + if episode_idx > MAX_SINGLE_TRAINING_ITERATION: + USE_SINGLE_AGENT_TRAINING = False + # Reset environment reset_timer.start() train_env_params.n_agents = episode_idx % n_agents + 1 @@ -293,6 +301,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if agent_obs[agent][26] == 1: if act != RailEnvActions.STOP_MOVING: all_rewards[agent] -= 10.0 + if agent_obs[agent][27] == 1: + if act == RailEnvActions.MOVE_LEFT or \ + act == RailEnvActions.MOVE_RIGHT or \ + act == RailEnvActions.DO_NOTHING: + all_rewards[agent] -= 1.0 step_timer.end() @@ -310,7 +323,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if update_values[agent] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - if agent in agent_to_learn: + if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING: policy.step(agent, agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], @@ -501,8 +514,8 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=54000, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, + parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, type=int) diff --git a/run.py b/run.py index 048dcde..a4fa62f 100644 --- a/run.py +++ b/run.py @@ -27,12 +27,11 @@ VERBOSE = True # Checkpoint to use (remember to push it!) # checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10 -checkpoint = "./checkpoints/201113211844-6700.pth" # 19.690047767961005 DEPTH=2 AGENTS=20 - +checkpoint = "./checkpoints/201117082153-1500.pth" # 21.570149424415636 DEPTH=2 AGENTS=10 # Use last action cache USE_ACTION_CACHE = False -USE_DEAD_LOCK_AVOIDANCE_AGENT = False +USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213 # Observation parameters (must match training parameters!) observation_tree_depth = 2 diff --git a/utils/extra.py b/utils/extra.py index 89ed0bb..c4df6a8 100644 --- a/utils/extra.py +++ b/utils/extra.py @@ -187,6 +187,7 @@ class Extra(ObservationBuilder): def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction): _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, []) + same_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.same_agent_map.get(handle,[]) local_walker = DeadlockAvoidanceShortestDistanceWalker( self.env, self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions, @@ -196,6 +197,7 @@ class Extra(ObservationBuilder): my_shortest_path_to_check = shortest_distance_agent_map[handle] next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, opp_agents, + same_agents, full_shortest_distance_agent_map) return next_step_ok diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index c388d2a..625a21e 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -25,7 +25,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 27 + self.observation_dim = 32 def build_data(self): if self.env is not None: @@ -40,8 +40,8 @@ class FastTreeObs(ObservationBuilder): else: self.dead_lock_avoidance_agent = None - def find_all_cell_where_agent_can_choose(self): - switches = {} + def find_all_switches(self): + self.switches = {} for h in range(self.env.height): for w in range(self.env.width): pos = (h, w) @@ -49,12 +49,13 @@ class FastTreeObs(ObservationBuilder): possible_transitions = self.env.rail.get_transitions(*pos, dir) num_transitions = fast_count_nonzero(possible_transitions) if num_transitions > 1: - if pos not in switches.keys(): - switches.update({pos: [dir]}) + if pos not in self.switches.keys(): + self.switches.update({pos: [dir]}) else: - switches[pos].append(dir) + self.switches[pos].append(dir) - switches_neighbours = {} + def find_all_switch_neighbours(self): + self.switches_neighbours = {} for h in range(self.env.height): for w in range(self.env.width): # look one step forward @@ -64,35 +65,34 @@ class FastTreeObs(ObservationBuilder): for d in range(4): if possible_transitions[d] == 1: new_cell = get_new_position(pos, d) - if new_cell in switches.keys() and pos not in switches.keys(): - if pos not in switches_neighbours.keys(): - switches_neighbours.update({pos: [dir]}) + if new_cell in self.switches.keys() and pos not in self.switches.keys(): + if pos not in self.switches_neighbours.keys(): + self.switches_neighbours.update({pos: [dir]}) else: - switches_neighbours[pos].append(dir) + self.switches_neighbours[pos].append(dir) - self.switches = switches - self.switches_neighbours = switches_neighbours + def find_all_cell_where_agent_can_choose(self): + self.find_all_switches() + self.find_all_switch_neighbours() def check_agent_decision(self, position, direction): - switches = self.switches - switches_neighbours = self.switches_neighbours agents_on_switch = False agents_on_switch_all = False agents_near_to_switch = False agents_near_to_switch_all = False - if position in switches.keys(): - agents_on_switch = direction in switches[position] + if position in self.switches.keys(): + agents_on_switch = direction in self.switches[position] agents_on_switch_all = True - if position in switches_neighbours.keys(): + if position in self.switches_neighbours.keys(): new_cell = get_new_position(position, direction) - if new_cell in switches.keys(): - if not direction in switches[new_cell]: - agents_near_to_switch = direction in switches_neighbours[position] + if new_cell in self.switches.keys(): + if not direction in self.switches[new_cell]: + agents_near_to_switch = direction in self.switches_neighbours[position] else: - agents_near_to_switch = direction in switches_neighbours[position] + agents_near_to_switch = direction in self.switches_neighbours[position] - agents_near_to_switch_all = direction in switches_neighbours[position] + agents_near_to_switch_all = direction in self.switches_neighbours[position] return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all @@ -151,15 +151,6 @@ class FastTreeObs(ObservationBuilder): self.build_data() return - def fast_argmax(self, array): - if array[0] == 1: - return 0 - if array[1] == 1: - return 1 - if array[2] == 1: - return 2 - return 3 - def _explore(self, handle, new_position, new_direction, depth=0): has_opp_agent = 0 has_same_agent = 0 @@ -269,6 +260,7 @@ class FastTreeObs(ObservationBuilder): # observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1 # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 + # observation[27] : If there the agent can only walk forward or stop -> 1 observation = np.zeros(self.observation_dim) visited = [] @@ -313,18 +305,21 @@ class FastTreeObs(ObservationBuilder): observation[14 + dir_loop] = has_opp_agent observation[18 + dir_loop] = has_same_agent observation[22 + dir_loop] = has_target + observation[26 + dir_loop] = int(np.math.isinf(new_cell_dist)) + + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_decision(agent_virtual_position, agent.direction) + observation[7] = int(agents_on_switch) + observation[8] = int(agents_near_to_switch) + observation[9] = int(agents_near_to_switch_all) + + action = self.dead_lock_avoidance_agent.act([handle], 0.0) + observation[30] = int(action == RailEnvActions.STOP_MOVING) + observation[31] = int(fast_count_nonzero(possible_transitions) == 1) - agents_on_switch, \ - agents_near_to_switch, \ - agents_near_to_switch_all, \ - agents_on_switch_all = \ - self.check_agent_decision(agent_virtual_position, agent.direction) - observation[7] = int(agents_on_switch) - observation[8] = int(agents_near_to_switch) - observation[9] = int(agents_near_to_switch_all) - - action = self.dead_lock_avoidance_agent.act([handle], 0.0) - observation[26] = int(action == RailEnvActions.STOP_MOVING) self.env.dev_obs_dict.update({handle: visited}) return observation -- GitLab