From 3aced3b7ab3f81fab92cdc3b4acf8ca08ce9add7 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 28 Aug 2019 10:15:30 +0200
Subject: [PATCH] #44 bugfix sparse generator

---
 flatland/core/transition_map.py | 4 ++--
 flatland/envs/rail_env.py       | 9 ++++++++-
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 2aff675d..9550b713 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -422,9 +422,9 @@ class GridTransitionMap(TransitionMap):
             # Check the adjacent cell is within bounds
             # if not, then this transition is invalid!
             if np.any(gPos2 < 0):
-                return False
+                continue
             if np.any(gPos2 >= grcMax):
-                return False
+                continue
 
             # Get the transitions out of gPos2, using iDirOut as the inbound direction
             # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 22812829..deaabd02 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -11,6 +11,7 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
@@ -132,7 +133,7 @@ class RailEnv(Environment):
         """
 
         self.rail_generator = rail_generator
-        self.rail = None
+        self.rail: GridTransitionMap = None
         self.width = width
         self.height = height
 
@@ -222,6 +223,12 @@ class RailEnv(Environment):
         if regen_rail or self.rail is None:
             self.rail = tRailAgents[0]
             self.height, self.width = self.rail.grid.shape
+            for r in range(self.height):
+                for c in range(self.width):
+                    rcPos = (r, c)
+                    check = self.rail.cell_neighbours_valid(rcPos, True)
+                    if not check:
+                        print("WARNING: Invalid grid at {} -> {}".format(rcPos, check))
 
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
-- 
GitLab