Commit c65896e5 authored by nilabha's avatar nilabha

corrected shape errors for global IL

parent cbd784c0
......@@ -45,7 +45,7 @@ class GlobalObservation(Observation):
def observation_space(self) -> gym.Space:
grid_shape = (self._config['max_width'], self._config['max_height'])
return gym.spaces.Tuple([gym.spaces.Box(low=0, high=np.inf, shape=grid_shape + (31,), dtype=np.float32)])
return gym.spaces.Box(low=0, high=np.inf, shape=grid_shape + (31,), dtype=np.float32)
class PaddedGlobalObsForRailEnv(ObservationBuilder):
......
flatland-random-sparse-small-global-marwil-fc-ppo:
run: MARWIL
env: flatland_sparse
stop:
timesteps_total: 1000000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
beta:
grid_search: [0,0.25,0.5,0.75, 1] # compare IL (beta=0) vs MARWIL [0,0.25,0.5,0.75, 1]
input: /tmp/flatland
input_evaluation: [is, wis, simulation]
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
num_workers: 1
num_envs_per_worker: 1
batch_mode: truncate_episodes
observation_filter: NoFilter
num_gpus: 0
env_config:
observation: global
observation_config:
max_width: 45
max_height: 45
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: neurips2020-flatland-baselines
entity: nilabha2007
tags: ["small_v0", "global_obs", "MARWIL"] # TODO should be set programmatically
model:
custom_model: global_obs_model
custom_options:
architecture: impala
architecture_options:
residual_layers: [[16,2], [32, 4]]
......@@ -17,12 +17,12 @@ class GlobalObsModel(TFModelV2):
self._mask_unavailable_actions = self._options.get("mask_unavailable_actions", False)
if self._mask_unavailable_actions:
obs_space = obs_space.original_space['obs']
obs_space = obs_space['obs']
else:
obs_space = obs_space.original_space
obs_space = obs_space
observations = [tf.keras.layers.Input(shape=o.shape) for o in obs_space]
processed_observations = observations[0] #preprocess_obs(tuple(observations))
observations = tf.keras.layers.Input(shape=obs_space.shape)
processed_observations = observations # preprocess_obs(tuple(observations))
if self._options['architecture'] == 'nature':
conv_out = NatureCNN(activation_out=True, **self._options['architecture_options'])(processed_observations)
......@@ -34,7 +34,7 @@ class GlobalObsModel(TFModelV2):
baseline = tf.keras.layers.Dense(units=1)(conv_out)
self._model = tf.keras.Model(inputs=observations, outputs=[logits, baseline])
self.register_variables(self._model.variables)
self._model.summary()
# self._model.summary()
def forward(self, input_dict, state, seq_lens):
# obs = preprocess_obs(input_dict['obs'])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment