From 507f0e86309f6c24dd6e72777f1604c72188dc3e Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Tue, 11 Jun 2019 15:08:06 +0200
Subject: [PATCH] added simple conflict detection

---
 RLLib_training/RailEnvRLLibWrapper.py | 134 ++++++++++++++++++++------
 1 file changed, 102 insertions(+), 32 deletions(-)

diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index 704bb12..f9643a8 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -1,39 +1,44 @@
 from flatland.envs.rail_env import RailEnv
 from ray.rllib.env.multi_agent_env import MultiAgentEnv
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.generators import random_rail_generator
 from ray.rllib.utils.seed import seed as set_seed
 from flatland.envs.generators import complex_rail_generator, random_rail_generator
 import numpy as np
+from flatland.envs.predictions import DummyPredictorForRailEnv
 
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
     def __init__(self, config):
-                 # width,
-                 # height,
-                 # rail_generator=random_rail_generator(),
-                 # number_of_agents=1,
-                 # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+
         super(MultiAgentEnv, self).__init__()
         if hasattr(config, "vector_index"):
             vector_index = config.vector_index
         else:
             vector_index = 1
 
+        self.predefined_env = False
+
         if config['rail_generator'] == "complex_rail_generator":
             self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
                                                           nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
-        else:
-            raise(Error)
+        elif config['rail_generator'] == "random_rail_generator":
             self.rail_generator = random_rail_generator()
+        elif config['rail_generator'] == "load_env":
+            self.predefined_env = True
+
+        else:
+            raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
 
         set_seed(config['seed'] * (1+vector_index))
         self.env = RailEnv(width=config["width"], height=config["height"],
                 number_of_agents=config["number_of_agents"],
-                obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)
+                obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
+                prediction_builder_object=DummyPredictorForRailEnv())
 
-        # self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
+        if self.predefined_env:
+            self.env.load(config['load_env_path'])
+                # '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
 
         self.width = self.env.width
         self.height = self.env.height
@@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
     
     def reset(self):
         self.agents_done = []
-        obs = self.env.reset()
+        if self.predefined_env:
+            obs = self.env.reset(False, False)
+        else:
+            obs = self.env.reset()
+
+        predictions = self.env.predict()
+        pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
+
         o = dict()
-        
-        
-        #for agent, _ in obs.items():                                     
-            #o[agent] = obs[agent]                                     
-        #    one_hot_agent_encoding = np.zeros(len(self.env.agents))   
-        #    one_hot_agent_encoding[agent] += 1                        
-        #    o[agent] = np.append(obs[agent], one_hot_agent_encoding)        
-        
-        # o['agents'] = obs
-        # obs[0] = [obs[0], np.ones((17, 17)) * 17]
-        # obs['global_obs'] = np.ones((17, 17)) * 17
+
+        for i_agent in range(len(self.env.agents)):
+
+            # prediction of collision that will be added to the observation
+            # Allows to the agent to know which other train is is about to meet (maybe will come
+            # up with a priority order of trains).
+            pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
+
+            for time_offset in range(len(predictions[0])):
+
+                # We consider a time window of t-1; t+1 to find a collision
+                collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
+
+                coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
+
+                # x coordinates of all other train in the time window
+                x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
+                                                :, collision_window, 0]
+
+                # y coordinates of all other train in the time window
+                y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
+                                                :, collision_window, 1]
+
+                coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
+
+                # collision_info here contains the index of the agent colliding with the current agent
+                for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
+                    pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
+
+            agent_id_one_hot = np.zeros(len(self.env.agents))
+            agent_id_one_hot[i_agent] = 1
+            o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
 
 
         self.rail = self.env.rail
@@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         o = dict()
         # print(self.agents_done)
         # print(dones)
-        for agent, done in dones.items():
-            if agent not in self.agents_done:
-                if agent != '__all__':
-        #            o[agent] = obs[agent]
-                    #one_hot_agent_encoding = np.zeros(len(self.env.agents))
-                    #one_hot_agent_encoding[agent] += 1
-                    o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
-                    r[agent] = rewards[agent]
-    
-                d[agent] = dones[agent]
+
+        for i_agent in range(len(self.env.agents)):
+            if i_agent not in self.agents_done:
+                # prediction of collision that will be added to the observation
+                # Allows to the agent to know which other train is is about to meet (maybe will come
+                # up with a priority order of trains).
+                pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
+
+                for time_offset in range(len(predictions[0])):
+
+                    # We consider a time window of t-1; t+1 to find a collision
+                    collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
+
+                    coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
+
+                    # x coordinates of all other train in the time window
+                    x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
+                                                    :, collision_window, 0]
+
+                    # y coordinates of all other train in the time window
+                    y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
+                                                    :, collision_window, 1]
+
+                    coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
+
+                    # collision_info here contains the index of the agent colliding with the current agent
+                    for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
+                        pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
+
+                agent_id_one_hot = np.zeros(len(self.env.agents))
+                agent_id_one_hot[i_agent] = 1
+                o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+                r[i_agent] = rewards[i_agent]
+                d[i_agent] = dones[i_agent]
+
+        d['__all__'] = dones['__all__']
+
+        # for agent, done in dones.items():
+        #     if agent not in self.agents_done:
+        #         if agent != '__all__':
+        # #            o[agent] = obs[agent]
+        #             #one_hot_agent_encoding = np.zeros(len(self.env.agents))
+        #             #one_hot_agent_encoding[agent] += 1
+        #             o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
+        #
+        #
+        #         d[agent] = dones[agent]
 
         for agent, done in dones.items():
             if done and agent != '__all__':
-- 
GitLab