Skip to content
Snippets Groups Projects
Commit b2f7c0a5 authored by gmollard's avatar gmollard
Browse files

corrected obs_shape for global obs direction dependent

parent 088c5f16
No related branches found
No related tags found
No related merge requests found
...@@ -77,8 +77,7 @@ def train(config, reporter): ...@@ -77,8 +77,7 @@ def train(config, reporter):
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
preprocessor = "tree_obs_prep" preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv) or \ elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent):
obs_space = gym.spaces.Tuple(( obs_space = gym.spaces.Tuple((
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)), gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)),
...@@ -88,6 +87,16 @@ def train(config, reporter): ...@@ -88,6 +87,16 @@ def train(config, reporter):
else: else:
preprocessor = "global_obs_prep" preprocessor = "global_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent):
obs_space = gym.spaces.Tuple((
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 5)),
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2))))
if config['conv_model']:
preprocessor = "conv_obs_prep"
else:
preprocessor = "global_obs_prep"
elif isinstance(config["obs_builder"], LocalObsForRailEnv): elif isinstance(config["obs_builder"], LocalObsForRailEnv):
view_radius = config["obs_builder"].view_radius view_radius = config["obs_builder"].view_radius
obs_space = gym.spaces.Tuple(( obs_space = gym.spaces.Tuple((
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment