diff --git a/checkpoints/201103180606-1400.pth.local b/checkpoints/201103180606-1400.pth.local deleted file mode 100644 index 0b42ee0314e712d253f0d9057786a992c1ac0fd4..0000000000000000000000000000000000000000 Binary files a/checkpoints/201103180606-1400.pth.local and /dev/null differ diff --git a/checkpoints/201103180606-1400.pth.target b/checkpoints/201103180606-1400.pth.target deleted file mode 100644 index a751a2b15a0492f028861eddcb21b4cf6d3c0e6d..0000000000000000000000000000000000000000 Binary files a/checkpoints/201103180606-1400.pth.target and /dev/null differ diff --git a/checkpoints/201103221432-3000.pth.local b/checkpoints/201103221432-3000.pth.local deleted file mode 100644 index d965cbc82edabbd07b19440c1d7ba22de4a1614e..0000000000000000000000000000000000000000 Binary files a/checkpoints/201103221432-3000.pth.local and /dev/null differ diff --git a/checkpoints/201103221432-3000.pth.target b/checkpoints/201103221432-3000.pth.target deleted file mode 100644 index f03ee7722253c9a1f20a6b8361f7f7cec7ba1d03..0000000000000000000000000000000000000000 Binary files a/checkpoints/201103221432-3000.pth.target and /dev/null differ diff --git a/checkpoints/201105050310-2300.pth.local b/checkpoints/201105050310-2300.pth.local deleted file mode 100644 index 22d8eda9a7a272ad3dbbd2056aa91277048764e5..0000000000000000000000000000000000000000 Binary files a/checkpoints/201105050310-2300.pth.local and /dev/null differ diff --git a/checkpoints/201105050310-2300.pth.target b/checkpoints/201105050310-2300.pth.target deleted file mode 100644 index 443db7835c8b770f73f4b1ed7893c6d55ecdb0fc..0000000000000000000000000000000000000000 Binary files a/checkpoints/201105050310-2300.pth.target and /dev/null differ diff --git a/checkpoints/201106073658-4400.pth.optimizer b/checkpoints/201106073658-4400.pth.optimizer deleted file mode 100644 index a860868a8beb873d2ec52ee3cb14c7ed715ebdae..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106073658-4400.pth.optimizer and /dev/null differ diff --git a/checkpoints/201106073658-4400.pth.policy b/checkpoints/201106073658-4400.pth.policy deleted file mode 100644 index fa6348dc3330fa4699b2ee429296bc97c67f044c..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106073658-4400.pth.policy and /dev/null differ diff --git a/checkpoints/201106090621-3300.pth.local b/checkpoints/201106090621-3300.pth.local deleted file mode 100644 index 453e1cdb7f0166eb52de7e89a257b49be88e36f8..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106090621-3300.pth.local and /dev/null differ diff --git a/checkpoints/201106090621-3300.pth.target b/checkpoints/201106090621-3300.pth.target deleted file mode 100644 index f94422b5b2c56cf46dae8cbbba7189d558581fa4..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106090621-3300.pth.target and /dev/null differ diff --git a/checkpoints/201106090621-4500.pth.local b/checkpoints/201106090621-4500.pth.local deleted file mode 100644 index 5b608c3ba00b82a04d5150dc653ea3fd94a6de68..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106090621-4500.pth.local and /dev/null differ diff --git a/checkpoints/201106090621-4500.pth.target b/checkpoints/201106090621-4500.pth.target deleted file mode 100644 index 8d1c3d4b211eacbb51704cefea4d7aeb083b50c5..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106090621-4500.pth.target and /dev/null differ diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index db4b1a805e4d3fa5a3c29bc58224ec88d2d4b2f1..b28eb694a78b6f2a2de515eeee7a061bee3f2d3d 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -20,6 +20,7 @@ from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo.ppo_agent import PPOAgent +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -229,6 +230,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): train_env = create_rail_env(train_env_params, tree_observation) obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) policy.reset() + + policy2 = DeadLockAvoidanceAgent(train_env) + policy2.reset() + reset_timer.end() if train_params.render: @@ -252,15 +257,21 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): max_steps = train_env._max_episode_steps # Run episode + agent_to_learn = 0 + if train_env.get_num_agents() > 1: + agent_to_learn = np.random.choice(train_env.get_num_agents()) for step in range(max_steps - 1): inference_timer.start() policy.start_step() + policy2.start_step() for agent in train_env.get_agent_handles(): if info['action_required'][agent]: update_values[agent] = True - action = policy.act(agent_obs[agent], eps=eps_start) - + if agent == agent_to_learn: + action = policy.act(agent_obs[agent], eps=eps_start) + else: + action = policy2.act([agent], eps=eps_start) action_count[action] += 1 actions_taken.append(action) else: @@ -270,6 +281,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): action = 0 action_dict.update({agent: action}) policy.end_step() + policy2.end_step() inference_timer.end() # Environment step @@ -291,10 +303,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if update_values[agent] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - policy.step(agent, - agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], - agent_obs[agent], - done[agent]) + if agent == agent_to_learn: + policy.step(agent, + agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], + agent_obs[agent], + done[agent]) learn_timer.end() agent_prev_obs[agent] = agent_obs[agent].copy() @@ -481,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) @@ -506,7 +519,7 @@ if __name__ == "__main__": parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') - parser.add_argument("--max_depth", help="max depth", default=1, type=int) + parser.add_argument("--max_depth", help="max depth", default=2, type=int) training_params = parser.parse_args() env_params = [ diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index be23960414fbb57628a400a300e9d90e00ae202e..a7431f85201def6f189ccdc6101a89428b598e47 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -39,11 +39,6 @@ class PPOAgent(Policy): # Decide on an action to take in the environment def act(self, state, eps=None): - # if eps is not None: - # # Epsilon-greedy action selection - # if np.random.random() < eps: - # return np.random.choice(np.arange(self.action_size)) - self.policy.eval() with torch.no_grad(): output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))