From dc836c20d6caab6bf38e9109b23df5e819657b9e Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume@iccluster028.iccluster.epfl.ch>
Date: Thu, 13 Jun 2019 12:50:56 +0200
Subject: [PATCH] modified preprocessor to include predictor

---
 RLLib_training/custom_preprocessors.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py
index 1d1d214..86c159d 100644
--- a/RLLib_training/custom_preprocessors.py
+++ b/RLLib_training/custom_preprocessors.py
@@ -49,13 +49,13 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 
 class CustomPreprocessor(Preprocessor):
     def _init_shape(self, obs_space, options):
-        return (sum([space.shape[0] for space in obs_space]), )
-        return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,)
+        #return (sum([space.shape[0] for space in obs_space]), )
+        return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1]),)
 
     def transform(self, observation):
         # if len(observation) == 111:
-        return np.concatenate([norm_obs_clip(obs) for obs in observation])
-        #return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
+        #return np.concatenate([norm_obs_clip(obs) for obs in observation])
+        return np.concatenate([norm_obs_clip(observation[0]), observation[1], observation[2].flatten()])#, norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
         #one_hot = observation[-3:]
         #return np.append(obs, one_hot)
         # else:
-- 
GitLab