From 8a8a7853c2df2b617fed29fe4a738bc9c6567cd2 Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Thu, 23 May 2019 17:16:11 +0200
Subject: [PATCH] action_space and observation_space, issue #46

---
 flatland/core/env.py                     |  6 ++++++
 flatland/core/env_observation_builder.py |  4 ++++
 flatland/envs/observations.py            | 12 ++++++++----
 flatland/envs/rail_env.py                |  5 +++++
 4 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/flatland/core/env.py b/flatland/core/env.py
index 284afdff..5334b22f 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -9,6 +9,10 @@ class Environment:
     """
     Base interface for multi-agent environments in Flatland.
 
+    Derived environments should implement the following attributes:
+        action_space: tuple with the dimensions of the actions to be passed to the step method
+        observation_space: tuple with the dimensions of the observations returned by reset and step
+
     Agents are identified by agent ids (handles).
     Examples:
         >>> obs = env.reset()
@@ -39,6 +43,8 @@ class Environment:
     """
 
     def __init__(self):
+        self.action_space = ()
+        self.observation_space = ()
         pass
 
     def reset(self):
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 09a624e8..3cef545c 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -12,9 +12,13 @@ case of multi-agent environments.
 class ObservationBuilder:
     """
     ObservationBuilder base class.
+
+    Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned
+    observations.
     """
 
     def __init__(self):
+        self.observation_space = ()
         pass
 
     def _set_env(self, env):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index ba159cb7..af789f5d 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder):
     def __init__(self, max_depth):
         self.max_depth = max_depth
 
+        # Compute the size of the returned observation vector
+        size = 0
+        pow4 = 1
+        for i in range(self.max_depth+1):
+            size += pow4
+            pow4 *= 4
+        self.observation_space = [size * 5]
+
     def reset(self):
         agents = self.env.agents
         nAgents = len(agents)
@@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         the transitions. The order is:
             [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
 
-
-
-
-
         Each branch data is organized as:
             [root node information] +
             [recursive branch data from 'left'] +
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 74e7526c..4ac00f18 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -112,6 +112,10 @@ class RailEnv(Environment):
 
         self.valid_positions = None
 
+        self.action_space = [1]
+        self.observation_space = self.obs_builder.observation_space # updated on resets?
+
+
     # no more agent_handles
     def get_agent_handles(self):
         return range(self.get_num_agents())
@@ -160,6 +164,7 @@ class RailEnv(Environment):
 
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
+        self.observation_space = self.obs_builder.observation_space # <-- change on reset?
 
         # Return the new observation vectors for each agent
         return self._get_observations()
-- 
GitLab