diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py index da54ad01f6d5536c0aab129509ae501f93e5ead9..fbdee61432f0aa2f558f543c002e9846570dc140 100644 --- a/RailEnvRLLibWrapper.py +++ b/RailEnvRLLibWrapper.py @@ -3,6 +3,8 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.generators import random_rail_generator from ray.rllib.utils.seed import seed as set_seed +import numpy as np + class RailEnvRLLibWrapper(MultiAgentEnv): @@ -13,15 +15,20 @@ class RailEnvRLLibWrapper(MultiAgentEnv): # number_of_agents=1, # obs_builder_object=TreeObsForRailEnv(max_depth=2)): super(MultiAgentEnv, self).__init__() - self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5, nr_extra=30, - seed=config['seed'] * (1+config.vector_index)) + self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5, + nr_extra=30, seed=config['seed'] * (1+config.vector_index)) set_seed(config['seed'] * (1+config.vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator, number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder']) def reset(self): self.agents_done = [] - return self.env.reset() + obs = self.env.reset() + o = dict() + # o['agents'] = obs + # obs[0] = [obs[0], np.ones((17, 17)) * 17] + # obs['global_obs'] = np.ones((17, 17)) * 17 + return obs def step(self, action_dict): obs, rewards, dones, infos = self.env.step(action_dict) @@ -46,7 +53,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv): #print(obs) #return obs, rewards, dones, infos + # oo = dict() + # oo['agents'] = o + # o['global'] = np.ones((17, 17)) * 17 + + # o[0] = [o[0], np.ones((17, 17)) * 17] + # o['global_obs'] = np.ones((17, 17)) * 17 + # r['global_obs'] = 0 + # d['global_obs'] = True return o, r, d, infos - + def get_agent_handles(self): return self.env.get_agent_handles() diff --git a/experiment_configs/CustomModels.py b/experiment_configs/CustomModels.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/experiment_configs/observation_benchmark/config.gin b/experiment_configs/observation_benchmark/config.gin index 9f3c0727dd6b5a922b2a8b212e4bf5e6f77f0dab..f5a4dc80396e26476296a4d8b83cb4882a0f1033 100644 --- a/experiment_configs/observation_benchmark/config.gin +++ b/experiment_configs/observation_benchmark/config.gin @@ -5,12 +5,13 @@ run_experiment.hidden_sizes = [32, 32] run_experiment.map_width = 20 run_experiment.map_height = 20 -run_experiment.n_agents = {"grid_search": [2, 5]} +run_experiment.n_agents = 5 run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents" run_experiment.horizon = 50 run_experiment.seed = 123 -run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]} +run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]} TreeObsForRailEnv.max_depth = 2 +LocalObsForRailEnv.view_radius = 5 diff --git a/train_experiment.py b/train_experiment.py index 0c1af1727abaaaf3078e24d7a250f071bda6cb9c..16c52b657db58c6f69e8cf91004d4259461be2cd 100644 --- a/train_experiment.py +++ b/train_experiment.py @@ -8,8 +8,9 @@ from flatland.envs.generators import complex_rail_generator # Import PPO trainer: we can replace these imports by any other trainer from RLLib. from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer -# from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph -from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph +# from baselines.CustomPPOTrainer import PPOTrainer as Trainer +from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph +# from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print @@ -27,9 +28,10 @@ import gin from ray import tune from ray.rllib.utils.seed import seed as set_seed -from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, LocalObsForRailEnv gin.external_configurable(TreeObsForRailEnv) gin.external_configurable(GlobalObsForRailEnv) +gin.external_configurable(LocalObsForRailEnv) from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor @@ -76,6 +78,15 @@ def train(config, reporter): gym.spaces.Box(low=0, high=1, shape=(4,)))) preprocessor = "global_obs_prep" + elif isinstance(config["obs_builder"], LocalObsForRailEnv): + view_radius = config["obs_builder"].view_radius + obs_space = gym.spaces.Tuple(( + gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 16)), + gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 2)), + gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 4)), + gym.spaces.Box(low=0, high=1, shape=(4,)))) + preprocessor = "global_obs_prep" + else: raise ValueError("Undefined observation space") @@ -107,8 +118,9 @@ def train(config, reporter): trainer_config["num_envs_per_worker"] = 1 trainer_config["env_config"] = env_config trainer_config["batch_mode"] = "complete_episodes" - trainer_config['simple_optimizer'] = True + trainer_config['simple_optimizer'] = False trainer_config['postprocess_inputs'] = True + trainer_config['log_level'] = 'WARN' def logger_creator(conf): """Creates a Unified logger with a default logdir prefix