diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 3e461fd9fe3e13aa341c8e4aab1cde6af2854bf7..e2ea4bfea061fc42e21978a580e25e8e8139449b 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -22,9 +22,9 @@ from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo_agent import PPOAgent from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent +from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import get_agent_positions, check_for_deadlock -from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -174,9 +174,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = DDDQNPolicy(state_size, get_action_size(), train_params) if False: policy = PPOAgent(state_size, get_action_size()) - if True: + if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) - if True: + if False: policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) # Load existing policy @@ -387,7 +387,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): '\t 🎲 Epsilon: {:.3f} ' '\t 🔀 Action Probs: {}'.format( episode_idx, - train_env_params.n_agents, train_env.get_num_agents(), + train_env_params.n_agents, number_of_agents, normalized_score, smoothed_normalized_score, 100 * completion, @@ -521,11 +521,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=3, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float) diff --git a/run.py b/run.py index 0ba9acc5a0f976ea7373e6925ef67411978f1a42..998add015fca7d6b57b492131249076bd2382367 100644 --- a/run.py +++ b/run.py @@ -49,26 +49,18 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy # Print per-step logs VERBOSE = True USE_FAST_TREEOBS = True -USE_PPO_AGENT = True +USE_PPO_AGENT = False # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 -# checkpoint = "./checkpoints/201126150143-5200.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 -# checkpoint = "./checkpoints/201126160144-2000.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 -checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786 -checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857 -checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504 -checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537 -checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723 -checkpoint = "./checkpoints/201214140158-5000.pth" # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723 -checkpoint = "./checkpoints/201215120518-3700.pth" # USE_MULTI_DECISION_AGENT with PPO: 13.944402986414723 +checkpoint = "./checkpoints/201215160226-12000.pth" # +# checkpoint = "./checkpoints/201215212134-12000.pth" # EPSILON = 0.0 # Use last action cache USE_ACTION_CACHE = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213 -USE_MULTI_DECISION_AGENT = False +USE_MULTI_DECISION_AGENT = True # Observation parameters (must match training parameters!) observation_tree_depth = 2