diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 8f7e1ec6c77dab11b6e33fcf1d8960d74bc1c304..0fd0f59727bb384a4545b8189e7844d8540367b1 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -53,6 +53,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): else: obs = self.env.reset() + # predictions = self.env.predict() # if predictions != {}: # # pred_pos is a 3 dimensions array (N_Agents, T_pred, 2) containing x and y coordinates of @@ -63,7 +64,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): o = dict() for i_agent in range(len(self.env.agents)): - + data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]), + num_features_per_node=8, current_depth=0) # if predictions != {}: # pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent) # @@ -71,7 +73,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): # agent_id_one_hot[i_agent] = 1 # o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs] # else: - o[i_agent] = obs[i_agent] + + o[i_agent] = [data, distance, agent_data] # needed for the renderer self.rail = self.env.rail @@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): for i_agent in range(len(self.env.agents)): if i_agent not in self.agents_done: + data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]), + num_features_per_node=8, current_depth=0) # if predictions != {}: # pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent) @@ -112,7 +117,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): # agent_id_one_hot[i_agent] = 1 # o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs] # else: - o[i_agent] = obs[i_agent] + o[i_agent] = [data, distance, agent_data] r[i_agent] = rewards[i_agent] d[i_agent] = dones[i_agent] diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py index 6f9cad767ad5bbbf823b99fe320bf3dcdf1cad1e..41e895bee95cde7bc59220ce2eee4eaabd458651 100644 --- a/RLLib_training/custom_preprocessors.py +++ b/RLLib_training/custom_preprocessors.py @@ -55,6 +55,12 @@ class CustomPreprocessor(Preprocessor): # return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),) def transform(self, observation): + print('OBSSSSSSSSSSSSSSSSSs', observation, observation.shape) + data = norm_obs_clip(observation[0]) + distance = norm_obs_clip(observation[1]) + agent_data = np.clip(observation[2], -1, 1) + + return np.concatenate((np.concatenate((data, distance)), agent_data)) return norm_obs_clip(observation) return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])]) # if len(observation) == 111: