From 155cd80421bf986d87ca22fdb88a38768dbbfc0e Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Mon, 21 Dec 2020 08:05:03 +0100 Subject: [PATCH] experiment with ppo --- .../multi_agent_training.py | 30 +++---------------- run.py | 3 ++ 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index a3a2b1a..4c06133 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -174,7 +174,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Double Dueling DQN policy policy = DDDQNPolicy(state_size, get_action_size(), train_params) if True: - policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=True, in_parameters=train_params) + policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) if False: @@ -282,28 +282,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment step step_timer.start() next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict)) - - # Reward shaping .Dead-lock .NotMoving .NotStarted - if True: - agent_positions = get_agent_positions(train_env) - for agent_handle in train_env.get_agent_handles(): - agent = train_env.agents[agent_handle] - act = map_action(action_dict.get(agent_handle, map_rail_env_action(RailEnvActions.DO_NOTHING))) - if agent.status == RailAgentStatus.ACTIVE: - if done[agent_handle] == False: - if check_for_deadlock(agent_handle, train_env, agent_positions): - all_rewards[agent_handle] -= 5.0 - else: - pos = agent.position - possible_transitions = train_env.rail.get_transitions(*pos, agent.direction) - num_transitions = fast_count_nonzero(possible_transitions) - if num_transitions < 2 and ((act != RailEnvActions.MOVE_FORWARD) or - (act != RailEnvActions.STOP_MOVING)): - all_rewards[agent_handle] -= 1.0 - else: - all_rewards[agent_handle] *= 9.0 - all_rewards[agent_handle] += 1.0 - step_timer.end() # Render an episode at some interval @@ -524,9 +502,9 @@ if __name__ == "__main__": type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) - parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) - parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float) - parser.add_argument("--eps_decay", help="exploration decay", default=0.99975, type=float) + parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) + parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) + parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float) parser.add_argument("--buffer_size", help="replay buffer size", default=int(32_000), type=int) parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) diff --git a/run.py b/run.py index c04de85..6578fee 100644 --- a/run.py +++ b/run.py @@ -54,6 +54,9 @@ USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) checkpoint = "./checkpoints/201219090514-8600.pth" # # checkpoint = "./checkpoints/201215212134-12000.pth" # +checkpoint = "./checkpoints/201220171629-12000.pth" # DDDQN - EPSILON: 0.0 - 13.940848323912533 +checkpoint = "./checkpoints/201220203236-12000.pth" # PPO - EPSILON: 0.0 - 13.660942453931114 +checkpoint = "./checkpoints/201220214325-12000.pth" # PPO - EPSILON: 0.0 - 13.463600936043 EPSILON = 0.0 -- GitLab