From 94293fa87a125d5aa97ff4a2a595d2eb5bf5d6e7 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Tue, 9 Apr 2019 14:55:40 +0200
Subject: [PATCH] Simple tests for rail_env

---
 .gitignore                   |  3 ++
 flatland/core/env.py         |  9 ++++--
 flatland/core/transitions.py | 14 ++++-----
 requirements_dev.txt         |  2 ++
 tests/test_environments.py   | 57 +++++++++++++++++++++++++++++++-----
 5 files changed, 69 insertions(+), 16 deletions(-)

diff --git a/.gitignore b/.gitignore
index 84229f4..ca0c210 100644
--- a/.gitignore
+++ b/.gitignore
@@ -71,6 +71,9 @@ target/
 # Jupyter Notebook
 .ipynb_checkpoints
 
+# PyCharm
+.idea/
+
 # pyenv
 .python-version
 
diff --git a/flatland/core/env.py b/flatland/core/env.py
index 4a147d0..2ecee63 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -108,7 +108,7 @@ class RailEnv:
         0: do nothing
         1: turn left and move to the next cell
         2: move to the next cell in front of the agent
-        3: turn righ tand move to the next cell
+        3: turn right and move to the next cell
 
     Moving forward in a dead-end cell makes the agent turn 180 degrees and step
     to the cell it came from.
@@ -276,6 +276,7 @@ class RailEnv:
                                             self.rail[pos[0]][pos[1]],
                                             reverse_direction,
                                             reverse_direction)
+
                         if valid_transition:
                             direction = reverse_direction
                             movement = direction
@@ -285,7 +286,11 @@ class RailEnv:
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
                 # free, i.e., no agent is currently in that cell
-                if self.rail[new_position[0]][new_position[1]] > 0:
+                if new_position[1] >= self.width or new_position[0] >= self.height or\
+                    new_position[0] < 0 or new_position[1] < 0:
+                    new_cell_isValid = False
+
+                elif self.rail[new_position[0]][new_position[1]] > 0:
                     new_cell_isValid = True
                 else:
                     new_cell_isValid = False
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index dc4ffe8..9c4f05e 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -440,13 +440,13 @@ class RailEnvTransitions(GridTransitions):
     """
 
     transition_list = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000000000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
+                       int('1000000000100000', 2),  # Case 1 - straight
+                       int('1001001000100000', 2),  # Case 2 - simple switch
+                       int('1000010000100001', 2),  # Case 3 - diamond drossing
+                       int('1001011000100001', 2),  # Case 4 - single slip switch
+                       int('1100110000110011', 2),  # Case 5 - double slip switch
+                       int('0101001000000010', 2),  # Case 6 - symmetrical switch
+                       int('0010000000000000', 2)]  # Case 7 - dead end
 
     def __init__(self):
         super(RailEnvTransitions, self).__init__(
diff --git a/requirements_dev.txt b/requirements_dev.txt
index cba711f..ad6c0ad 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -11,3 +11,5 @@ twine==1.12.1
 pytest==3.8.2
 pytest-runner==4.2
 sphinx-rtd-theme==0.4.3
+
+numpy==1.16.2
diff --git a/tests/test_environments.py b/tests/test_environments.py
index e95c398..435002b 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -4,15 +4,17 @@
 from flatland.core.env import RailEnv
 from flatland.core.transitions import GridTransitions
 import numpy as np
+import random
 
 """Tests for `flatland` package."""
 
 
 
-def test_rail_environment():
+def test_rail_environment_single_agent():
+
     cells = [int('0000000000000000', 2),  # empty cell - Case 0
                        int('1000000000100000', 2),  # Case 1 - straight
-                       int('1001001000000000', 2),  # Case 2 - simple switch
+                       int('1001001000100000', 2),  # Case 2 - simple switch
                        int('1000010000100001', 2),  # Case 3 - diamond drossing
                        int('1001011000100001', 2),  # Case 4 - single slip switch
                        int('1100110000110011', 2),  # Case 5 - double slip switch
@@ -29,8 +31,9 @@ def test_rail_environment():
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
-    south_east_turn = int('0100000000100000', 2)  # Simple turn not in the base transitions ?
+    south_east_turn = int('0100000000000010', 2)  # Simple turn not in the base transitions ?
     south_west_turn = transitions.rotate_transition(south_east_turn, 90)
+    # print(bytes(south_west_turn))
     north_east_turn = transitions.rotate_transition(south_east_turn, 270)
     north_west_turn = transitions.rotate_transition(south_east_turn, 180)
 
@@ -40,12 +43,52 @@ def test_rail_environment():
                    dtype=np.uint16)
 
     rail_env = RailEnv(rail_map, number_of_agents=1)
+    for _ in range(200):
+        _ = rail_env.reset()
+
+        # We do not care about target for the moment
+        rail_env.agents_target[0] = [-1, -1]
 
-    # Check that trains are always initialized at a consistent position / direction.
-    # They should always be able to go somewhere.
-    for _ in range(1000):
-        obs = rail_env.reset()
+        # Check that trains are always initialized at a consistent position / direction.
+        # They should always be able to go somewhere.
         assert(transitions.get_transitions_from_orientation(
             rail_map[rail_env.agents_position[0]],
             rail_env.agents_direction[0]) != (0, 0, 0, 0))
 
+        initial_pos = rail_env.agents_position[0]
+
+        valid_active_actions_done = 0
+        pos = initial_pos
+        while valid_active_actions_done < 6:
+            # We randomly select an action
+            action = np.random.randint(4)
+
+            _, _, _, _ = rail_env.step({0: action})
+
+            prev_pos = pos
+            pos = rail_env.agents_position[0]
+            if prev_pos != pos:
+                valid_active_actions_done += 1
+
+        # After 6 movements on this railway network, the train should be back to its original
+        # position.
+        assert(initial_pos[0] == rail_env.agents_position[0][0])
+
+        # We check that the train always attains its target after some time
+        for _ in range(200):
+            _ = rail_env.reset()
+
+            done = False
+            while not done:
+                # We randomly select an action
+                action = np.random.randint(4)
+
+                _, _, dones, _ = rail_env.step({0: action})
+
+                done = dones['__all__']
+
+
+
+
+
+
-- 
GitLab