From 7ca1ed3b05e9d0ef189e4f114a84ce676b6e7e99 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 17 Aug 2019 11:12:05 -0400
Subject: [PATCH] Fixed invalid transitions: now all nodes are valid All
 starting positions are valid. Check for feasibility of map is not yet done.

---
 examples/training_example.py                  |  2 +-
 flatland/core/transition_map.py               | 56 +++++++++++++++++++
 flatland/envs/generators.py                   | 39 +++++++++----
 ...test_flatland_env_sparse_rail_generator.py | 11 ++--
 4 files changed, 89 insertions(+), 19 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index c038e7b4..60dc455f 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -77,7 +77,7 @@ for trials in range(1, n_trials + 1):
 
     score = 0
     # Run episode
-    for step in range(100):
+    for step in range(500):
         # Chose an action for each agent in the environment
         for a in range(env.get_num_agents()):
             action = agent.act(obs[a])
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 5e0f6cd7..018c8cd5 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -350,4 +350,60 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
+    def fix_neighbours(self, rcPos, check_this_cell=False):
+        """
+        Check validity of cell at rcPos = tuple(row, column)
+        Checks that:
+        - surrounding cells have inbound transitions for all the
+            outbound transitions of this cell.
+
+        These are NOT checked - see transition.is_valid:
+        - all transitions have the mirror transitions (N->E <=> W->S)
+        - Reverse transitions (N -> S) only exist for a dead-end
+        - a cell contains either no dead-ends or exactly one
+
+        Returns: True (valid) or False (invalid)
+        """
+        cell_transition = self.grid[tuple(rcPos)]
+
+        if check_this_cell:
+            if not self.transitions.is_valid(cell_transition):
+                return False
+
+        gDir2dRC = self.transitions.gDir2dRC  # [[-1,0] = N, [0,1]=E, etc]
+        grcPos = array(rcPos)
+        grcMax = self.grid.shape
+
+        binTrans = self.get_full_transitions(*rcPos)  # 16bit integer - all trans in/out
+        lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8)  # 2 x uint8
+        g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4)  # 4x4 x uint8 binary(0,1)
+        gDirOut = g2binTrans.any(axis=0)  # outbound directions as boolean array (4)
+        giDirOut = np.argwhere(gDirOut)[:, 0]  # valid outbound directions as array of int
+
+        # loop over available outbound directions (indices) for rcPos
+        for iDirOut in giDirOut:
+            gdRC = gDir2dRC[iDirOut]  # row,col increment
+            gPos2 = grcPos + gdRC  # next cell in that direction
+
+            # Check the adjacent cell is within bounds
+            # if not, then this transition is invalid!
+            if np.any(gPos2 < 0):
+                return False
+            if np.any(gPos2 >= grcMax):
+                return False
+
+            # Get the transitions out of gPos2, using iDirOut as the inbound direction
+            # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
+            t4Trans2 = self.get_transitions(*gPos2, iDirOut)
+            if any(t4Trans2):
+                continue
+            else:
+                self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
+                return False
+
+        return True
+
+
+def mirror(dir):
+    return (dir + 2) % 4
 # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 2441e3e7..a0eb7ca7 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -841,8 +841,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             to_close = True
             tries = 0
             while to_close:
-                x_tmp = 1 + np.random.randint(height - 1)
-                y_tmp = 1 + np.random.randint(width - 1)
+                x_tmp = 1 + np.random.randint(height - 2)
+                y_tmp = 1 + np.random.randint(width - 2)
                 to_close = False
                 for node_pos in node_positions:
                     if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
@@ -871,8 +871,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             available_nodes = available_nodes[np.argsort(node_dist)]
 
             # Set number of neighboring nodes
-            # np.random.randint(1, max_neigbours)
-
             if len(available_nodes) >= num_neighb:
                 connected_neighb_idx = available_nodes[
                                        0:2]  # np.random.choice(available_nodes, num_neighb, replace=False)
@@ -887,6 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             node_stack.pop(0)
 
         # Place train stations close to the node
+        # We currently place them uniformly distirbuted among all cities
+
         train_stations = [[] for i in range(num_cities)]
 
         for station in range(num_trainstations):
@@ -911,6 +911,11 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
                                           (station_x, station_y))
 
+        # Fix all nodes with illegal transition maps
+        for current_node in node_positions:
+            if not grid_map.cell_neighbours_valid(current_node):
+                grid_map.fix_neighbours(current_node)
+
         # Generate start and target node directory for all agents.
         # Assure that start and target are not in the same node
         agent_start_targets_nodes = []
@@ -924,20 +929,24 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
 
         # Assign agents to slots
         for agent_idx in range(num_agents):
-            av_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
-            av_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
-            start_node = np.random.choice(av_start_nodes)
-            target_node = np.random.choice(av_target_nodes)
+            avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
+            avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
+            start_node = np.random.choice(avail_start_nodes)
+            target_node = np.random.choice(avail_target_nodes)
+            tries = 0
             while target_node == start_node:
-                target_node = np.random.choice(av_target_nodes)
+                target_node = np.random.choice(avail_target_nodes)
+                tries += 1
+                # Test again with new start node if no pair is found (This code needs to be improved)
+                if tries > 10:
+                    start_node = np.random.choice(avail_start_nodes)
+
             node_available_start[start_node] -= 1
             node_available_target[target_node] -= 1
-            print(node_available_target, node_available_start)
 
             agent_start_targets_nodes.append((start_node, target_node))
 
         # Place agents and targets within available train stations
-
         agents_position = []
         agents_target = []
         agents_direction = []
@@ -956,7 +965,13 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 start = train_stations[current_start_node][start_station_idx]
             agents_position.append((start[0], start[1]))
             agents_target.append((target[0], target[1]))
-            agents_direction.append(0)
+
+            # Orient the agent correctly
+            for orientation in range(4):
+                transitions = grid_map.get_transitions(start[0], start[1], orientation)
+                if any(transitions) > 0:
+                    agents_direction.append(orientation)
+                    continue
             agent_idx += 1
 
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index abb03528..34e6c2b4 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -23,23 +23,22 @@ def test_realistic_rail_generator():
         env_renderer.close_window()
 
 def test_sparse_rail_generator():
-
-    env = RailEnv(width=50,
+    env = RailEnv(width=20,
                   height=50,
                   rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
                                                        num_intersections=3,  # Number of interesections in map
-                                                       num_trainstations=30,  # Number of possible start/targets on map
+                                                       num_trainstations=10,  # Number of possible start/targets on map
                                                        min_node_dist=10,  # Minimal distance of nodes
                                                        node_radius=2,  # Proximity of stations to city center
                                                        num_neighb=4,  # Number of connections to other cities
                                                        seed=15,  # Random seed
                                                        ),
-                  number_of_agents=20,
+                  number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-    time.sleep(2)
+    time.sleep(20)
 
     env_renderer.gl.save_image("flatalnd_2_0.png")
-    time.sleep(100)
+    time.sleep(1)
-- 
GitLab