Skip to content
Snippets Groups Projects
Commit d5a04cae authored by gmollard's avatar gmollard
Browse files

observation benchmark with local observation

parent 8b1fc5bd
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,8 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed
import numpy as np
class RailEnvRLLibWrapper(MultiAgentEnv):
......@@ -13,15 +15,20 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(MultiAgentEnv, self).__init__()
self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5, nr_extra=30,
seed=config['seed'] * (1+config.vector_index))
self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=30, seed=config['seed'] * (1+config.vector_index))
set_seed(config['seed'] * (1+config.vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder'])
def reset(self):
self.agents_done = []
return self.env.reset()
obs = self.env.reset()
o = dict()
# o['agents'] = obs
# obs[0] = [obs[0], np.ones((17, 17)) * 17]
# obs['global_obs'] = np.ones((17, 17)) * 17
return obs
def step(self, action_dict):
obs, rewards, dones, infos = self.env.step(action_dict)
......@@ -46,7 +53,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
#print(obs)
#return obs, rewards, dones, infos
# oo = dict()
# oo['agents'] = o
# o['global'] = np.ones((17, 17)) * 17
# o[0] = [o[0], np.ones((17, 17)) * 17]
# o['global_obs'] = np.ones((17, 17)) * 17
# r['global_obs'] = 0
# d['global_obs'] = True
return o, r, d, infos
def get_agent_handles(self):
return self.env.get_agent_handles()
......@@ -5,12 +5,13 @@ run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [2, 5]}
run_experiment.n_agents = 5
run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents"
run_experiment.horizon = 50
run_experiment.seed = 123
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5
......@@ -8,8 +8,9 @@ from flatland.envs.generators import complex_rail_generator
# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
# from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
# from baselines.CustomPPOTrainer import PPOTrainer as Trainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
# from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
......@@ -27,9 +28,10 @@ import gin
from ray import tune
from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, LocalObsForRailEnv
gin.external_configurable(TreeObsForRailEnv)
gin.external_configurable(GlobalObsForRailEnv)
gin.external_configurable(LocalObsForRailEnv)
from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
......@@ -76,6 +78,15 @@ def train(config, reporter):
gym.spaces.Box(low=0, high=1, shape=(4,))))
preprocessor = "global_obs_prep"
elif isinstance(config["obs_builder"], LocalObsForRailEnv):
view_radius = config["obs_builder"].view_radius
obs_space = gym.spaces.Tuple((
gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 16)),
gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 2)),
gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 4)),
gym.spaces.Box(low=0, high=1, shape=(4,))))
preprocessor = "global_obs_prep"
else:
raise ValueError("Undefined observation space")
......@@ -107,8 +118,9 @@ def train(config, reporter):
trainer_config["num_envs_per_worker"] = 1
trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes"
trainer_config['simple_optimizer'] = True
trainer_config['simple_optimizer'] = False
trainer_config['postprocess_inputs'] = True
trainer_config['log_level'] = 'WARN'
def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment