diff --git a/train_experiment.py b/train_experiment.py index 68b45684bc4dcaf5efa190f731d03a436e94dbaa..0c1af1727abaaaf3078e24d7a250f071bda6cb9c 100644 --- a/train_experiment.py +++ b/train_experiment.py @@ -8,7 +8,8 @@ from flatland.envs.generators import complex_rail_generator # Import PPO trainer: we can replace these imports by any other trainer from RLLib. from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph +# from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph +from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print @@ -34,7 +35,7 @@ from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor) -ray.init(object_store_memory=150000000000, redis_max_memory=30000000000) +ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000) def train(config, reporter): @@ -62,9 +63,6 @@ def train(config, reporter): "seed": config['seed'], "obs_builder": config['obs_builder']} - print(config["obs_builder"]) - print(config["obs_builder"].__class__) - print(type(TreeObsForRailEnv)) # Observation space and action space definitions if isinstance(config["obs_builder"], TreeObsForRailEnv): obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) @@ -73,7 +71,8 @@ def train(config, reporter): elif isinstance(config["obs_builder"], GlobalObsForRailEnv): obs_space = gym.spaces.Tuple(( gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), - gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])), + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 3)), + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 4)), gym.spaces.Box(low=0, high=1, shape=(4,)))) preprocessor = "global_obs_prep" @@ -101,14 +100,15 @@ def train(config, reporter): trainer_config["horizon"] = config['horizon'] trainer_config["num_workers"] = 0 - trainer_config["num_cpus_per_worker"] = 10 - trainer_config["num_gpus"] = 0.5 - trainer_config["num_gpus_per_worker"] = 0.5 - trainer_config["num_cpus_for_driver"] = 2 - trainer_config["num_envs_per_worker"] = 10 + trainer_config["num_cpus_per_worker"] = 3 + trainer_config["num_gpus"] = 0 + trainer_config["num_gpus_per_worker"] = 0 + trainer_config["num_cpus_for_driver"] = 1 + trainer_config["num_envs_per_worker"] = 1 trainer_config["env_config"] = env_config trainer_config["batch_mode"] = "complete_episodes" - trainer_config['simple_optimizer'] = False + trainer_config['simple_optimizer'] = True + trainer_config['postprocess_inputs'] = True def logger_creator(conf): """Creates a Unified logger with a default logdir prefix @@ -155,8 +155,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, "seed": seed }, resources_per_trial={ - "cpu": 12, - "gpu": 0.5 + "cpu": 2, + "gpu": 0.0 }, local_dir=local_dir ) @@ -164,6 +164,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, if __name__ == '__main__': gin.external_configurable(tune.grid_search) - dir = '/mount/SDC/flatland/baselines/experiment_configs/observation_benchmark' # To Modify + dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/experiment_configs/observation_benchmark' # To Modify gin.parse_config_file(dir + '/config.gin') run_experiment(local_dir=dir)