From 6e2c1c9d136859faf3e24c29d97142a1538ee570 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 24 Jul 2019 18:48:15 -0400
Subject: [PATCH] updating local obs for rail env to be better suited for the
 task

---
 examples/training_example.py             |  10 ++-
 flatland/core/env_observation_builder.py |   1 +
 flatland/envs/observations.py            | 110 ++++++++++++++++-------
 3 files changed, 86 insertions(+), 35 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index 3139209..67e0c74 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -1,9 +1,10 @@
 import numpy as np
 
 from flatland.envs.generators import complex_rail_generator
-from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
 
 np.random.seed(1)
 
@@ -12,12 +13,14 @@ np.random.seed(1)
 #
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
-              obs_builder_object=TreeObservation,
+              obs_builder_object=LocalGridObs,
               number_of_agents=2)
 
+env_renderer = RenderTool(env, gl="PILSVG", )
 
 # Import your own Agent or use RLlib to train agents on Flatland
 # As an example we use a random agent here
@@ -66,6 +69,7 @@ for trials in range(1, n_trials + 1):
 
     # Reset environment and get initial observations for all agents
     obs = env.reset()
+    env_renderer.reset()
     # Here you can also further enhance the provided observation by means of normalization
     # See training navigation example in the baseline repository
 
@@ -80,6 +84,8 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
+        env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 060785f..4acdf16 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -74,6 +74,7 @@ class ObservationBuilder:
         direction[agent.direction] = 1
         return direction
 
+
 class DummyObservationBuilder(ObservationBuilder):
     """
     DummyObservationBuilder class which returns dummy observations
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 498d98e..fb8e107 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -698,71 +698,115 @@ class LocalObsForRailEnv(ObservationBuilder):
     The observation is composed of the following elements:
 
         - transition map array of the local environment around the given agent,
-          with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
+          with dimensions (view_height,2*view_width+1, 16),
           assuming 16 bits encoding of transitions.
 
-        - Two 2D arrays (2*view_radius + 1, 2*view_radius + 1, 2) containing respectively,
+        - Two 3D arrays (view_height,2*view_width+1, 2) containing respectively,
         if they are in the agent's vision range, its target position, the positions of the other targets.
 
-        - A 3D array (2*view_radius + 1, 2*view_radius + 1, 4) containing the one hot encoding of directions
+        - A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
           of the other agents at their position coordinates, if they are in the agent's vision range.
 
         - A 4 elements array with one hot encoding of the direction.
     """
 
-    def __init__(self, view_radius):
+    def __init__(self, view_width, view_height, center):
         """
         :param view_radius:
         """
         super(LocalObsForRailEnv, self).__init__()
-        self.view_radius = view_radius
+        self.view_width = view_width
+        self.view_height = view_height
+        self.center = center
+        self.max_padding = max(self.view_width, self.view_height - self.center)
 
     def reset(self):
         # We build the transition map with a view_radius empty cells expansion on each side.
         # This helps to collect the local transition map view when the agent is close to a border.
-
-        self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
-                                  self.env.width + 2 * self.view_radius, 16))
+        self.max_padding = max(self.view_width, self.view_height)
+        self.rail_obs = np.zeros((self.env.height + 2 * self.max_padding,
+                                  self.env.width + 2 * self.max_padding, 16))
         for i in range(self.env.height):
             for j in range(self.env.width):
                 bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
+                self.rail_obs[i + self.view_height, j + self.view_width] = np.array(bitlist)
 
     def get(self, handle):
         agents = self.env.agents
         agent = agents[handle]
+        agent_rel_pos = [0, 0]
 
-        local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
-                         agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
-
-        obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
+        # Correct agents position for padding
+        agent_rel_pos[0] = agent.position[0] + self.max_padding
+        agent_rel_pos[1] = agent.position[1] + self.max_padding
 
-        obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
+        # Collect the rail information in the local field of view
+        local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs)
 
-        def relative_pos(pos):
-            return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
+        # Locate observed agents and their coresponding targets
+        obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2))
+        obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4))
 
-        def is_in(rel_pos):
-            return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius)
+        # Collect visible cells as set to be plotted
+        visited = self.field_of_view(agent.position, agent.direction)
 
-        target_rel_pos = relative_pos(agent.target)
-        if is_in(target_rel_pos):
-            obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1
+        # Add the visible cells to the observed cells
+        self.env.dev_obs_dict[handle] = visited
 
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
+        direction = self._get_one_hot_for_agent_direction(agent)
 
-                agent_2_rel_pos = relative_pos(agent2.position)
-                if is_in(agent_2_rel_pos):
-                    obs_other_agents_state[self.view_radius + agent_2_rel_pos[0],
-                                           self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1
+        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-                target_rel_pos_2 = relative_pos(agent2.position)
-                if is_in(target_rel_pos_2):
-                    obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
+    def get_many(self, handles=None):
+        """
+        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
+        in the `handles' list.
+        """
 
-        direction = self._get_one_hot_for_agent_direction(agent)
+        observations = {}
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
 
-        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
+    def field_of_view(self, position, direction, state=None):
+        # Compute the local field of view for an agent in the environment
+        data_collection = False
+        if state is not None:
+            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
+            data_collection = True
+        if direction == 0:
+            origin = (position[0] + self.center, position[1] - self.view_width)
+        elif direction == 1:
+            origin = (position[0] - self.view_width, position[1] - self.center)
+        elif direction == 2:
+            origin = (position[0] - self.center, position[1] + self.view_width)
+        else:
+            origin = (position[0] + self.view_width, position[1] + self.center)
+        visible = set()
+        for h in range(self.view_height):
+            for w in range(2 * self.view_width + 1):
+                if direction == 0:
+                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                        visible.add((origin[0] - h, origin[1] + w))
+                    if data_collection:
+                        temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
+                elif direction == 1:
+                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
+                        visible.add((origin[0] + w, origin[1] + h))
+                    if data_collection:
+                        temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
+                elif direction == 2:
+                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                        visible.add((origin[0] + h, origin[1] - w))
+                    if data_collection:
+                        temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
+                else:
+                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                        visible.add((origin[0] - w, origin[1] - h))
+                    if data_collection:
+                        temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
+        if data_collection:
+            return temp_visible_data
+        else:
+            return visible
-- 
GitLab