diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py index 1d1d214cef1af84b99c719359a859712559974bc..86c159d3bbc11bc4c0fdd321d2c6ff6838488f4d 100644 --- a/RLLib_training/custom_preprocessors.py +++ b/RLLib_training/custom_preprocessors.py @@ -49,13 +49,13 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): class CustomPreprocessor(Preprocessor): def _init_shape(self, obs_space, options): - return (sum([space.shape[0] for space in obs_space]), ) - return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,) + #return (sum([space.shape[0] for space in obs_space]), ) + return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1]),) def transform(self, observation): # if len(observation) == 111: - return np.concatenate([norm_obs_clip(obs) for obs in observation]) - #return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()]) + #return np.concatenate([norm_obs_clip(obs) for obs in observation]) + return np.concatenate([norm_obs_clip(observation[0]), observation[1], observation[2].flatten()])#, norm_obs_clip(observation[1]), observation[2], observation[3].flatten()]) #one_hot = observation[-3:] #return np.append(obs, one_hot) # else: