diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index dc4fc33f0e8397f4610a40a633e660b73cb6c99c..542f587b1c3bea557c5d9e5f90e6210415fcf4a7 100644 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -9,6 +9,7 @@ from pprint import pprint import numpy as np import psutil +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -196,8 +197,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo policy = PPOAgent(state_size, action_size, n_agents, train_env) if False: policy = MultiPolicy(state_size, action_size, n_agents, train_env) - if False: - policy = DeadLockAvoidanceAgent(train_env, state_size, action_size) # Load existing policy if train_params.load_policy is not None: @@ -244,6 +243,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo training_id )) + rl_policy = policy for episode_idx in range(n_episodes + 1): step_timer = Timer() reset_timer = Timer() @@ -254,6 +254,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo # Reset environment reset_timer.start() obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) + + # train different number of agents : 1,2,3,... n_agents + for handle in range(train_env.get_num_agents()): + if (episode_idx % n_agents) < handle: + train_env.agents[handle].status = RailAgentStatus.DONE_REMOVED + + # start with simple deadlock avoidance agent policy (imitation learning?) + if episode_idx < 500: + policy = DeadLockAvoidanceAgent(train_env, state_size, action_size) + else: + policy = rl_policy + policy.reset() reset_timer.end() @@ -512,7 +524,7 @@ if __name__ == "__main__": parser.add_argument("--eps_start", help="max exploration", default=0.5, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) parser.add_argument("--eps_decay", help="exploration decay", default=0.9985, type=float) - parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int) + parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e6), type=int) parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, @@ -528,8 +540,8 @@ if __name__ == "__main__": parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool) parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool) - parser.add_argument("--max_depth", help="max depth", default=1, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, + parser.add_argument("--max_depth", help="max depth", default=2, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 62e76ea09bcaf4a58ab982c89399d2e9fe889f2c..39090a14edbe533297556853f8aecc8b34ddef14 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -148,6 +148,6 @@ class DeadLockAvoidanceAgent(Policy): for opp_a in opp_agents: opp = full_shortest_distance_agent_map[opp_a] delta = ((delta - opp - agent_positions_map) > 0).astype(int) - if (np.sum(delta) < 1 + len(opp_agents)): + if (np.sum(delta) < 2 + len(opp_agents)): next_step_ok = False return next_step_ok diff --git a/utils/extra.py b/utils/extra.py index 340cb404a50bf9e19387e4ca227345a39d0f9db4..89ed0bb9ea2b7a993fe9eecd9332f0910294cbce 100644 --- a/utils/extra.py +++ b/utils/extra.py @@ -183,14 +183,21 @@ class Extra(ObservationBuilder): self.build_data() return - def fast_argmax(self, array): - if array[0] == 1: - return 0 - if array[1] == 1: - return 1 - if array[2] == 1: - return 2 - return 3 + + def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction): + _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() + opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, []) + local_walker = DeadlockAvoidanceShortestDistanceWalker( + self.env, + self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions, + self.dead_lock_avoidance_agent.shortest_distance_walker.switches) + local_walker.walk_to_target(handle, new_position, branch_direction) + shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() + my_shortest_path_to_check = shortest_distance_agent_map[handle] + next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, + opp_agents, + full_shortest_distance_agent_map) + return next_step_ok def _explore(self, handle, new_position, new_direction, depth=0): @@ -332,18 +339,7 @@ class Extra(ObservationBuilder): observation[18 + dir_loop] = has_same_agent observation[22 + dir_loop] = has_switch - _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() - opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, []) - local_walker = DeadlockAvoidanceShortestDistanceWalker( - self.env, - self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions, - self.dead_lock_avoidance_agent.shortest_distance_walker.switches) - local_walker.walk_to_target(handle, new_position, branch_direction) - shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() - my_shortest_path_to_check = shortest_distance_agent_map[handle] - next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, - opp_agents, - full_shortest_distance_agent_map) + next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction) if next_step_ok: observation[26 + dir_loop] = 1