From 7bb7aebd8893d291415505441e9245e36b8de914 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 5 Nov 2020 21:12:41 +0100
Subject: [PATCH] DQN & PPO

---
 reinforcement_learning/dddqn_policy.py        |  2 +-
 .../multi_agent_training.py                   | 27 +++++++------------
 reinforcement_learning/policy.py              | 11 +++++++-
 reinforcement_learning/ppo/ppo_agent.py       | 24 +++++------------
 .../single_agent_training.py                  |  4 ++-
 5 files changed, 31 insertions(+), 37 deletions(-)

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 2cf7ad2..1c323c3 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -67,7 +67,7 @@ class DDDQNPolicy(Policy):
         else:
             return random.choice(np.arange(self.action_size))
 
-    def step(self, state, action, reward, next_state, done):
+    def step(self, handle, state, action, reward, next_state, done):
         assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
 
         # Save experience in replay memory
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 1d99a9b..6f250a8 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -18,13 +18,14 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from torch.utils.tensorboard import SummaryWriter
 
+from reinforcement_learning.ppo.ppo_agent import PPOAgent
+
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
 
 from utils.timer import Timer
 from utils.observation_utils import normalize_observation
-from utils.fast_tree_obs import FastTreeObs, fast_tree_obs_check_agent_deadlock
-from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from utils.fast_tree_obs import FastTreeObs
 
 try:
     import wandb
@@ -171,8 +172,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     completion_window = deque(maxlen=checkpoint_interval)
 
     # Double Dueling DQN policy
-    policy = DDDQNPolicy(state_size, action_size, train_params)
-
+    # policy = DDDQNPolicy(state_size, action_size, train_params)
+    policy = PPOAgent(state_size, action_size, n_agents)
     # Load existing policy
     if train_params.load_policy is not "":
         policy.load(train_params.load_policy)
@@ -226,6 +227,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         train_env_params.n_agents = episode_idx % n_agents + 1
         train_env = create_rail_env(train_env_params, tree_observation)
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
+        policy.reset()
         reset_timer.end()
 
         if train_params.render:
@@ -288,18 +290,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 if update_values[agent] or done['__all__']:
                     # Only learn from timesteps where somethings happened
                     learn_timer.start()
-                    call_step = True
-
-                    if not (agent_obs[agent][7] == 1 or agent_obs[agent][8] == 1 or agent_obs[agent][4] == 1):
-                        if action_dict.get(agent) == RailEnvActions.MOVE_FORWARD:
-                            call_step = np.random.random() < 0.1
-                    if fast_tree_obs_check_agent_deadlock(agent_obs[agent]):
-                        all_rewards[agent] -= 10
-                        call_step = True
-                    if call_step:
-                        policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
-                                    agent_obs[agent],
-                                    done[agent])
+                    policy.step(agent,
+                                agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
+                                agent_obs[agent],
+                                done[agent])
                     learn_timer.end()
 
                     agent_prev_obs[agent] = agent_obs[agent].copy()
@@ -444,7 +438,6 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         score = 0.0
 
         obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
-
         final_step = 0
 
         for step in range(max_steps - 1):
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index b8714d1..b605aa3 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,5 +1,5 @@
 class Policy:
-    def step(self, state, action, reward, next_state, done):
+    def step(self, handle, state, action, reward, next_state, done):
         raise NotImplementedError
 
     def act(self, state, eps=0.):
@@ -16,3 +16,12 @@ class Policy:
 
     def end_step(self):
         pass
+
+    def load_replay_buffer(self, filename):
+        pass
+
+    def test(self):
+        pass
+
+    def reset(self):
+        pass
\ No newline at end of file
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index e43cb30..663a05a 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -20,7 +20,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 class PPOAgent(Policy):
-    def __init__(self, state_size, action_size, num_agents, env):
+    def __init__(self, state_size, action_size, num_agents):
         self.action_size = action_size
         self.state_size = state_size
         self.num_agents = num_agents
@@ -31,7 +31,7 @@ class PPOAgent(Policy):
         self.memory = ReplayBuffer(BUFFER_SIZE)
         self.t_step = 0
         self.loss = 0
-        self.env = env
+        self.num_agents = num_agents
 
     def reset(self):
         self.finished = [False] * len(self.episodes)
@@ -39,21 +39,11 @@ class PPOAgent(Policy):
 
     # Decide on an action to take in the environment
 
-    def act(self, handle, state, eps=None):
-        if True:
-            self.policy.eval()
-            with torch.no_grad():
-                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
-                return Categorical(output).sample().item()
-
-        # Epsilon-greedy action selection
-        if random.random() > eps:
-            self.policy.eval()
-            with torch.no_grad():
-                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
-                return Categorical(output).sample().item()
-        else:
-            return random.choice(np.arange(self.action_size))
+    def act(self, state, eps=None):
+        self.policy.eval()
+        with torch.no_grad():
+            output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+            return Categorical(output).sample().item()
 
     # Record the results of the agent's action and update the model
     def step(self, handle, state, action, reward, next_state, done):
diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py
index 236d1a7..bfcc886 100644
--- a/reinforcement_learning/single_agent_training.py
+++ b/reinforcement_learning/single_agent_training.py
@@ -146,7 +146,9 @@ def train_agent(n_episodes):
             for agent in range(env.get_num_agents()):
                 # Only update the values when we are done or when an action was taken and thus relevant information is present
                 if update_values or done[agent]:
-                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
+                    policy.step(agent,
+                                agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
+                                agent_obs[agent], done[agent])
 
                     agent_prev_obs[agent] = agent_obs[agent].copy()
                     agent_prev_action[agent] = action_dict[agent]
-- 
GitLab