Skip to content
Snippets Groups Projects
Commit a54b734a authored by gmollard's avatar gmollard
Browse files

observation benchmark script

parent bc400346
No related branches found
No related tags found
No related merge requests found
...@@ -54,3 +54,11 @@ class CustomPreprocessor(Preprocessor): ...@@ -54,3 +54,11 @@ class CustomPreprocessor(Preprocessor):
def transform(self, observation): def transform(self, observation):
return norm_obs_clip(observation) # return the preprocessed observation return norm_obs_clip(observation) # return the preprocessed observation
# class NoPreprocessor:
# def _init_shape(self, obs_space, options):
# num_features = 0
# for space in obs_space:
run_experiment.name = "n_agents_results"
run_experiment.num_iterations = 1002
run_experiment.save_every = 200
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [2, 5]}
run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_experiment.horizon = 50
run_experiment.seed = 123
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv, @GlobalObsForRailEnv]}
TreeObsForRailEnv.max_depth = 2
...@@ -24,9 +24,12 @@ import tempfile ...@@ -24,9 +24,12 @@ import tempfile
import gin import gin
from ray import tune from ray import tune
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
ray.init() ray.init()
...@@ -55,7 +58,22 @@ def train(config, reporter): ...@@ -55,7 +58,22 @@ def train(config, reporter):
"seed": config['seed']} "seed": config['seed']}
# Observation space and action space definitions # Observation space and action space definitions
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) if type(config["obs_builder"]) == TreeObsForRailEnv:
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
preprocessor = "tree_obs_prep"
elif type(config["obs_builder"]) == GlobalObsForRailEnv:
obs_space = gym.spaces.Tuple((
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])),
gym.spaces.Space(4)))
preprocessor = TupleFlatteningPreprocessor
else:
raise ValueError("Undefined observation space")
act_space = gym.spaces.Discrete(4) act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train # Dict with the different policies to train
...@@ -69,7 +87,7 @@ def train(config, reporter): ...@@ -69,7 +87,7 @@ def train(config, reporter):
# Trainer configuration # Trainer configuration
trainer_config = DEFAULT_CONFIG.copy() trainer_config = DEFAULT_CONFIG.copy()
trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"} trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
trainer_config['multiagent'] = {"policy_graphs": policy_graphs, trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn, "policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())} "policies_to_train": list(policy_graphs.keys())}
...@@ -129,8 +147,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -129,8 +147,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"seed": seed "seed": seed
}, },
resources_per_trial={ resources_per_trial={
"cpu": 1, "cpu": 12,
"gpu": 0.0 "gpu": 0.5
}, },
local_dir=local_dir local_dir=local_dir
) )
...@@ -138,6 +156,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -138,6 +156,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__': if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
dir = '/home/guillaume/Desktop/distMAgent/baselines/experiment_configs/n_agents_experiment' # To Modify dir = '/mount/SDC/flatland/baselines/experiment_configs/observation_benchmark' # To Modify
gin.parse_config_file(dir + '/config.gin') gin.parse_config_file(dir + '/config.gin')
run_experiment(local_dir=dir) run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment