From 155cd80421bf986d87ca22fdb88a38768dbbfc0e Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Mon, 21 Dec 2020 08:05:03 +0100
Subject: [PATCH] experiment with ppo

---
 .../multi_agent_training.py                   | 30 +++----------------
 run.py                                        |  3 ++
 2 files changed, 7 insertions(+), 26 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index a3a2b1a..4c06133 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -174,7 +174,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, get_action_size(), train_params)
     if True:
-        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=True, in_parameters=train_params)
+        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
     if False:
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
     if False:
@@ -282,28 +282,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             # Environment step
             step_timer.start()
             next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict))
-
-            # Reward shaping .Dead-lock .NotMoving .NotStarted
-            if True:
-                agent_positions = get_agent_positions(train_env)
-                for agent_handle in train_env.get_agent_handles():
-                    agent = train_env.agents[agent_handle]
-                    act = map_action(action_dict.get(agent_handle, map_rail_env_action(RailEnvActions.DO_NOTHING)))
-                    if agent.status == RailAgentStatus.ACTIVE:
-                        if done[agent_handle] == False:
-                            if check_for_deadlock(agent_handle, train_env, agent_positions):
-                                all_rewards[agent_handle] -= 5.0
-                            else:
-                                pos = agent.position
-                                possible_transitions = train_env.rail.get_transitions(*pos, agent.direction)
-                                num_transitions = fast_count_nonzero(possible_transitions)
-                                if num_transitions < 2 and ((act != RailEnvActions.MOVE_FORWARD) or
-                                                            (act != RailEnvActions.STOP_MOVING)):
-                                    all_rewards[agent_handle] -= 1.0
-                        else:
-                            all_rewards[agent_handle] *= 9.0
-                            all_rewards[agent_handle] += 1.0
-
             step_timer.end()
 
             # Render an episode at some interval
@@ -524,9 +502,9 @@ if __name__ == "__main__":
                         type=int)
     parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
-    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
-    parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float)
-    parser.add_argument("--eps_decay", help="exploration decay", default=0.99975, type=float)
+    parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
+    parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float)
     parser.add_argument("--buffer_size", help="replay buffer size", default=int(32_000), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
diff --git a/run.py b/run.py
index c04de85..6578fee 100644
--- a/run.py
+++ b/run.py
@@ -54,6 +54,9 @@ USE_PPO_AGENT = True
 # Checkpoint to use (remember to push it!)
 checkpoint = "./checkpoints/201219090514-8600.pth"  #
 # checkpoint = "./checkpoints/201215212134-12000.pth"  #
+checkpoint = "./checkpoints/201220171629-12000.pth"  # DDDQN - EPSILON: 0.0 - 13.940848323912533
+checkpoint = "./checkpoints/201220203236-12000.pth"  # PPO - EPSILON: 0.0 - 13.660942453931114
+checkpoint = "./checkpoints/201220214325-12000.pth"  # PPO - EPSILON: 0.0 - 13.463600936043
 
 EPSILON = 0.0
 
-- 
GitLab