From c95088f84a483570b0d14fe5c5656507789ce393 Mon Sep 17 00:00:00 2001
From: flaurent <florian.laurent@gmail.com>
Date: Sat, 24 Oct 2020 16:16:37 +0200
Subject: [PATCH] Renamed FastTreeObs command line flag

---
 reinforcement_learning/multi_agent_training.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 80e2620..195b46a 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -110,7 +110,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Observation builder
     predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-    if not train_params.use_extra_observation:
+    if not train_params.use_fast_tree_observation:
         print("\nUsing standard TreeObs")
 
         def check_is_observation_valid(observation):
@@ -141,7 +141,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     eval_env = create_rail_env(eval_env_params, tree_observation)
     eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
 
-    if not train_params.use_extra_observation:
+    if not train_params.use_fast_tree_observation:
         # Calculate the state size given the depth of the tree observation and the number of features
         n_features_per_node = train_env.obs_builder.observation_dim
         n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
@@ -495,7 +495,7 @@ if __name__ == "__main__":
     parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
     parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
-    parser.add_argument("--use_extra_observation", help="extra observation", action='store_true')
+    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true')
     parser.add_argument("--max_depth", help="max depth", default=2, type=int)
 
     training_params = parser.parse_args()
@@ -557,7 +557,7 @@ if __name__ == "__main__":
 
     # FIXME hard-coded for sweep search
     # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
-    # training_params.use_extra_observation = True
+    # training_params.use_fast_tree_observation = True
 
     print("\nTraining parameters:")
     pprint(vars(training_params))
-- 
GitLab