Skip to content
Snippets Groups Projects
Commit dc836c20 authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

modified preprocessor to include predictor

parent f80e77a6
No related branches found
No related tags found
No related merge requests found
...@@ -49,13 +49,13 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -49,13 +49,13 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
class CustomPreprocessor(Preprocessor): class CustomPreprocessor(Preprocessor):
def _init_shape(self, obs_space, options): def _init_shape(self, obs_space, options):
return (sum([space.shape[0] for space in obs_space]), ) #return (sum([space.shape[0] for space in obs_space]), )
return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,) return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1]),)
def transform(self, observation): def transform(self, observation):
# if len(observation) == 111: # if len(observation) == 111:
return np.concatenate([norm_obs_clip(obs) for obs in observation]) #return np.concatenate([norm_obs_clip(obs) for obs in observation])
#return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()]) return np.concatenate([norm_obs_clip(observation[0]), observation[1], observation[2].flatten()])#, norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
#one_hot = observation[-3:] #one_hot = observation[-3:]
#return np.append(obs, one_hot) #return np.append(obs, one_hot)
# else: # else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment