From 77e5b4eb2acfdcd424f2e4f139389972805f2b13 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume@iccluster028.iccluster.epfl.ch>
Date: Wed, 12 Jun 2019 10:09:20 +0200
Subject: [PATCH] first prediction test

---
 RLLib_training/RailEnvRLLibWrapper.py  | 27 ++++++++++++++++++--------
 RLLib_training/custom_preprocessors.py |  4 ++--
 RLLib_training/train_experiment.py     | 21 ++++++++++++--------
 3 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index f9643a8..ca2a78b 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -89,12 +89,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
             agent_id_one_hot[i_agent] = 1
             o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
 
-
+        self.old_obs = o
+        oo = dict()
+        for i_agent in range(len(self.env.agents)):
+            oo[i_agent] = [o[i_agent], o[i_agent][0], o[i_agent][1], o[i_agent][2]]
         self.rail = self.env.rail
         self.agents = self.env.agents
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
-        return obs
+        return oo
 
     def step(self, action_dict):
         obs, rewards, dones, infos = self.env.step(action_dict)
@@ -105,6 +108,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         o = dict()
         # print(self.agents_done)
         # print(dones)
+        predictions = self.env.predict()
+        pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
 
         for i_agent in range(len(self.env.agents)):
             if i_agent not in self.agents_done:
@@ -153,15 +158,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         #
         #         d[agent] = dones[agent]
 
-        for agent, done in dones.items():
-            if done and agent != '__all__':
-                self.agents_done.append(agent)
-
-        self.rail = self.env.rail
         self.agents = self.env.agents
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
+        #print('Old OBS #####', self.old_obs)
+        oo = dict()
+        for i_agent in range(len(self.env.agents)):
+            if i_agent not in self.agents_done:
+                oo[i_agent] = [o[i_agent], self.old_obs[i_agent][0], self.old_obs[i_agent][1],
+                            self.old_obs[i_agent][2]]
         
+        self.old_obs = o
+        for agent, done in dones.items():
+            if done and agent != '__all__':
+                self.agents_done.append(agent)
+
         #print(obs)
         #return obs, rewards, dones, infos
         # oo = dict()
@@ -172,7 +183,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         # o['global_obs'] = np.ones((17, 17)) * 17
         # r['global_obs'] = 0
         # d['global_obs'] = True
-        return o, r, d, infos
+        return oo, r, d, infos
 
     def get_agent_handles(self):
         return self.env.get_agent_handles()
diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py
index ddd2f65..45e6937 100644
--- a/RLLib_training/custom_preprocessors.py
+++ b/RLLib_training/custom_preprocessors.py
@@ -49,11 +49,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 
 class CustomPreprocessor(Preprocessor):
     def _init_shape(self, obs_space, options):
-        return obs_space.shape
+        return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,)
 
     def transform(self, observation):
         # if len(observation) == 111:
-        return norm_obs_clip(observation)
+        return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
         one_hot = observation[-3:]
         return np.append(obs, one_hot)
         # else:
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index 2a35ad3..d7bf261 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -61,7 +61,12 @@ def train(config, reporter):
 
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        obs_space = gym.spaces.Box(low=-1, high=1, shape=(147,))
+        obs_space = gym.spaces.Tuple((gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
+                                     gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
+                                     gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),
+                                     gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
+                                     gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
+                                     gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents']))))
         preprocessor = "tree_obs_prep"
 
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
@@ -97,7 +102,7 @@ def train(config, reporter):
         raise ValueError("Undefined observation space")
 
 
-    act_space = gym.spaces.Discrete(4)
+    act_space = gym.spaces.Discrete(5)
 
     # Dict with the different policies to train
     policy_graphs = {
@@ -121,11 +126,11 @@ def train(config, reporter):
     trainer_config["horizon"] = config['horizon']
 
     trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 2
-    trainer_config["num_gpus"] = 0
-    trainer_config["num_gpus_per_worker"] = 0
+    trainer_config["num_cpus_per_worker"] = 11
+    trainer_config["num_gpus"] = 0.5
+    trainer_config["num_gpus_per_worker"] = 0.5
     trainer_config["num_cpus_for_driver"] = 1
-    trainer_config["num_envs_per_worker"] = 1
+    trainer_config["num_envs_per_worker"] = 6
     trainer_config['entropy_coeff'] = config['entropy_coeff']
     trainer_config["env_config"] = env_config
     trainer_config["batch_mode"] = "complete_episodes"
@@ -189,8 +194,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "lambda_gae": lambda_gae
                 },
         resources_per_trial={
-            "cpu": 3,
-            "gpu": 0.0
+            "cpu": 12,
+            "gpu": 0.5
         },
         local_dir=local_dir
     )
-- 
GitLab