From 409a87d483a7dc7b439f1009de47076cac2f6829 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 14 Jan 2021 09:44:46 +0100
Subject: [PATCH] number of agents (fixed or 1..)

---
 reinforcement_learning/multi_agent_training.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index b566d5a..340a21f 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -236,8 +236,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
-        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
-        train_env_params.n_agents = episode_idx % number_of_agents + 1
+        if train_params.n_agent_fixed:
+            number_of_agents = n_agents
+            train_env_params.n_agents = n_agents
+        else:
+            number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
+            train_env_params.n_agents = episode_idx % number_of_agents + 1
 
         train_env = create_rail_env(train_env_params, tree_observation)
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
@@ -506,6 +510,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 if __name__ == "__main__":
     parser = ArgumentParser()
     parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5000, type=int)
+    parser.add_argument("--n_agent_fixed", help="hold the number of agent fixed", default=False, type=bool)
     parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
                         type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
-- 
GitLab