Skip to content
Snippets Groups Projects
Commit 77e5b4eb authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

first prediction test

parent 507f0e86
No related branches found
No related tags found
No related merge requests found
...@@ -89,12 +89,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -89,12 +89,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
agent_id_one_hot[i_agent] = 1 agent_id_one_hot[i_agent] = 1
o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs] o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
self.old_obs = o
oo = dict()
for i_agent in range(len(self.env.agents)):
oo[i_agent] = [o[i_agent], o[i_agent][0], o[i_agent][1], o[i_agent][2]]
self.rail = self.env.rail self.rail = self.env.rail
self.agents = self.env.agents self.agents = self.env.agents
self.agents_static = self.env.agents_static self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict self.dev_obs_dict = self.env.dev_obs_dict
return obs return oo
def step(self, action_dict): def step(self, action_dict):
obs, rewards, dones, infos = self.env.step(action_dict) obs, rewards, dones, infos = self.env.step(action_dict)
...@@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o = dict() o = dict()
# print(self.agents_done) # print(self.agents_done)
# print(dones) # print(dones)
predictions = self.env.predict()
pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
for i_agent in range(len(self.env.agents)): for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done: if i_agent not in self.agents_done:
...@@ -153,15 +158,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -153,15 +158,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# #
# d[agent] = dones[agent] # d[agent] = dones[agent]
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
self.rail = self.env.rail
self.agents = self.env.agents self.agents = self.env.agents
self.agents_static = self.env.agents_static self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict self.dev_obs_dict = self.env.dev_obs_dict
#print('Old OBS #####', self.old_obs)
oo = dict()
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
oo[i_agent] = [o[i_agent], self.old_obs[i_agent][0], self.old_obs[i_agent][1],
self.old_obs[i_agent][2]]
self.old_obs = o
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
#print(obs) #print(obs)
#return obs, rewards, dones, infos #return obs, rewards, dones, infos
# oo = dict() # oo = dict()
...@@ -172,7 +183,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -172,7 +183,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# o['global_obs'] = np.ones((17, 17)) * 17 # o['global_obs'] = np.ones((17, 17)) * 17
# r['global_obs'] = 0 # r['global_obs'] = 0
# d['global_obs'] = True # d['global_obs'] = True
return o, r, d, infos return oo, r, d, infos
def get_agent_handles(self): def get_agent_handles(self):
return self.env.get_agent_handles() return self.env.get_agent_handles()
......
...@@ -49,11 +49,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -49,11 +49,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
class CustomPreprocessor(Preprocessor): class CustomPreprocessor(Preprocessor):
def _init_shape(self, obs_space, options): def _init_shape(self, obs_space, options):
return obs_space.shape return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,)
def transform(self, observation): def transform(self, observation):
# if len(observation) == 111: # if len(observation) == 111:
return norm_obs_clip(observation) return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
one_hot = observation[-3:] one_hot = observation[-3:]
return np.append(obs, one_hot) return np.append(obs, one_hot)
# else: # else:
......
...@@ -61,7 +61,12 @@ def train(config, reporter): ...@@ -61,7 +61,12 @@ def train(config, reporter):
# Observation space and action space definitions # Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv): if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Box(low=-1, high=1, shape=(147,)) obs_space = gym.spaces.Tuple((gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),
gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents']))))
preprocessor = "tree_obs_prep" preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv): elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
...@@ -97,7 +102,7 @@ def train(config, reporter): ...@@ -97,7 +102,7 @@ def train(config, reporter):
raise ValueError("Undefined observation space") raise ValueError("Undefined observation space")
act_space = gym.spaces.Discrete(4) act_space = gym.spaces.Discrete(5)
# Dict with the different policies to train # Dict with the different policies to train
policy_graphs = { policy_graphs = {
...@@ -121,11 +126,11 @@ def train(config, reporter): ...@@ -121,11 +126,11 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 2 trainer_config["num_cpus_per_worker"] = 11
trainer_config["num_gpus"] = 0 trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0 trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 1 trainer_config["num_cpus_for_driver"] = 1
trainer_config["num_envs_per_worker"] = 1 trainer_config["num_envs_per_worker"] = 6
trainer_config['entropy_coeff'] = config['entropy_coeff'] trainer_config['entropy_coeff'] = config['entropy_coeff']
trainer_config["env_config"] = env_config trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes" trainer_config["batch_mode"] = "complete_episodes"
...@@ -189,8 +194,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -189,8 +194,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"lambda_gae": lambda_gae "lambda_gae": lambda_gae
}, },
resources_per_trial={ resources_per_trial={
"cpu": 3, "cpu": 12,
"gpu": 0.0 "gpu": 0.5
}, },
local_dir=local_dir local_dir=local_dir
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment