From 1e092263a2d65104a29cf8bada806629a83c0529 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Fri, 11 Dec 2020 21:54:47 +0100 Subject: [PATCH] :-) --- .../multi_agent_training.py | 84 +++---------------- reinforcement_learning/ppo_agent.py | 46 +++++----- run.py | 18 ++-- utils/deadlock_check.py | 44 ++++++++++ utils/fast_tree_obs.py | 13 ++- 5 files changed, 102 insertions(+), 103 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 01add0f..cc6cfce 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -22,6 +22,7 @@ from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo_agent import PPOAgent +from utils.deadlock_check import get_agent_positions, check_for_deadlock base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -77,42 +78,6 @@ def create_rail_env(env_params, tree_observation): random_seed=seed ) - -def get_agent_positions(env): - agent_positions: np.ndarray = np.full((env.height, env.width), -1) - for agent_handle in env.get_agent_handles(): - agent = env.agents[agent_handle] - if agent.status == RailAgentStatus.ACTIVE: - position = agent.position - if position is None: - position = agent.initial_position - agent_positions[position] = agent_handle - return agent_positions - - -def check_for_dealock(handle, env, agent_positions): - agent = env.agents[handle] - if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: - return False - - position = agent.position - if position is None: - position = agent.initial_position - possible_transitions = env.rail.get_transitions(*position, agent.direction) - num_transitions = fast_count_nonzero(possible_transitions) - for dir_loop in range(4): - if possible_transitions[dir_loop] == 1: - new_position = get_new_position(position, dir_loop) - opposite_agent = agent_positions[new_position] - if opposite_agent != handle and opposite_agent != -1: - num_transitions -= 1 - else: - return False - - is_deadlock = num_transitions <= 0 - return is_deadlock - - def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment parameters n_agents = train_env_params.n_agents @@ -207,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Double Dueling DQN policy policy = DDDQNPolicy(state_size, action_size, train_params) - if True: + if False: policy = PPOAgent(state_size, action_size) # Load existing policy if train_params.load_policy is not "": @@ -256,7 +221,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Reset environment reset_timer.start() - number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200))) + number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 500))) train_env_params.n_agents = episode_idx % number_of_agents + 1 train_env = create_rail_env(train_env_params, tree_observation) @@ -318,34 +283,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): agent_positions = get_agent_positions(train_env) for agent_handle in train_env.get_agent_handles(): agent = train_env.agents[agent_handle] - act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD) if agent.status == RailAgentStatus.ACTIVE: - pos = agent.position - dir = agent.direction - possible_transitions = train_env.rail.get_transitions(*pos, dir) - num_transitions = fast_count_nonzero(possible_transitions) - if act == RailEnvActions.STOP_MOVING: - all_rewards[agent_handle] -= 2.0 - - if num_transitions == 1: - if act != RailEnvActions.MOVE_FORWARD: - all_rewards[agent_handle] -= 1.0 - if check_for_dealock(agent_handle, train_env, agent_positions): - all_rewards[agent_handle] -= 5.0 - elif agent.status == RailAgentStatus.READY_TO_DEPART: - all_rewards[agent_handle] -= 5.0 - else: - if False: - agent_positions = get_agent_positions(train_env) - for agent_handle in train_env.get_agent_handles(): - agent = train_env.agents[agent_handle] - act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD) - if agent.status == RailAgentStatus.ACTIVE: - if done[agent_handle] == False: - if check_for_dealock(agent_handle, train_env, agent_positions): - all_rewards[agent_handle] -= 1000.0 - done[agent_handle] = True + if done[agent_handle] == False: + if check_for_deadlock(agent_handle, train_env, agent_positions): + all_rewards[agent_handle] -= 1000.0 step_timer.end() @@ -559,17 +501,17 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=25000, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, + parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) - parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=2000, type=int) - parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) - parser.add_argument("--eps_end", help="min exploration", default=0.05, type=float) - parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float) - parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int) + parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) + parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) + parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float) + parser.add_argument("--eps_decay", help="exploration decay", default=0.99975, type=float) + parser.add_argument("--buffer_size", help="replay buffer size", default=int(32_000), type=int) parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index e603e70..e97b265 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -9,10 +9,12 @@ from torch.distributions import Categorical # Hyperparameters from reinforcement_learning.policy import Policy -device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("device:", device) +# https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html + class DataBuffers: def __init__(self): self.reset() @@ -44,7 +46,7 @@ class ActorCriticModel(nn.Module): nn.Tanh(), nn.Linear(hidsize2, action_size), nn.Softmax(dim=-1) - ) + ).to(device) self.critic = nn.Sequential( nn.Linear(state_size, hidsize1), @@ -52,7 +54,7 @@ class ActorCriticModel(nn.Module): nn.Linear(hidsize1, hidsize2), nn.Tanh(), nn.Linear(hidsize2, 1) - ) + ).to(device) def forward(self, x): raise NotImplementedError @@ -95,11 +97,11 @@ class PPOAgent(Policy): super(PPOAgent, self).__init__() # parameters - self.learning_rate = 0.1e-4 - self.gamma = 0.99 - self.surrogate_eps_clip = 0.2 - self.K_epoch = 30 - self.weight_loss = 1.0 + self.learning_rate = 1.0e-5 + self.gamma = 0.95 + self.surrogate_eps_clip = 0.1 + self.K_epoch = 50 + self.weight_loss = 0.5 self.weight_entropy = 0.01 # objects @@ -144,8 +146,8 @@ class PPOAgent(Policy): discounted_reward = 0 done_list.insert(0, 1) else: - discounted_reward = reward_i + self.gamma * discounted_reward done_list.insert(0, 0) + discounted_reward = reward_i + self.gamma * discounted_reward reward_list.insert(0, discounted_reward) state_next_list.insert(0, state_next_i) prob_a_list.insert(0, prob_action_i) @@ -160,22 +162,21 @@ class PPOAgent(Policy): torch.tensor(prob_a_list).to(device) # standard-normalize rewards - rewards = (rewards - rewards.mean()) / (rewards.std() + 1.e-5) + # rewards = (rewards - rewards.mean()) / (rewards.std() + 1.e-5) return states, actions, rewards, states_next, dones, prob_actions def train_net(self): - # Optimize policy for K epochs: - for _ in range(self.K_epoch): - # All agents have to propagate their experiences made during past episode - for handle in range(len(self.memory)): - # Extract agent's episode history (list of all transitions) - agent_episode_history = self.memory.get_transitions(handle) - if len(agent_episode_history) > 0: - # Convert the replay buffer to torch tensors (arrays) - states, actions, rewards, states_next, dones, probs_action = \ - self._convert_transitions_to_torch_tensors(agent_episode_history) - + # All agents have to propagate their experiences made during past episode + for handle in range(len(self.memory)): + # Extract agent's episode history (list of all transitions) + agent_episode_history = self.memory.get_transitions(handle) + if len(agent_episode_history) > 0: + # Convert the replay buffer to torch tensors (arrays) + states, actions, rewards, states_next, dones, probs_action = \ + self._convert_transitions_to_torch_tensors(agent_episode_history) + # Optimize policy for K epochs: + for _ in range(int(self.K_epoch)): # Evaluating actions (actor) and values (critic) logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions) @@ -201,8 +202,9 @@ class PPOAgent(Policy): self.optimizer.step() # Transfer the current loss to the agents loss (information) for debug purpose only - self.loss = loss.mean().detach().numpy() + self.loss = loss.mean().detach().cpu().numpy() + self.K_epoch = max(3, self.K_epoch - 0.01) # Reset all collect transition data self.memory.reset() diff --git a/run.py b/run.py index 1757f75..c4608bf 100644 --- a/run.py +++ b/run.py @@ -47,16 +47,18 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy # Print per-step logs VERBOSE = True USE_FAST_TREEOBS = True -USE_PPO_AGENT = True +USE_PPO_AGENT = False # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201124171810-7800.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 -# checkpoint = "./checkpoints/201126150143-5200.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 -# checkpoint = "./checkpoints/201126160144-2000.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 -checkpoint = "./checkpoints/201127160352-2000.pth" -checkpoint = "./checkpoints/201130083154-2000.pth" - -EPSILON = 0.005 +checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 +# checkpoint = "./checkpoints/201126150143-5200.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 +# checkpoint = "./checkpoints/201126160144-2000.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 +checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786 +checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857 +checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504 +checkpoint = "./checkpoints/201211164554-8900.pth" # DDDQN: 17.44397192482364 + +EPSILON = 0.01 # Use last action cache USE_ACTION_CACHE = False diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py index 6d414fa..d787c8c 100644 --- a/utils/deadlock_check.py +++ b/utils/deadlock_check.py @@ -1,5 +1,49 @@ +import numpy as np + from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import fast_count_nonzero + + +def get_agent_positions(env): + agent_positions: np.ndarray = np.full((env.height, env.width), -1) + for agent_handle in env.get_agent_handles(): + agent = env.agents[agent_handle] + if agent.status == RailAgentStatus.ACTIVE: + position = agent.position + if position is None: + position = agent.initial_position + agent_positions[position] = agent_handle + return agent_positions + + +def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None): + agent = env.agents[handle] + if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: + return False + + position = agent.position + if position is None: + position = agent.initial_position + if check_position is not None: + position = check_position + direction = agent.direction + if check_direction is not None: + direction = check_direction + + possible_transitions = env.rail.get_transitions(*position, direction) + num_transitions = fast_count_nonzero(possible_transitions) + for dir_loop in range(4): + if possible_transitions[dir_loop] == 1: + new_position = get_new_position(position, dir_loop) + opposite_agent = agent_positions[new_position] + if opposite_agent != handle and opposite_agent != -1: + num_transitions -= 1 + else: + return False + + is_deadlock = num_transitions <= 0 + return is_deadlock def check_if_all_blocked(env): diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index b104916..c45477d 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -7,6 +7,7 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent +from utils.deadlock_check import check_for_deadlock, get_agent_positions """ LICENCE for the FastTreeObs Observation Builder @@ -25,7 +26,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 36 + self.observation_dim = 41 def build_data(self): if self.env is not None: @@ -244,6 +245,7 @@ class FastTreeObs(ObservationBuilder): def get_many(self, handles: Optional[List[int]] = None): self.dead_lock_avoidance_agent.start_step(train=False) + self.agent_positions = get_agent_positions(self.env) observations = super().get_many(handles) self.dead_lock_avoidance_agent.end_step(train=False) return observations @@ -328,6 +330,11 @@ class FastTreeObs(ObservationBuilder): observation[19 + dir_loop] = has_same_agent observation[23 + dir_loop] = has_target observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist)) + observation[36] = int(check_for_deadlock(handle, + self.env, + self.agent_positions, + new_position, + branch_direction)) agents_on_switch, \ agents_near_to_switch, \ @@ -341,7 +348,9 @@ class FastTreeObs(ObservationBuilder): observation[10] = int(agents_near_to_switch_all) action = self.dead_lock_avoidance_agent.act([handle], 0.0) - observation[31] = int(action == RailEnvActions.STOP_MOVING) + observation[35] = int(action == RailEnvActions.STOP_MOVING) + + observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions)) self.env.dev_obs_dict.update({handle: visited}) -- GitLab