From 2c91e4868ca9437e5e75450e4df784feb9310866 Mon Sep 17 00:00:00 2001
From: "S.P. Mohanty" <spmohanty91@gmail.com>
Date: Fri, 18 Oct 2019 19:04:51 +0200
Subject: [PATCH] Add a custom observation builder

---
 my_observation_builder.py | 106 ++++++++++++++++++++++++++++++++++++++
 run.py                    |  29 +++++++----
 2 files changed, 125 insertions(+), 10 deletions(-)
 create mode 100644 my_observation_builder.py

diff --git a/my_observation_builder.py b/my_observation_builder.py
new file mode 100644
index 0000000..da2a209
--- /dev/null
+++ b/my_observation_builder.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python 
+
+import collections
+from typing import Optional, List, Dict, Tuple
+
+import numpy as np
+
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.grid_utils import coordinate_to_position
+from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
+from flatland.utils.ordered_set import OrderedSet
+
+
+class CustomObservationBuilder(ObservationBuilder):
+    """
+    Template for building a custom observation builder for the RailEnv class
+
+    The observation in this case composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width),\
+          where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
+        
+        - the individual agent object (with position, direction, target information available)
+
+    """
+    def __init__(self):
+        super(CustomObservationBuilder, self).__init__()
+
+    def set_env(self, env: Environment):
+        super().set_env(env)
+        # Note :
+        # The instantiations which depend on parameters of the Env object should be 
+        # done here, as it is only here that the updated self.env instance is available
+        self.rail_obs = np.zeros((self.env.height, self.env.width))
+        print("Env Width : ", self.env.width, "Env Height : ", self.env.height)
+
+    def reset(self):
+        """
+        Called internally on every env.reset() call, 
+        to reset any observation specific variables that are being used
+        """
+        self.rail_obs[:] = 0        
+        for _x in range(self.env.width):
+            for _y in range(self.env.height):
+                # Get the transition map value at location _x, _y
+                transition_value = self.env.rail.get_full_transitions(_y, _x)
+                self.rail_obs[_y, _x] = transition_value
+        print("Responding to obs_builder.reset()")
+
+    def get(self, handle: int = 0):
+        """
+        Returns the built observation for a single agent with handle : handle
+
+        In this particular case, we return 
+        - the global transition_map of the RailEnv,
+        - a tuple containing, the current agent's:
+            - state
+            - position
+            - direction
+            - initial_position
+            - target
+        """
+
+        agent = self.env.agents[handle]
+        """
+        Available information for each agent object : 
+
+        - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        - agent.position : Current position of the agent
+        - agent.direction : Current direction of the agent
+        - agent.initial_position : Initial Position of the agent
+        - agent.target : Target position of the agent
+        """
+
+        status = agent.status
+        position = agent.position
+        direction = agent.direction
+        initial_position = agent.initial_position
+        target = agent.target
+
+        
+        """
+        You can also optionally access the states of the rest of the agents by 
+        using something similar to 
+
+        for i in range(len(self.env.agents)):
+            other_agent: EnvAgent = self.env.agents[i]
+
+            # ignore other agents not in the grid any more
+            if other_agent.status == RailAgentStatus.DONE_REMOVED:
+                continue
+
+            ## Gather other agent specific params 
+            other_agent_status = other_agent.status
+            other_agent_position = other_agent.position
+            other_agent_direction = other_agent.direction
+            other_agent_initial_position = other_agent.initial_position
+            other_agent_target = other_agent.target
+
+            ## Do something nice here if you wish
+        """
+        return self.rail_obs, (status, position, direction, initial_position, target)
+
diff --git a/run.py b/run.py
index 15a97b8..5c5bb9a 100644
--- a/run.py
+++ b/run.py
@@ -1,6 +1,6 @@
 from flatland.evaluators.client import FlatlandRemoteClient
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from my_observation_builder import CustomObservationBuilder
 import numpy as np
 import time
 
@@ -31,10 +31,14 @@ def my_controller(obs, number_of_agents):
 # the example here : 
 # https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
 #####################################################################
-my_observation_builder = TreeObsForRailEnv(
-                                max_depth=3,
-                                predictor=ShortestPathPredictorForRailEnv()
-                            )
+my_observation_builder = CustomObservationBuilder()
+
+# Or if you want to use your own approach to build the observation from the env_step, 
+# please feel free to pass a DummyObservationBuilder() object as mentioned below,
+# and that will just return a placeholder True for all observation, and you 
+# can build your own Observation for all the agents as your please.
+# my_observation_builder = DummyObservationBuilder()
+
 
 #####################################################################
 # Main evaluation loop
@@ -55,9 +59,11 @@ while True:
     # You can also pass your custom observation_builder object
     # to allow you to have as much control as you wish 
     # over the observation of your choice.
+    time_start = time.time()
     observation, info = remote_client.env_create(
                     obs_builder_object=my_observation_builder
                 )
+    env_creation_time = time.time() - time_start
     if not observation:
         #
         # If the remote_client returns False on a `env_create` call,
@@ -66,7 +72,7 @@ while True:
         # and hence its safe to break out of the main evaluation loop
         break
     
-    #print("Evaluation Number : {}".format(evaluation_number))
+    print("Evaluation Number : {}".format(evaluation_number))
 
     #####################################################################
     # Access to a local copy of the environment
@@ -95,12 +101,12 @@ while True:
     # or when the number of time steps has exceed max_time_steps, which 
     # is defined by : 
     #
-    # max_time_steps = int(1.5 * (env.width + env.height))
+    # max_time_steps = int(4 * 2 * (env.width + env.height + 20))
     #
     time_taken_by_controller = []
     time_taken_per_step = []
-
-    for k in range(10):
+    steps = 0
+    while True:
         #####################################################################
         # Evaluation of a single episode
         #
@@ -119,6 +125,7 @@ while True:
         # are returned by the remote copy of the env
         time_start = time.time()
         observation, all_rewards, done, info = remote_client.env_step(action)
+        steps += 1
         time_taken = time.time() - time_start
         time_taken_per_step.append(time_taken)
 
@@ -136,6 +143,8 @@ while True:
     print("="*100)
     print("Evaluation Number : ", evaluation_number)
     print("Current Env Path : ", remote_client.current_env_path)
+    print("Env Creation Time : ", env_creation_time)
+    print("Number of Steps : ", steps)
     print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
     print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
     print("="*100)
-- 
GitLab