From b2f7c0a5c26c72472aabaac5ffebd8a601d363eb Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Wed, 29 May 2019 10:37:51 +0200 Subject: [PATCH] corrected obs_shape for global obs direction dependent --- RLLib_training/train_experiment.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index 223b905..48407f6 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -77,8 +77,7 @@ def train(config, reporter): obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) preprocessor = "tree_obs_prep" - elif isinstance(config["obs_builder"], GlobalObsForRailEnv) or \ - isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent): + elif isinstance(config["obs_builder"], GlobalObsForRailEnv): obs_space = gym.spaces.Tuple(( gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)), @@ -88,6 +87,16 @@ def train(config, reporter): else: preprocessor = "global_obs_prep" + elif isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent): + obs_space = gym.spaces.Tuple(( + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 5)), + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2)))) + if config['conv_model']: + preprocessor = "conv_obs_prep" + else: + preprocessor = "global_obs_prep" + elif isinstance(config["obs_builder"], LocalObsForRailEnv): view_radius = config["obs_builder"].view_radius obs_space = gym.spaces.Tuple(( -- GitLab