From 493b8a2027c83ce4683853a29bd9813a9be51eea Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Thu, 18 Apr 2019 16:46:00 +0200
Subject: [PATCH] started implementation of global view and the corresponding
 test

---
 flatland/core/env_observation_builder.py | 24 +++++++++
 tests/test_env_observation_builder.py    | 69 ++++++++++++++++++++++++
 2 files changed, 93 insertions(+)
 create mode 100644 tests/test_env_observation_builder.py

diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index a83c0c51..ac62946d 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -1,3 +1,5 @@
+import numpy as np
+
 # TODO: add docstrings, pylint, etc...
 
 
@@ -23,3 +25,25 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # raise NotImplementedError()
         return []
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Four 2D arrays containing respectively the position of the given agent,
+          the position of its target, the positions of the other agents and of
+          their target.
+    """
+    def __init__(self, env):
+        super(GlobalObsForRailEnv, self).__init__(env)
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                self.rail_obs[i, j] = self.env.rail.get_transitions((i, j))
+
+
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
new file mode 100644
index 00000000..c3f5dfd5
--- /dev/null
+++ b/tests/test_env_observation_builder.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
+from flatland.core.transitions import Grid4Transitions
+import numpy as np
+
+"""Tests for `flatland` package."""
+
+
+def test_global_obs():
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ /_\ _ _  _  _ _ _
+    #               \ /
+    #                |
+    #                |
+    #                |
+
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+
+    transitions = Grid4Transitions([])
+    empty = cells[0]
+
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+
+    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
+    double_switch_north_horizontal_straight = transitions.rotate_transition(
+        double_switch_south_horizontal_straight, 180)
+
+
+
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6]*2 +
+        [[horizontal_straight] * 3 + [double_switch_north_horizontal_straight] +
+        [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+        [horizontal_straight] * 3] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6], dtype=np.uint16)
+
+    print(rail_map.shape)
+
+test_global_obs()
+
+
+
+
+
+
+
+
+
+
-- 
GitLab