Skip to content
Snippets Groups Projects
Commit 77e5b4eb authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

first prediction test

parent 507f0e86
No related branches found
No related tags found
No related merge requests found
......@@ -89,12 +89,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
agent_id_one_hot[i_agent] = 1
o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
self.old_obs = o
oo = dict()
for i_agent in range(len(self.env.agents)):
oo[i_agent] = [o[i_agent], o[i_agent][0], o[i_agent][1], o[i_agent][2]]
self.rail = self.env.rail
self.agents = self.env.agents
self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict
return obs
return oo
def step(self, action_dict):
obs, rewards, dones, infos = self.env.step(action_dict)
......@@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o = dict()
# print(self.agents_done)
# print(dones)
predictions = self.env.predict()
pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
......@@ -153,15 +158,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
#
# d[agent] = dones[agent]
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
self.rail = self.env.rail
self.agents = self.env.agents
self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict
#print('Old OBS #####', self.old_obs)
oo = dict()
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
oo[i_agent] = [o[i_agent], self.old_obs[i_agent][0], self.old_obs[i_agent][1],
self.old_obs[i_agent][2]]
self.old_obs = o
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
#print(obs)
#return obs, rewards, dones, infos
# oo = dict()
......@@ -172,7 +183,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# o['global_obs'] = np.ones((17, 17)) * 17
# r['global_obs'] = 0
# d['global_obs'] = True
return o, r, d, infos
return oo, r, d, infos
def get_agent_handles(self):
return self.env.get_agent_handles()
......
......@@ -49,11 +49,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
class CustomPreprocessor(Preprocessor):
def _init_shape(self, obs_space, options):
return obs_space.shape
return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,)
def transform(self, observation):
# if len(observation) == 111:
return norm_obs_clip(observation)
return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
one_hot = observation[-3:]
return np.append(obs, one_hot)
# else:
......
......@@ -61,7 +61,12 @@ def train(config, reporter):
# Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Box(low=-1, high=1, shape=(147,))
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),
gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents']))))
preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
......@@ -97,7 +102,7 @@ def train(config, reporter):
raise ValueError("Undefined observation space")
act_space = gym.spaces.Discrete(4)
act_space = gym.spaces.Discrete(5)
# Dict with the different policies to train
policy_graphs = {
......@@ -121,11 +126,11 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 2
trainer_config["num_gpus"] = 0
trainer_config["num_gpus_per_worker"] = 0
trainer_config["num_cpus_per_worker"] = 11
trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 1
trainer_config["num_envs_per_worker"] = 1
trainer_config["num_envs_per_worker"] = 6
trainer_config['entropy_coeff'] = config['entropy_coeff']
trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes"
......@@ -189,8 +194,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"lambda_gae": lambda_gae
},
resources_per_trial={
"cpu": 3,
"gpu": 0.0
"cpu": 12,
"gpu": 0.5
},
local_dir=local_dir
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment