From f92415eedc114cce4aef7bf52f1829021f30762e Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Wed, 12 Jun 2019 18:37:53 +0200
Subject: [PATCH] added get_many() method to observation builders, base class
 and derived. Also removed the render() methods in the env base class and
 railenv

---
 flatland/core/env.py                     |  6 ------
 flatland/core/env_observation_builder.py | 21 +++++++++++++++++++++
 flatland/envs/observations.py            | 13 +++++++++++++
 flatland/envs/rail_env.py                |  9 +--------
 4 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/flatland/core/env.py b/flatland/core/env.py
index 32691f50..3618d965 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -99,12 +99,6 @@ class Environment:
         """
         raise NotImplementedError()
 
-    def render(self):
-        """
-        Perform rendering of the environment.
-        """
-        raise NotImplementedError()
-
     def get_agent_handles(self):
         """
         Returns a list of agents' handles to be used as keys in the step()
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index b30c2b1f..53e7a068 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -30,6 +30,27 @@ class ObservationBuilder:
         """
         raise NotImplementedError()
 
+    def get_many(self, handles=[]):
+        """
+        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
+        in the `handles' list.
+
+        Parameters
+        -------
+        handles : list of handles (optional)
+            List with the handles of the agents for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            A dictionary of observation structures, specific to the corresponding environment, with handles from
+            `handles' as keys.
+        """
+        observations = {}
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
+
     def get(self, handle=0):
         """
         Called whenever an observation has to be computed for the `env' environment, possibly
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 676051d8..f3c0a7c3 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -167,6 +167,19 @@ class TreeObsForRailEnv(ObservationBuilder):
         elif movement == 3:  # WEST
             return (position[0], position[1] - 1)
 
+    def get_many(self, handles=[]):
+        """
+        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
+        in the `handles' list.
+        """
+
+        # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
+
+        observations = {}
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
+
     def get(self, handle):
         """
         Computes the current observation for agent `handle' in env
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 6cd64514..44ed3f77 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -330,10 +330,7 @@ class RailEnv(Environment):
         return new_direction, transition_isValid
 
     def _get_observations(self):
-        self.obs_dict = {}
-        self.debug_obs_dict = {}
-        for iAgent in range(self.get_num_agents()):
-            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
+        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
     def _get_predictions(self):
@@ -341,10 +338,6 @@ class RailEnv(Environment):
             return {}
         return {}
 
-    def render(self):
-        # TODO:
-        pass
-
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
-- 
GitLab