From dc836c20d6caab6bf38e9109b23df5e819657b9e Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume@iccluster028.iccluster.epfl.ch> Date: Thu, 13 Jun 2019 12:50:56 +0200 Subject: [PATCH] modified preprocessor to include predictor --- RLLib_training/custom_preprocessors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py index 1d1d214..86c159d 100644 --- a/RLLib_training/custom_preprocessors.py +++ b/RLLib_training/custom_preprocessors.py @@ -49,13 +49,13 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): class CustomPreprocessor(Preprocessor): def _init_shape(self, obs_space, options): - return (sum([space.shape[0] for space in obs_space]), ) - return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,) + #return (sum([space.shape[0] for space in obs_space]), ) + return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1]),) def transform(self, observation): # if len(observation) == 111: - return np.concatenate([norm_obs_clip(obs) for obs in observation]) - #return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()]) + #return np.concatenate([norm_obs_clip(obs) for obs in observation]) + return np.concatenate([norm_obs_clip(observation[0]), observation[1], observation[2].flatten()])#, norm_obs_clip(observation[1]), observation[2], observation[3].flatten()]) #one_hot = observation[-3:] #return np.append(obs, one_hot) # else: -- GitLab