From cb4df49c3ae2036b16bbf7f2f79711d1f86c66c4 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Fri, 18 Dec 2020 11:16:23 +0100
Subject: [PATCH] test

---
 reinforcement_learning/multi_agent_training.py        | 2 +-
 reinforcement_learning/ppo_agent.py                   | 5 ++---
 reinforcement_learning/ppo_deadlockavoidance_agent.py | 2 +-
 3 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index b9a819a..cce7ecc 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         policy = PPOPolicy(state_size, get_action_size())
     if False:
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
-    if False:
+    if True:
         policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy)
 
     # Load existing policy
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index 43de9f7..5c0fc08 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -107,8 +107,8 @@ class PPOPolicy(Policy):
         self.weight_loss = 0.25
         self.weight_entropy = 0.01
 
-        self.buffer_size = 2_000
-        self.batch_size = 64
+        self.buffer_size = 32_000
+        self.batch_size = 1024
         self.buffer_min_size = 0
         self.use_replay_buffer = True
         self.device = device
@@ -187,7 +187,6 @@ class PPOPolicy(Policy):
                                                     reward_list, state_next_list,
                                                     done_list, prob_a_list)
 
-
         # convert data to torch tensors
         states, actions, rewards, states_next, dones, prob_actions = \
             torch.tensor(state_list, dtype=torch.float).to(self.device), \
diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py
index 737634c..6e8880c 100644
--- a/reinforcement_learning/ppo_deadlockavoidance_agent.py
+++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py
@@ -39,7 +39,7 @@ class MultiDecisionAgent(Policy):
                 act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
                 return map_rail_env_action(act)
         # Agent is still at target cell
-        return RailEnvActions.DO_NOTHING
+        return map_rail_env_action(RailEnvActions.DO_NOTHING)
 
     def save(self, filename):
         self.dead_lock_avoidance_agent.save(filename)
-- 
GitLab