diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index cc6cfcee87a1428d08f16357999750ed6665d541..3237ce7e32ebf38074896d58641079fcce8d1315 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Double Dueling DQN policy policy = DDDQNPolicy(state_size, action_size, train_params) - if False: + if True: policy = PPOAgent(state_size, action_size) # Load existing policy if train_params.load_policy is not "": diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index e97b26512156d048087ea85999aa4f63b80a66d5..703c9560cf86750350abd26bd23a0ec1c55fb6be 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -9,7 +9,7 @@ from torch.distributions import Categorical # Hyperparameters from reinforcement_learning.policy import Policy -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu") print("device:", device) @@ -145,8 +145,10 @@ class PPOAgent(Policy): if done_i: discounted_reward = 0 done_list.insert(0, 1) + reward_i = 1 else: done_list.insert(0, 0) + reward_i = 0 discounted_reward = reward_i + self.gamma * discounted_reward reward_list.insert(0, discounted_reward) state_next_list.insert(0, state_next_i) diff --git a/run.py b/run.py index 8a967756cf122cb25f968e13143e6d9423c2904e..667b7c4779fc527d803845ff25a045f6ccea016d 100644 --- a/run.py +++ b/run.py @@ -47,7 +47,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy # Print per-step logs VERBOSE = True USE_FAST_TREEOBS = True -USE_PPO_AGENT = False +USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 @@ -57,15 +57,16 @@ checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786 checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857 checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504 checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537 +checkpoint = "./checkpoints/201212190452-6500.pth" # PPO: 13.944402986414723 -EPSILON = 0.01 +EPSILON = 0.0 # Use last action cache USE_ACTION_CACHE = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213 # Observation parameters (must match training parameters!) -observation_tree_depth = 1 +observation_tree_depth = 2 observation_radius = 10 observation_max_path_depth = 30