Skip to content
Snippets Groups Projects
Commit 38d81491 authored by gmollard's avatar gmollard
Browse files

small changes to test modified tree obs, but not working

parent 23f5ddf1
No related branches found
No related tags found
No related merge requests found
......@@ -19,13 +19,14 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
vector_index = config.vector_index
else:
vector_index = 1
self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=30, seed=config['seed'] * (1+vector_index))
#self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
# nr_extra=30, seed=config['seed'] * (1+vector_index))
set_seed(config['seed'] * (1+vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
#self.env = RailEnv(width=config["width"], height=config["height"],
self.env = RailEnv(width=10, height=20,
number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder'])
self.env.load('./baselines/torch_training/railway/complex_scene.pkl')
self.env.load('/mount/SDC/flatland/baselines/torch_training/railway/complex_scene.pkl')
self.width = self.env.width
self.height = self.env.height
......@@ -45,7 +46,6 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
self.agents = self.env.agents
self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict
return obs
def step(self, action_dict):
......
......@@ -50,10 +50,10 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
class CustomPreprocessor(Preprocessor):
def _init_shape(self, obs_space, options):
return (105,)
return (111,)
def transform(self, observation):
if len(observation) == 105:
if len(observation) == 111:
return norm_obs_clip(observation)
else:
return observation
......
run_experiment.name = "n_agents_results"
run_experiment.name = "observation_benchmark_results"
run_experiment.num_iterations = 1002
run_experiment.save_every = 200
run_experiment.hidden_sizes = [32, 32]
run_experiment.save_every = 100
run_experiment.hidden_sizes = [32,32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [1]}#, 2, 5, 10]}
run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_experiment.n_agents = {"grid_search": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_entropy_coeff_{config[entropy_coeff]}_{config[n_agents]}_agents_"
run_experiment.horizon = 50
run_experiment.seed = 123
run_experiment.entropy_coeff = {"grid_search": [1e-3, 1e-2, 0]}
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5
run_experiment.name = "observation_benchmark_loaded_env_results"
run_experiment.num_iterations = 1002
run_experiment.save_every = 50
run_experiment.hidden_sizes = 32
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
......@@ -10,9 +10,10 @@ run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__._
run_experiment.horizon = 50
run_experiment.seed = 123
run_experiment.conv_model = False
run_experiment.entropy_coeff = 1e-2
run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv(), @TreeObsForRailEnv(), @GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}
run_experiment.obs_builder = @TreeObsForRailEnv()#{"grid_search": [@LocalObsForRailEnv(), @TreeObsForRailEnv(), @GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent()]}
TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5
......@@ -52,6 +52,10 @@ def train(config, reporter):
set_seed(config['seed'], config['seed'], config['seed'])
config['map_width']= 20
config['map_height']= 10
config['n_agents'] = 8
# Example configuration to generate a random rail
env_config = {"width": config['map_width'],
"height": config['map_height'],
......@@ -62,7 +66,7 @@ def train(config, reporter):
# Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(111,))
preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
......@@ -191,6 +195,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/conv_model_test' # To Modify
dir = '/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/observation_benchmark_loaded_env' # To Modify
gin.parse_config_file(dir + '/config.gin')
run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment