From 5c8a88cf573c29a575cbdc9000a7ce78a2d2db49 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 11 Nov 2020 17:52:48 +0100 Subject: [PATCH] FastTreeObs (fix) -> 0.8773 --- reinforcement_learning/dddqn_policy.py | 14 ++++++- .../multi_agent_training.py | 40 ++++++++++--------- reinforcement_learning/policy.py | 5 ++- utils/fast_tree_obs.py | 4 +- 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 6218ab8..0350a94 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -17,6 +17,7 @@ class DDDQNPolicy(Policy): """Dueling Double DQN policy""" def __init__(self, state_size, action_size, parameters, evaluation_mode=False): + self.parameters = parameters self.evaluation_mode = evaluation_mode self.state_size = state_size @@ -59,11 +60,16 @@ class DDDQNPolicy(Policy): self.qnetwork_local.eval() with torch.no_grad(): action_values = self.qnetwork_local(state) + self.qnetwork_local.train() # Epsilon-greedy action selection - if random.random() > eps: + if random.random() >= eps: return np.argmax(action_values.cpu().data.numpy()) + qvals = action_values.cpu().data.numpy()[0] + qvals = qvals - np.min(qvals) + qvals = qvals / (1e-5 + np.sum(qvals)) + return np.argmax(np.random.multinomial(1, qvals)) else: return random.choice(np.arange(self.action_size)) @@ -148,6 +154,12 @@ class DDDQNPolicy(Policy): self.act(np.array([[0] * self.state_size])) self._learn() + def clone(self): + me = DDDQNPolicy(self.state_size, self.action_size, self.parameters, evaluation_mode=True) + me.qnetwork_target = copy.deepcopy(self.qnetwork_local) + me.qnetwork_target = copy.deepcopy(self.qnetwork_target) + return me + Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 0f26a20..1a1f512 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -19,7 +19,6 @@ from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy -from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -173,6 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy + USE_SINGLE_AGENT_TRAINING = False policy = DDDQNPolicy(state_size, action_size, train_params) # policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy @@ -227,8 +227,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) policy.reset() - policy2 = DeadLockAvoidanceAgent(train_env) - policy2.reset() + if episode_idx % 100 == 0: + policy2 = policy.clone() reset_timer.end() @@ -253,9 +253,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): max_steps = train_env._max_episode_steps # Run episode - agent_to_learn = 0 + agent_to_learn = [0] if train_env.get_num_agents() > 1: - agent_to_learn = np.random.choice(train_env.get_num_agents()) + agent_to_learn = np.unique(np.random.choice(train_env.get_num_agents(), train_env.get_num_agents())) + # agent_to_learn = np.arange(train_env.get_num_agents()) + for step in range(max_steps - 1): inference_timer.start() policy.start_step() @@ -263,11 +265,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): for agent in train_env.get_agent_handles(): if info['action_required'][agent]: update_values[agent] = True - - if agent == agent_to_learn or True: + if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING: action = policy.act(agent_obs[agent], eps=eps_start) else: - action = policy2.act([agent], eps=eps_start) + action = policy2.act(agent_obs[agent], eps=eps_start) action_count[action] += 1 actions_taken.append(action) else: @@ -316,7 +317,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if update_values[agent] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - if agent == agent_to_learn: + if agent in agent_to_learn: policy.step(agent, agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], @@ -507,27 +508,28 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, + type=int) + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) - parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) + parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) - parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float) + parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float) parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int) parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, type=bool) parser.add_argument("--batch_size", help="minibatch size", default=128, type=int) - parser.add_argument("--gamma", help="discount factor", default=0.99, type=float) - parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float) + parser.add_argument("--gamma", help="discount factor", default=0.97, type=float) + parser.add_argument("--tau", help="soft update of target parameters", default=0.5e-3, type=float) parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float) parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int) - parser.add_argument("--update_every", help="how often to update the network", default=8, type=int) - parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool) - parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) + parser.add_argument("--update_every", help="how often to update the network", default=10, type=int) + parser.add_argument("--use_gpu", help="use GPU if available", default=True, type=bool) + parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=4, type=int) parser.add_argument("--render", help="render 1 episode in 100", action='store_true') parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index c7621a6..c7300de 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -24,4 +24,7 @@ class Policy: pass def reset(self): - pass \ No newline at end of file + pass + + def clone(self): + return self \ No newline at end of file diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index db22a8f..0666ef4 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -222,8 +222,8 @@ class FastTreeObs(ObservationBuilder): dir_loop, depth + 1) visited.append(v) - has_opp_agent = max(has_opp_agent, hoa) - has_same_agent = max(has_same_agent, hsa) + has_opp_agent += hoa * 2 ** (-1 - depth) + has_same_agent += hsa * 2 ** (-1 - depth) has_target = max(has_target, ht) return has_opp_agent, has_same_agent, has_target, visited else: -- GitLab