From 671e7521dfeefe8511dbde77eb6e7a675c48eb67 Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Thu, 18 Apr 2019 11:37:12 +0200
Subject: [PATCH] ObservationBuilder dummy + relevant modifications to env

---
 flatland/core/env.py                          | 54 +++++++++++--------
 flatland/core/env_observation_builder.py      | 27 ++++++++++
 .../{transitionmap.py => transition_map.py}   |  0
 flatland/utils/rail_env_generator.py          |  2 +-
 tests/test_environments.py                    |  2 +-
 5 files changed, 61 insertions(+), 24 deletions(-)
 create mode 100644 flatland/core/env_observation_builder.py
 rename flatland/core/{transitionmap.py => transition_map.py} (100%)

diff --git a/flatland/core/env.py b/flatland/core/env.py
index 950365fd..a7e63fd4 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -5,6 +5,8 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv
 """
 import random
 
+from .env_observation_builder import TreeObsForRailEnv
+
 
 class Environment:
     """
@@ -118,7 +120,10 @@ class RailEnv:
     beta to be passed as parameters to __init__().
     """
 
-    def __init__(self, rail, number_of_agents=1):
+    def __init__(self,
+                 rail,
+                 number_of_agents=1,
+                 custom_observation_builder=TreeObsForRailEnv):
         """
         Environment init.
 
@@ -128,6 +133,9 @@ class RailEnv:
             The transition matrix that defines the environment.
         number_of_agents : int
             Number of agents to spawn on the map.
+        custom_observation_builder: ObservationBuilder object
+            ObservationBuilder-derived object that takes this env object
+            as input as provides observation vectors for each agent.
         """
 
         self.rail = rail
@@ -136,10 +144,16 @@ class RailEnv:
 
         self.number_of_agents = number_of_agents
 
+        self.obs_builder = custom_observation_builder(env=self)
+
         self.actions = [0]*self.number_of_agents
         self.rewards = [0]*self.number_of_agents
         self.done = False
 
+        self.dones = {"__all__": False}
+        self.obs_dict = {}
+        self.rewards_dict = {}
+
         self.agents_handles = list(range(self.number_of_agents))
 
     def get_agent_handles(self):
@@ -192,10 +206,11 @@ class RailEnv:
                     self.agents_direction[i] = random.sample(
                                                valid_starting_directions, 1)[0]
 
-        obs_dict = {}
-        for handle in self.agents_handles:
-            obs_dict[handle] = self._get_observation_for_agent(handle)
-        return obs_dict
+        # Reset the state of the observation builder with the new environment
+        self.obs_builder.reset()
+
+        # Return the new observation vectors for each agent
+        return self._get_observations()
 
     def step(self, action_dict):
         alpha = 1.0
@@ -206,15 +221,12 @@ class RailEnv:
         global_reward = 1 * beta
 
         # Reset the step rewards
-        rewards_dict = {}
+        self.rewards_dict = {}
         for handle in self.agents_handles:
-            rewards_dict[handle] = 0
+            self.rewards_dict[handle] = 0
 
         if self.dones["__all__"]:
-            obs_dict = {}
-            for handle in self.agents_handles:
-                obs_dict[handle] = self._get_observation_for_agent(handle)
-            return obs_dict, rewards_dict, self.dones, {}
+            return self._get_observations(), self.rewards_dict, self.dones, {}
 
         for i in range(len(self.agents_handles)):
             handle = self.agents_handles[i]
@@ -307,14 +319,14 @@ class RailEnv:
                     self.agents_direction[i] = movement
                 else:
                     # the action was not valid, add penalty
-                    rewards_dict[handle] += invalid_action_penalty
+                    self.rewards_dict[handle] += invalid_action_penalty
 
             # if agent is not in target position, add step penalty
             if self.agents_position[i][0] == self.agents_target[i][0] and \
                self.agents_position[i][1] == self.agents_target[i][1]:
                 self.dones[handle] = True
             else:
-                rewards_dict[handle] += step_penalty
+                self.rewards_dict[handle] += step_penalty
 
         # Check for end of episode + add global reward to all rewards!
         num_agents_in_target_position = 0
@@ -325,17 +337,13 @@ class RailEnv:
 
         if num_agents_in_target_position == self.number_of_agents:
             self.dones["__all__"] = True
-            rewards_dict = [r+global_reward for r in rewards_dict]
+            self.rewards_dict = [r+global_reward for r in self.rewards_dict]
 
         # Reset the step actions (in case some agent doesn't 'register_action'
         # on the next step)
         self.actions = [0]*self.number_of_agents
 
-        obs_dict = {}
-        for handle in self.agents_handles:
-            obs_dict[handle] = self._get_observation_for_agent(handle)
-
-        return obs_dict, rewards_dict, self.dones, {}
+        return self._get_observations(), self.rewards_dict, self.dones, {}
 
     def _new_position(self, position, movement):
         if movement == 0:    # NORTH
@@ -376,9 +384,11 @@ class RailEnv:
 
         return 0
 
-    def _get_observation_for_agent(self, handle):
-        # TODO:
-        return None
+    def _get_observations(self):
+        self.obs_dict = {}
+        for handle in self.agents_handles:
+            self.obs_dict[handle] = self.obs_builder.get(handle)
+        return self.obs_dict
 
     def render(self):
         # TODO:
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
new file mode 100644
index 00000000..17061854
--- /dev/null
+++ b/flatland/core/env_observation_builder.py
@@ -0,0 +1,27 @@
+## TODO: add docstrings, pylint, etc...
+
+
+class ObservationBuilder:
+    def __init__(self, env):
+        self.env = env
+
+    def reset(self):
+        raise NotImplementedError()
+
+    def get(self, handle):
+        raise NotImplementedError()
+
+
+
+class TreeObsForRailEnv(ObservationBuilder):
+    def reset(self):
+        # TODO: precompute distances, etc...
+        #raise NotImplementedError()
+        pass
+
+    def get(self, handle):
+        # TODO: compute the observation for agent `handle'
+        #raise NotImplementedError()
+        return []
+
+
diff --git a/flatland/core/transitionmap.py b/flatland/core/transition_map.py
similarity index 100%
rename from flatland/core/transitionmap.py
rename to flatland/core/transition_map.py
diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py
index 5b292f03..69e5b831 100644
--- a/flatland/utils/rail_env_generator.py
+++ b/flatland/utils/rail_env_generator.py
@@ -6,7 +6,7 @@ import random
 import numpy as np
 
 from flatland.core.transitions import RailEnvTransitions
-from flatland.core.transitionmap import GridTransitionMap
+from flatland.core.transition_map import GridTransitionMap
 
 
 def generate_rail_from_manual_specifications(rail_spec):
diff --git a/tests/test_environments.py b/tests/test_environments.py
index 66e6bef4..03544b08 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -3,7 +3,7 @@
 
 from flatland.core.env import RailEnv
 from flatland.core.transitions import Grid4Transitions
-from flatland.core.transitionmap import GridTransitionMap
+from flatland.core.transition_map import GridTransitionMap
 import numpy as np
 
 """Tests for `flatland` package."""
-- 
GitLab