diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index c62a749e21b879522ba8004765ee405d75558b6e..dc4fc33f0e8397f4610a40a633e660b73cb6c99c 100644 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -33,6 +33,8 @@ from reinforcement_learning.multi_policy import MultiPolicy from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent +# https://github.com/dongminlee94/deep_rl + try: import wandb @@ -508,8 +510,8 @@ if __name__ == "__main__": parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int) parser.add_argument("--eps_start", help="max exploration", default=0.5, type=float) - parser.add_argument("--eps_end", help="min exploration", default=0.0001, type=float) - parser.add_argument("--eps_decay", help="exploration decay", default=0.9997, type=float) + parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) + parser.add_argument("--eps_decay", help="exploration decay", default=0.9985, type=float) parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int) parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) @@ -527,7 +529,8 @@ if __name__ == "__main__": parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool) parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool) parser.add_argument("--max_depth", help="max depth", default=1, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, + type=int) parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool) training_params = parser.parse_args() diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index 2ba5a68bb30171e6279742e50dbdf1753846346e..ec904e4598ae1a15fcff8f0bd1aa3b4cd2f5f3e9 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -9,14 +9,14 @@ from reinforcement_learning.policy import Policy from reinforcement_learning.ppo.model import PolicyNetwork from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer -BUFFER_SIZE = 32_000 -BATCH_SIZE = 4096 -GAMMA = 0.8 +BUFFER_SIZE = 128_000 +BATCH_SIZE = 8192 +GAMMA = 0.95 LR = 0.5e-4 CLIP_FACTOR = .005 UPDATE_EVERY = 30 -device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class PPOAgent(Policy): diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 71140c3ac8fad54116e40514d94e370390f4d1d4..62e76ea09bcaf4a58ab982c89399d2e9fe889f2c 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -7,7 +7,7 @@ from reinforcement_learning.policy import Policy from utils.shortest_Distance_walker import ShortestDistanceWalker -class MyWalker(ShortestDistanceWalker): +class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): def __init__(self, env: RailEnv, agent_positions, switches): super().__init__(env) self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), @@ -26,13 +26,6 @@ class MyWalker(ShortestDistanceWalker): self.same_agent_map = {} self.switches = switches - def get_action(self, handle, min_distances): - if min_distances[0] != np.inf: - m = min(min_distances) - if min_distances[0] < m + 5: - return 0 - return np.argmin(min_distances) - def getData(self): return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map @@ -58,7 +51,7 @@ class MyWalker(ShortestDistanceWalker): class DeadLockAvoidanceAgent(Policy): - def __init__(self, env: RailEnv, state_size, action_size): + def __init__(self, env: RailEnv, state_size, action_size, show_debug_plot=False): self.env = env self.action_size = action_size self.state_size = state_size @@ -66,6 +59,7 @@ class DeadLockAvoidanceAgent(Policy): self.loss = 0 self.agent_can_move = {} self.switches = {} + self.show_debug_plot = show_debug_plot def step(self, handle, state, action, reward, next_state, done): pass @@ -78,6 +72,8 @@ class DeadLockAvoidanceAgent(Policy): return check[3] def reset(self): + self.agent_positions = None + self.shortest_distance_walker = None self.switches = {} for h in range(self.env.height): for w in range(self.env.width): @@ -92,7 +88,9 @@ class DeadLockAvoidanceAgent(Policy): self.switches[pos].append(dir) def start_step(self): + self.build_agent_position_map() self.shortest_distance_mapper() + self.extract_agent_can_move() def end_step(self): pass @@ -100,49 +98,56 @@ class DeadLockAvoidanceAgent(Policy): def get_actions(self): pass - def shortest_distance_mapper(self): - + def build_agent_position_map(self): # build map with agent positions (only active agents) - agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1 + self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1 for handle in range(self.env.get_num_agents()): agent = self.env.agents[handle] if agent.status == RailAgentStatus.ACTIVE: if agent.position is not None: - agent_positions[agent.position] = handle + self.agent_positions[agent.position] = handle - my_walker = MyWalker(self.env, agent_positions, self.switches) + def shortest_distance_mapper(self): + self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env, + self.agent_positions, + self.switches) for handle in range(self.env.get_num_agents()): agent = self.env.agents[handle] if agent.status <= RailAgentStatus.ACTIVE: - my_walker.walk_to_target(handle) - shortest_distance_agent_map, full_shortest_distance_agent_map = my_walker.getData() + self.shortest_distance_walker.walk_to_target(handle) - delta_data = np.copy(full_shortest_distance_agent_map) + def extract_agent_can_move(self): self.agent_can_move = {} - agent_positions_map = (agent_positions > -1).astype(int) + shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData() for handle in range(self.env.get_num_agents()): - opp_agents = my_walker.opp_agent_map.get(handle, []) - same_agents = my_walker.same_agent_map.get(handle, []) - me = shortest_distance_agent_map[handle] - delta = me - next_step_ok = True - next_position, next_direction, action, possible_transitions = my_walker.walk_one_step(handle) - for opp_a in opp_agents: - opp = full_shortest_distance_agent_map[opp_a] - delta = ((delta - opp - agent_positions_map) > 0).astype(int) - delta_data[handle] += np.clip(delta,0,1) - if (np.sum(delta) < 3): - next_step_ok = False - - if next_step_ok: - self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]}) - - if False: + agent = self.env.agents[handle] + if agent.status < RailAgentStatus.DONE: + next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle], + self.shortest_distance_walker.opp_agent_map.get(handle, []), + full_shortest_distance_agent_map) + if next_step_ok: + next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle) + self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]}) + + if self.show_debug_plot: a = np.floor(np.sqrt(self.env.get_num_agents())) b = np.ceil(self.env.get_num_agents() / a) for handle in range(self.env.get_num_agents()): plt.subplot(a, b, handle + 1) - plt.imshow(delta_data[handle]) - # plt.colorbar() + plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle]) plt.show(block=False) plt.pause(0.01) + + def check_agent_can_move(self, + my_shortest_walking_path, + opp_agents, + full_shortest_distance_agent_map): + agent_positions_map = (self.agent_positions > -1).astype(int) + delta = my_shortest_walking_path + next_step_ok = True + for opp_a in opp_agents: + opp = full_shortest_distance_agent_map[opp_a] + delta = ((delta - opp - agent_positions_map) > 0).astype(int) + if (np.sum(delta) < 1 + len(opp_agents)): + next_step_ok = False + return next_step_ok diff --git a/utils/extra.py b/utils/extra.py index 4b14b840c7cd2286ab8c1216763e3f79e635c2d5..340cb404a50bf9e19387e4ca227345a39d0f9db4 100644 --- a/utils/extra.py +++ b/utils/extra.py @@ -1,11 +1,11 @@ import numpy as np - from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnvActions, fast_argmax, fast_count_nonzero from reinforcement_learning.policy import Policy +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent, DeadlockAvoidanceShortestDistanceWalker class ExtraPolicy(Policy): @@ -56,11 +56,14 @@ class Extra(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 26 + self.observation_dim = 31 def build_data(self): + self.dead_lock_avoidance_agent = None if self.env is not None: self.env.dev_obs_dict = {} + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, None, None) + self.switches = {} self.switches_neighbours = {} self.debug_render_list = [] @@ -248,6 +251,10 @@ class Extra(ObservationBuilder): return has_opp_agent, has_same_agent, has_switch, visited def get(self, handle): + + if handle == 0: + self.dead_lock_avoidance_agent.start_step() + # all values are [0,1] # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path @@ -275,6 +282,11 @@ class Extra(ObservationBuilder): # observation[23] : If there is a switch on the path which agent can not use -> 1 # observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1 + # observation[26] : Is there a deadlock signal on shortest path walk(s) (direction 0)-> 1 + # observation[27] : Is there a deadlock signal on shortest path walk(s) (direction 1)-> 1 + # observation[28] : Is there a deadlock signal on shortest path walk(s) (direction 2)-> 1 + # observation[29] : Is there a deadlock signal on shortest path walk(s) (direction 3)-> 1 + # observation[30] : Is there a deadlock signal on shortest path walk(s) (current position check)-> 1 observation = np.zeros(self.observation_dim) visited = [] @@ -320,6 +332,21 @@ class Extra(ObservationBuilder): observation[18 + dir_loop] = has_same_agent observation[22 + dir_loop] = has_switch + _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() + opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, []) + local_walker = DeadlockAvoidanceShortestDistanceWalker( + self.env, + self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions, + self.dead_lock_avoidance_agent.shortest_distance_walker.switches) + local_walker.walk_to_target(handle, new_position, branch_direction) + shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() + my_shortest_path_to_check = shortest_distance_agent_map[handle] + next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, + opp_agents, + full_shortest_distance_agent_map) + if next_step_ok: + observation[26 + dir_loop] = 1 + agents_on_switch, \ agents_near_to_switch, \ agents_near_to_switch_all, \ @@ -329,6 +356,8 @@ class Extra(ObservationBuilder): observation[8] = int(agents_near_to_switch) observation[9] = int(agents_near_to_switch_all) + observation[30] = int(self.dead_lock_avoidance_agent.act(handle, None, 0) == RailEnvActions.STOP_MOVING) + self.env.dev_obs_dict.update({handle: visited}) return observation diff --git a/utils/shortest_Distance_walker.py b/utils/shortest_Distance_walker.py index ad754b154dd28d8d3f7e0e376a2fd8eda8c89079..62b686fff0f61a13c48220565553d7e63067739a 100644 --- a/utils/shortest_Distance_walker.py +++ b/utils/shortest_Distance_walker.py @@ -42,14 +42,24 @@ class ShortestDistanceWalker: def callback(self, handle, agent, position, direction, action, possible_transitions): pass - def walk_to_target(self, handle, max_step=500): + def get_agent_position_and_direction(self, handle): agent = self.env.agents[handle] if agent.position is not None: position = agent.position else: position = agent.initial_position direction = agent.direction + return position, direction + def walk_to_target(self, handle, position=None, direction=None, max_step=500): + if position is None and direction is None: + position, direction = self.get_agent_position_and_direction(handle) + elif position is None: + position, _ = self.get_agent_position_and_direction(handle) + elif direction is None: + _, direction = self.get_agent_position_and_direction(handle) + + agent = self.env.agents[handle] step = 0 while (position != agent.target) and (step < max_step): position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)