diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 80e262018db4a3bdc38790fb5f1fd588c2e85516..195b46a4b19c2c2db0bf94a2fa1bdbb9a670f1cb 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -110,7 +110,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Observation builder
     predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-    if not train_params.use_extra_observation:
+    if not train_params.use_fast_tree_observation:
         print("\nUsing standard TreeObs")
 
         def check_is_observation_valid(observation):
@@ -141,7 +141,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     eval_env = create_rail_env(eval_env_params, tree_observation)
     eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
 
-    if not train_params.use_extra_observation:
+    if not train_params.use_fast_tree_observation:
         # Calculate the state size given the depth of the tree observation and the number of features
         n_features_per_node = train_env.obs_builder.observation_dim
         n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
@@ -495,7 +495,7 @@ if __name__ == "__main__":
     parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
     parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
-    parser.add_argument("--use_extra_observation", help="extra observation", action='store_true')
+    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true')
     parser.add_argument("--max_depth", help="max depth", default=2, type=int)
 
     training_params = parser.parse_args()
@@ -557,7 +557,7 @@ if __name__ == "__main__":
 
     # FIXME hard-coded for sweep search
     # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
-    # training_params.use_extra_observation = True
+    # training_params.use_fast_tree_observation = True
 
     print("\nTraining parameters:")
     pprint(vars(training_params))