Commit cbd784c0 authored by nilabha's avatar nilabha

Update shapes

parent ea0fbf8a
......@@ -45,7 +45,7 @@ class GlobalObservation(Observation):
def observation_space(self) -> gym.Space:
grid_shape = (self._config['max_width'], self._config['max_height'])
return gym.spaces.Tuple([gym.spaces.Box(low=0, high=np.inf, shape=grid_shape + (16,), dtype=np.float32)])
return gym.spaces.Tuple([gym.spaces.Box(low=0, high=np.inf, shape=grid_shape + (31,), dtype=np.float32)])
class PaddedGlobalObsForRailEnv(ObservationBuilder):
......
......@@ -22,7 +22,7 @@ class GlobalObsModel(TFModelV2):
obs_space = obs_space.original_space
observations = [tf.keras.layers.Input(shape=o.shape) for o in obs_space]
processed_observations = observations # preprocess_obs(tuple(observations))
processed_observations = observations[0] #preprocess_obs(tuple(observations))
if self._options['architecture'] == 'nature':
conv_out = NatureCNN(activation_out=True, **self._options['architecture_options'])(processed_observations)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment