Skip to content
Snippets Groups Projects
Commit 8b1fc5bd authored by gmollard's avatar gmollard
Browse files

changed dimensions for global observation with other agents direction

parent b106448c
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,8 @@ from flatland.envs.generators import complex_rail_generator
# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
# from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
......@@ -34,7 +35,7 @@ from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
ray.init(object_store_memory=150000000000, redis_max_memory=30000000000)
ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000)
def train(config, reporter):
......@@ -62,9 +63,6 @@ def train(config, reporter):
"seed": config['seed'],
"obs_builder": config['obs_builder']}
print(config["obs_builder"])
print(config["obs_builder"].__class__)
print(type(TreeObsForRailEnv))
# Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
......@@ -73,7 +71,8 @@ def train(config, reporter):
elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
obs_space = gym.spaces.Tuple((
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])),
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 3)),
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 4)),
gym.spaces.Box(low=0, high=1, shape=(4,))))
preprocessor = "global_obs_prep"
......@@ -101,14 +100,15 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 10
trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 2
trainer_config["num_envs_per_worker"] = 10
trainer_config["num_cpus_per_worker"] = 3
trainer_config["num_gpus"] = 0
trainer_config["num_gpus_per_worker"] = 0
trainer_config["num_cpus_for_driver"] = 1
trainer_config["num_envs_per_worker"] = 1
trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes"
trainer_config['simple_optimizer'] = False
trainer_config['simple_optimizer'] = True
trainer_config['postprocess_inputs'] = True
def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix
......@@ -155,8 +155,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"seed": seed
},
resources_per_trial={
"cpu": 12,
"gpu": 0.5
"cpu": 2,
"gpu": 0.0
},
local_dir=local_dir
)
......@@ -164,6 +164,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
dir = '/mount/SDC/flatland/baselines/experiment_configs/observation_benchmark' # To Modify
dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/experiment_configs/observation_benchmark' # To Modify
gin.parse_config_file(dir + '/config.gin')
run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment