Skip to content
Snippets Groups Projects
Commit 788ac51b authored by gmollard's avatar gmollard
Browse files

small changes for the observation benchmark, nr_extra=30

parent 4470e889
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -13,7 +13,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# number_of_agents=1, # number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)): # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(MultiAgentEnv, self).__init__() super(MultiAgentEnv, self).__init__()
self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5, self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5, nr_extra=30,
seed=config['seed'] * (1+config.vector_index)) seed=config['seed'] * (1+config.vector_index))
set_seed(config['seed'] * (1+config.vector_index)) set_seed(config['seed'] * (1+config.vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator, self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
......
run_experiment.name = "observation_benchmark_results" run_experiment.name = "observation_benchmark_results"
run_experiment.num_iterations = 1002 run_experiment.num_iterations = 1002
run_experiment.save_every = 200 run_experiment.save_every = 100
run_experiment.hidden_sizes = [32, 32] run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20 run_experiment.map_width = 20
......
...@@ -34,7 +34,7 @@ from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor ...@@ -34,7 +34,7 @@ from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor) ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
ray.init() ray.init(object_store_memory=150000000000, redis_max_memory=30000000000)
def train(config, reporter): def train(config, reporter):
...@@ -101,7 +101,7 @@ def train(config, reporter): ...@@ -101,7 +101,7 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 8 trainer_config["num_cpus_per_worker"] = 10
trainer_config["num_gpus"] = 0.5 trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0.5 trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 2 trainer_config["num_cpus_for_driver"] = 2
...@@ -155,7 +155,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -155,7 +155,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"seed": seed "seed": seed
}, },
resources_per_trial={ resources_per_trial={
"cpu": 10, "cpu": 12,
"gpu": 0.5 "gpu": 0.5
}, },
local_dir=local_dir local_dir=local_dir
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment