From 97a275598dc03158ecc5ebf6b850a164feb3dc87 Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Thu, 23 May 2019 20:14:30 +0200
Subject: [PATCH] part3 of getting started, new custom_obs and custom_rail
 examples, fixes to GlobalObs

---
 examples/custom_observation_example.py | 32 ++++++++++++++++++++++
 examples/custom_railmap_example.py     | 37 ++++++++++++++++++++++++++
 examples/simple_example_3.py           |  4 ++-
 flatland/envs/observations.py          |  6 +++++
 flatland/envs/rail_env.py              |  7 +++--
 5 files changed, 81 insertions(+), 5 deletions(-)
 create mode 100644 examples/custom_observation_example.py
 create mode 100644 examples/custom_railmap_example.py

diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
new file mode 100644
index 00000000..03bbe4b7
--- /dev/null
+++ b/examples/custom_observation_example.py
@@ -0,0 +1,32 @@
+import random
+
+from flatland.envs.generators import random_rail_generator, random_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+from flatland.core.env_observation_builder import ObservationBuilder
+import numpy as np
+
+random.seed(100)
+np.random.seed(100)
+
+class CustomObs(ObservationBuilder):
+    def __init__(self):
+        self.observation_space = [5]
+
+    def reset(self):
+        return
+
+    def get(self, handle):
+        observation = handle*np.ones((5,))
+        return observation
+
+env = RailEnv(width=7,
+              height=7,
+              rail_generator=random_rail_generator(),
+              number_of_agents=3,
+              obs_builder_object=CustomObs())
+
+# Print the observation vector for each agents
+obs, all_rewards, done, _ = env.step({0: 0})
+for i in range(env.get_num_agents()):
+    print("Agent ", i,"'s observation: ", obs[i])
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
new file mode 100644
index 00000000..71f849de
--- /dev/null
+++ b/examples/custom_railmap_example.py
@@ -0,0 +1,37 @@
+import random
+
+from flatland.envs.generators import random_rail_generator, random_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.core.transitions import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+from flatland.utils.rendertools import RenderTool
+import numpy as np
+
+random.seed(100)
+np.random.seed(100)
+
+def custom_rail_generator():
+    def generator(width, height, num_agents=0, num_resets=0):
+        rail_trans = RailEnvTransitions()
+        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        rail_array = grid_map.grid
+        rail_array.fill(0)
+
+        agents_positions = []
+        agents_direction = []
+        agents_target = []
+
+        return grid_map, agents_positions, agents_direction, agents_target
+    return generator
+
+env = RailEnv(width=6,
+              height=4,
+              rail_generator=custom_rail_generator(),
+              number_of_agents=1)
+
+env.reset()
+
+env_renderer = RenderTool(env, gl="QT")
+env_renderer.renderEnv(show=True)
+
+input("Press Enter to continue...")
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index e0830ff7..8aac0ccc 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -3,6 +3,7 @@ import random
 from flatland.envs.generators import random_rail_generator, random_rail_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
+from flatland.core.env_observation_builder import ObservationBuilder
 import numpy as np
 
 random.seed(100)
@@ -11,7 +12,8 @@ np.random.seed(100)
 env = RailEnv(width=7,
               height=7,
               rail_generator=random_rail_generator(),
-              number_of_agents=2)
+              number_of_agents=2,
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 # Print the distance map of each cell to the target of the first agent
 # for i in range(4):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index af789f5d..7de63271 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -491,8 +491,14 @@ class GlobalObsForRailEnv(ObservationBuilder):
     """
 
     def __init__(self):
+        self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
+    def _set_env(self, env):
+        super()._set_env(env)
+
+        self.observation_space = [4, self.env.height, self.env.width]
+
     def reset(self):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
         for i in range(self.rail_obs.shape[0]):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 4ac00f18..be3f4d6b 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -90,6 +90,9 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
+        self.action_space = [1]
+        self.observation_space = self.obs_builder.observation_space # updated on resets?
+
         self.actions = [0] * number_of_agents
         self.rewards = [0] * number_of_agents
         self.done = False
@@ -112,10 +115,6 @@ class RailEnv(Environment):
 
         self.valid_positions = None
 
-        self.action_space = [1]
-        self.observation_space = self.obs_builder.observation_space # updated on resets?
-
-
     # no more agent_handles
     def get_agent_handles(self):
         return range(self.get_num_agents())
-- 
GitLab