diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 80e262018db4a3bdc38790fb5f1fd588c2e85516..195b46a4b19c2c2db0bf94a2fa1bdbb9a670f1cb 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -110,7 +110,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Observation builder predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) - if not train_params.use_extra_observation: + if not train_params.use_fast_tree_observation: print("\nUsing standard TreeObs") def check_is_observation_valid(observation): @@ -141,7 +141,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): eval_env = create_rail_env(eval_env_params, tree_observation) eval_env.reset(regenerate_schedule=True, regenerate_rail=True) - if not train_params.use_extra_observation: + if not train_params.use_fast_tree_observation: # Calculate the state size given the depth of the tree observation and the number of features n_features_per_node = train_env.obs_builder.observation_dim n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) @@ -495,7 +495,7 @@ if __name__ == "__main__": parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) parser.add_argument("--render", help="render 1 episode in 100", action='store_true') parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) - parser.add_argument("--use_extra_observation", help="extra observation", action='store_true') + parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') parser.add_argument("--max_depth", help="max depth", default=2, type=int) training_params = parser.parse_args() @@ -557,7 +557,7 @@ if __name__ == "__main__": # FIXME hard-coded for sweep search # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly - # training_params.use_extra_observation = True + # training_params.use_fast_tree_observation = True print("\nTraining parameters:") pprint(vars(training_params))