Skip to content
Snippets Groups Projects
Commit 6d56e096 authored by gmollard's avatar gmollard
Browse files

new predictor

parent ea93665f
No related branches found
No related tags found
No related merge requests found
......@@ -53,6 +53,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
else:
obs = self.env.reset()
# predictions = self.env.predict()
# if predictions != {}:
# # pred_pos is a 3 dimensions array (N_Agents, T_pred, 2) containing x and y coordinates of
......@@ -63,7 +64,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o = dict()
for i_agent in range(len(self.env.agents)):
data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
num_features_per_node=8, current_depth=0)
# if predictions != {}:
# pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
#
......@@ -71,7 +73,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# agent_id_one_hot[i_agent] = 1
# o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
# else:
o[i_agent] = obs[i_agent]
o[i_agent] = [data, distance, agent_data]
# needed for the renderer
self.rail = self.env.rail
......@@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
num_features_per_node=8, current_depth=0)
# if predictions != {}:
# pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
......@@ -112,7 +117,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# agent_id_one_hot[i_agent] = 1
# o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
# else:
o[i_agent] = obs[i_agent]
o[i_agent] = [data, distance, agent_data]
r[i_agent] = rewards[i_agent]
d[i_agent] = dones[i_agent]
......
......@@ -55,6 +55,12 @@ class CustomPreprocessor(Preprocessor):
# return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),)
def transform(self, observation):
print('OBSSSSSSSSSSSSSSSSSs', observation, observation.shape)
data = norm_obs_clip(observation[0])
distance = norm_obs_clip(observation[1])
agent_data = np.clip(observation[2], -1, 1)
return np.concatenate((np.concatenate((data, distance)), agent_data))
return norm_obs_clip(observation)
return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])])
# if len(observation) == 111:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment