diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index 223b9059945db484b94b0916924aa739f32270e4..48407f645a07d3947053365a86398a90aa4e41d7 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -77,8 +77,7 @@ def train(config, reporter): obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) preprocessor = "tree_obs_prep" - elif isinstance(config["obs_builder"], GlobalObsForRailEnv) or \ - isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent): + elif isinstance(config["obs_builder"], GlobalObsForRailEnv): obs_space = gym.spaces.Tuple(( gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)), @@ -88,6 +87,16 @@ def train(config, reporter): else: preprocessor = "global_obs_prep" + elif isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent): + obs_space = gym.spaces.Tuple(( + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 5)), + gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2)))) + if config['conv_model']: + preprocessor = "conv_obs_prep" + else: + preprocessor = "global_obs_prep" + elif isinstance(config["obs_builder"], LocalObsForRailEnv): view_radius = config["obs_builder"].view_radius obs_space = gym.spaces.Tuple((