From 6e367dd68c7e59abe25c4762be61503d8949e7ac Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 29 Aug 2019 13:44:10 +0200
Subject: [PATCH] bugfix #141: check_path_exists and tests

---
 flatland/core/transition_map.py            |  29 +++++-
 flatland/envs/schedule_generators.py       |  34 ++----
 flatland/utils/simple_rail.py              |  37 +++++++
 tests/test_flatland_core_transition_map.py | 116 ++++++++++++++++++++-
 4 files changed, 185 insertions(+), 31 deletions(-)

diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index e32a5623..1de8d68f 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -7,6 +7,7 @@ from importlib_resources import path
 from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transitions import Transitions
 
 
@@ -298,8 +299,7 @@ class GridTransitionMap(TransitionMap):
         self.height = new_height
         self.grid = new_grid
 
-
-    def is_dead_end(self,rcPos):
+    def is_dead_end(self, rcPos):
         """
         Check if the cell is a dead-end
         :param rcPos: tuple(row, column) with grid coordinate
@@ -310,7 +310,30 @@ class GridTransitionMap(TransitionMap):
         while tmp > 0:
             nbits += (tmp & 1)
             tmp = tmp >> 1
-        return nbits==1
+        return nbits == 1
+
+    def _path_exists(self, start, direction, end):
+        # print("_path_exists({},{},{}".format(start, direction, end))
+        # BFS - Check if a path exists between the 2 nodes
+
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            node_position = node[0]
+            node_direction = node[1]
+            if node_position[0] == end[0] and node_position[1] == end[1]:
+                return True
+            if node not in visited:
+                visited.add(node)
+
+                moves = self.get_transitions(node_position[0], node_position[1], node_direction)
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((get_new_position(node_position, move_index),
+                                      move_index))
+
+        return False
 
     def cell_neighbours_valid(self, rcPos, check_this_cell=False):
         """
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index c9f97ea4..59bc3e5d 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -131,29 +131,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
     """
 
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
-        def _path_exists(rail, start, direction, end):
-            # BFS - Check if a path exists between the 2 nodes
-
-            visited = set()
-            stack = [(start, direction)]
-            while stack:
-                node = stack.pop()
-                if node[0][0] == end[0] and node[0][1] == end[1]:
-                    return True
-                if node not in visited:
-                    visited.add(node)
-                    moves = rail.get_transitions(node[0][0], node[0][1], node[1])
-                    for move_index in range(4):
-                        if moves[move_index]:
-                            stack.append((get_new_position(node[0], move_index),
-                                          move_index))
-
-                    # If cell is a dead-end, append previous node with reversed
-                    # orientation!
-                    if rail.is_dead_end(node[0]):
-                        stack.append((node[0], (node[1] + 2) % 4))
-
-            return False
 
         valid_positions = []
         for r in range(rail.height):
@@ -194,6 +171,8 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
             re_generate = False
             for i in range(num_agents):
                 valid_movements = []
+                if rail.is_dead_end(agents_position[i]):
+                    print("   dead_end", agents_position[i])
                 for direction in range(4):
                     position = agents_position[i]
                     moves = rail.get_transitions(position[0], position[1], direction)
@@ -204,14 +183,15 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
                 valid_starting_directions = []
                 for m in valid_movements:
                     new_position = get_new_position(agents_position[i], m[1])
-                    if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0],
-                                                                              agents_target[i]):
+                    if m[0] not in valid_starting_directions and rail._path_exists(new_position, m[0],
+                                                                                   agents_target[i]):
                         valid_starting_directions.append(m[0])
 
                 if len(valid_starting_directions) == 0:
-                    re_generate = True
                     update_agents[i] = 1
-                    print("reset position for agents:",i, agents_position[i],agents_target[i])
+                    print("reset position for agents:", i, agents_position[i], agents_target[i])
+                    print("   dead_end", rail.is_dead_end(agents_position[i]))
+                    re_generate = True
                     break
                 else:
                     agents_direction[i] = valid_starting_directions[
diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index c5fe4860..67bd93dd 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -81,6 +81,43 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
     rail.grid = rail_map
     return rail, rail_map
 
+def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _  _ _ _  _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6]  +
+        [[empty] * 3 + [dead_end_from_north] + [empty] * 6]  +
+        [[dead_end_from_east] + [horizontal_straight]  * 5 + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
 
 def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index a4142316..eb35856a 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -1,6 +1,13 @@
 from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
 from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
 
 
 def test_grid4_get_transitions():
@@ -43,4 +50,111 @@ def test_grid8_set_transitions():
     grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
     assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
 
-# TODO GridTransitionMap
+
+def check_path(env, rail, position, direction, target, expected, rendering=False):
+    agent = env.agents_static[0]
+    agent.position = position  # south dead-end
+    agent.direction = direction  # north
+    agent.target = target  # east dead-end
+    agent.moving = True
+    # reset to set agents from agents_static
+    # env.reset(False, False)
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.render_env(show=True, show_observations=False)
+        input("Continue?")
+    assert rail._path_exists(agent.position, agent.direction, agent.target) == expected
+
+
+def test_path_exists(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # reset to initialize agents_static
+    env.reset()
+
+    check_path(
+        env,
+        rail,
+        (5, 6),  # north of south dead-end
+        0,  # north
+        (3, 9),  # east dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (6, 6),  # south dead-end
+        2,  # south
+        (3, 9),  # east dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (3, 0),  # east dead-end
+        3,  # west
+        (0, 3),  # north dead-end
+        True
+    )
+    check_path(
+        env,
+        rail,
+        (5, 6),  # east dead-end
+        0,  # west
+        (1, 3),  # north dead-end
+        True)
+
+    check_path(
+        env,
+        rail,
+        (1,3),  # east dead-end
+        2,  # south
+        (3,3),  # north dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (1,3),  # east dead-end
+        0,  # north
+        (3,3),  # north dead-end
+        True
+    )
+
+
+def test_path_not_exists(rendering=False):
+    rail, rail_map = make_simple_rail_unconnected()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # reset to initialize agents_static
+    env.reset()
+
+    check_path(
+        env,
+        rail,
+        (5, 6),  # south dead-end
+        0,  # north
+        (0, 3),  # north dead-end
+        False
+    )
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.render_env(show=True, show_observations=False)
+        input("Continue?")
-- 
GitLab