Skip to content
Snippets Groups Projects
Commit fdd64ac1 authored by gmollard's avatar gmollard
Browse files

experiment with shortest path predictor

parent 44d3ca9d
No related branches found
No related tags found
No related merge requests found
run_experiment.name = "observation_benchmark_results" run_experiment.name = "observation_benchmark_results"
run_experiment.num_iterations = 2002 run_experiment.num_iterations = 2002
run_experiment.save_every = 50 run_experiment.save_every = 100
run_experiment.hidden_sizes = [32, 32] run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 8 run_experiment.map_width = 20
run_experiment.map_height = 8 run_experiment.map_height = 20
run_experiment.n_agents = 3 run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]}
run_experiment.rail_generator = "complex_rail_generator" run_experiment.rail_generator = "complex_rail_generator"
run_experiment.nr_extra = 5#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]} run_experiment.nr_extra = 5
run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_kl_coeff_{config[kl_coeff]}_horizon_{config[horizon]}_" run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}_"
run_experiment.horizon = {"grid_search": [50, 100]} #run_experiment.horizon =
run_experiment.seed = 123 run_experiment.seed = 123
#run_experiment.conv_model = {"grid_search": [True, False]} #run_experiment.conv_model = {"grid_search": [True, False]}
...@@ -18,9 +18,12 @@ run_experiment.conv_model = False ...@@ -18,9 +18,12 @@ run_experiment.conv_model = False
#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]} #run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
run_experiment.obs_builder = @TreeObsForRailEnv() run_experiment.obs_builder = @TreeObsForRailEnv()
TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv
TreeObsForRailEnv.max_depth = 2 TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5 LocalObsForRailEnv.view_radius = 5
run_experiment.entropy_coeff = 0.01 run_experiment.entropy_coeff = 0.001
run_experiment.kl_coeff = {"grid_search": [0, 0.2]} run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]}
run_experiment.lambda_gae = 0.9# {"grid_search": [0.9, 1.0]} run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]}
#run_experiment.predictor = "ShortestPathPredictorForRailEnv"
run_experiment.step_memory = 2
run_experiment.name = "memory_experiment_results" run_experiment.name = "memory_experiment_results"
run_experiment.num_iterations = 2002 run_experiment.num_iterations = 2002
run_experiment.save_every = 50 run_experiment.save_every = 50
run_experiment.hidden_sizes = {"grid_search": [[32, 32], [64, 64], [128, 128]]} run_experiment.hidden_sizes = [32, 32]#{"grid_search": [[32, 32], [64, 64], [128, 128]]}
run_experiment.map_width = 8 run_experiment.map_width = 8
run_experiment.map_height = 8 run_experiment.map_height = 8
...@@ -20,7 +20,7 @@ run_experiment.obs_builder = @TreeObsForRailEnv() ...@@ -20,7 +20,7 @@ run_experiment.obs_builder = @TreeObsForRailEnv()
TreeObsForRailEnv.max_depth = 2 TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5 LocalObsForRailEnv.view_radius = 5
run_experiment.entropy_coeff = {"grid_search": [1e-4, 1e-3, 1e-2]} run_experiment.entropy_coeff = 1e-4#{"grid_search": [1e-4, 1e-3, 1e-2]}
run_experiment.kl_coeff = 0.2 run_experiment.kl_coeff = 0.2
run_experiment.lambda_gae = 0.9 run_experiment.lambda_gae = 0.9
run_experiment.predictor = None#@DummyPredictorForRailEnv() run_experiment.predictor = None#@DummyPredictorForRailEnv()
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import gin import gin
import gym import gym
from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from importlib_resources import path from importlib_resources import path
# Import PPO trainer: we can replace these imports by any other trainer from RLLib. # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
...@@ -11,6 +11,7 @@ from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph ...@@ -11,6 +11,7 @@ from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
gin.external_configurable(DummyPredictorForRailEnv) gin.external_configurable(DummyPredictorForRailEnv)
gin.external_configurable(ShortestPathPredictorForRailEnv)
import ray import ray
...@@ -66,6 +67,7 @@ def on_episode_end(info): ...@@ -66,6 +67,7 @@ def on_episode_end(info):
score /= (len(episode._agent_reward_history) * 3 * episode.horizon) score /= (len(episode._agent_reward_history) * 3 * episode.horizon)
episode.custom_metrics["score"] = score episode.custom_metrics["score"] = score
def train(config, reporter): def train(config, reporter):
print('Init Env') print('Init Env')
...@@ -81,23 +83,12 @@ def train(config, reporter): ...@@ -81,23 +83,12 @@ def train(config, reporter):
"seed": config['seed'], "seed": config['seed'],
"obs_builder": config['obs_builder'], "obs_builder": config['obs_builder'],
"min_dist": config['min_dist'], "min_dist": config['min_dist'],
# "predictor": config["predictor"], "predictor": config["predictor"],
"step_memory": config["step_memory"]} "step_memory": config["step_memory"]}
# Observation space and action space definitions # Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv): if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), )) obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), ))
# gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
# gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
# 'step_memory'])
# if config['predictor'] is None:
# obs_space = gym.spaces.Tuple(
# (gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory'])
# else:
# obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
# gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
# gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
# 'step_memory'])
preprocessor = "tree_obs_prep" preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv): elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
...@@ -152,7 +143,7 @@ def train(config, reporter): ...@@ -152,7 +143,7 @@ def train(config, reporter):
trainer_config['multiagent'] = {"policy_graphs": policy_graphs, trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn, "policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())} "policies_to_train": list(policy_graphs.keys())}
trainer_config["horizon"] = 1.5 * (config['map_width'] + config['map_height'])#config['horizon'] trainer_config["horizon"] = 3 * (config['map_width'] + config['map_height'])#config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 7 trainer_config["num_cpus_per_worker"] = 7
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment