From 77e5b4eb2acfdcd424f2e4f139389972805f2b13 Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume@iccluster028.iccluster.epfl.ch> Date: Wed, 12 Jun 2019 10:09:20 +0200 Subject: [PATCH] first prediction test --- RLLib_training/RailEnvRLLibWrapper.py | 27 ++++++++++++++++++-------- RLLib_training/custom_preprocessors.py | 4 ++-- RLLib_training/train_experiment.py | 21 ++++++++++++-------- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index f9643a8..ca2a78b 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -89,12 +89,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv): agent_id_one_hot[i_agent] = 1 o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs] - + self.old_obs = o + oo = dict() + for i_agent in range(len(self.env.agents)): + oo[i_agent] = [o[i_agent], o[i_agent][0], o[i_agent][1], o[i_agent][2]] self.rail = self.env.rail self.agents = self.env.agents self.agents_static = self.env.agents_static self.dev_obs_dict = self.env.dev_obs_dict - return obs + return oo def step(self, action_dict): obs, rewards, dones, infos = self.env.step(action_dict) @@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): o = dict() # print(self.agents_done) # print(dones) + predictions = self.env.predict() + pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0) for i_agent in range(len(self.env.agents)): if i_agent not in self.agents_done: @@ -153,15 +158,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv): # # d[agent] = dones[agent] - for agent, done in dones.items(): - if done and agent != '__all__': - self.agents_done.append(agent) - - self.rail = self.env.rail self.agents = self.env.agents self.agents_static = self.env.agents_static self.dev_obs_dict = self.env.dev_obs_dict + #print('Old OBS #####', self.old_obs) + oo = dict() + for i_agent in range(len(self.env.agents)): + if i_agent not in self.agents_done: + oo[i_agent] = [o[i_agent], self.old_obs[i_agent][0], self.old_obs[i_agent][1], + self.old_obs[i_agent][2]] + self.old_obs = o + for agent, done in dones.items(): + if done and agent != '__all__': + self.agents_done.append(agent) + #print(obs) #return obs, rewards, dones, infos # oo = dict() @@ -172,7 +183,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): # o['global_obs'] = np.ones((17, 17)) * 17 # r['global_obs'] = 0 # d['global_obs'] = True - return o, r, d, infos + return oo, r, d, infos def get_agent_handles(self): return self.env.get_agent_handles() diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py index ddd2f65..45e6937 100644 --- a/RLLib_training/custom_preprocessors.py +++ b/RLLib_training/custom_preprocessors.py @@ -49,11 +49,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): class CustomPreprocessor(Preprocessor): def _init_shape(self, obs_space, options): - return obs_space.shape + return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,) def transform(self, observation): # if len(observation) == 111: - return norm_obs_clip(observation) + return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()]) one_hot = observation[-3:] return np.append(obs, one_hot) # else: diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index 2a35ad3..d7bf261 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -61,7 +61,12 @@ def train(config, reporter): # Observation space and action space definitions if isinstance(config["obs_builder"], TreeObsForRailEnv): - obs_space = gym.spaces.Box(low=-1, high=1, shape=(147,)) + obs_space = gym.spaces.Tuple((gym.spaces.Box(low=0, high=float('inf'), shape=(147,)), + gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)), + gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])), + gym.spaces.Box(low=0, high=float('inf'), shape=(147,)), + gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)), + gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])))) preprocessor = "tree_obs_prep" elif isinstance(config["obs_builder"], GlobalObsForRailEnv): @@ -97,7 +102,7 @@ def train(config, reporter): raise ValueError("Undefined observation space") - act_space = gym.spaces.Discrete(4) + act_space = gym.spaces.Discrete(5) # Dict with the different policies to train policy_graphs = { @@ -121,11 +126,11 @@ def train(config, reporter): trainer_config["horizon"] = config['horizon'] trainer_config["num_workers"] = 0 - trainer_config["num_cpus_per_worker"] = 2 - trainer_config["num_gpus"] = 0 - trainer_config["num_gpus_per_worker"] = 0 + trainer_config["num_cpus_per_worker"] = 11 + trainer_config["num_gpus"] = 0.5 + trainer_config["num_gpus_per_worker"] = 0.5 trainer_config["num_cpus_for_driver"] = 1 - trainer_config["num_envs_per_worker"] = 1 + trainer_config["num_envs_per_worker"] = 6 trainer_config['entropy_coeff'] = config['entropy_coeff'] trainer_config["env_config"] = env_config trainer_config["batch_mode"] = "complete_episodes" @@ -189,8 +194,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, "lambda_gae": lambda_gae }, resources_per_trial={ - "cpu": 3, - "gpu": 0.0 + "cpu": 12, + "gpu": 0.5 }, local_dir=local_dir ) -- GitLab