diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 487d44d70b3387fb03df6af5c7ae030862ba8d0a..4bdbe4eaf13c56ed36847c9b76402fd73c7347af 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -172,7 +172,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy - USE_SINGLE_AGENT_TRAINING = False + USE_SINGLE_AGENT_TRAINING = True + UPDATE_POLICY2_N_EPISODE = 1000 policy = DDDQNPolicy(state_size, action_size, train_params) # policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy @@ -227,7 +228,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) policy.reset() - if episode_idx % 100 == 0: + if episode_idx % UPDATE_POLICY2_N_EPISODE == 0: policy2 = policy.clone() reset_timer.end() @@ -499,14 +500,14 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) + parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=54000, type=int) parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=2, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) - parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) + parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float) parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int) @@ -563,6 +564,17 @@ if __name__ == "__main__": "malfunction_rate": 1 / 200, "seed": 0 }, + { + # Test_3 + "n_agents": 58, + "x_dim": 40, + "y_dim": 40, + "n_cities": 5, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "malfunction_rate": 1 / 200, + "seed": 0 + }, ] obs_params = { diff --git a/run.py b/run.py index f4eb48e4aef8bce112a71e68355b7c7e5b85adb2..08b858082fe95dd798cd45c6d3079b03bdcd3ed8 100644 --- a/run.py +++ b/run.py @@ -26,7 +26,10 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201112143850-4100.pth" # 21.543589381053096 DEPTH=2 +checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10 +checkpoint = "./checkpoints/201113070245-5400.pth" # 19.690047767961005 DEPTH=2 AGENTS=20 +checkpoint = "./checkpoints/201113211844-6100.pth" # 19.690047767961005 DEPTH=2 AGENTS=20 + # Use last action cache USE_ACTION_CACHE = False