diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 2aff675dfb7353aa253dad5811a572f259e4ef54..9550b71348c177adbaa30b4bd3e5307ba2e855ce 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -422,9 +422,9 @@ class GridTransitionMap(TransitionMap): # Check the adjacent cell is within bounds # if not, then this transition is invalid! if np.any(gPos2 < 0): - return False + continue if np.any(gPos2 >= grcMax): - return False + continue # Get the transitions out of gPos2, using iDirOut as the inbound direction # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2281282977d8c9d972f13526efa9d96abaf84a52..deaabd02fc1f842fae8c36d9056db632791e67a8 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -11,6 +11,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv @@ -132,7 +133,7 @@ class RailEnv(Environment): """ self.rail_generator = rail_generator - self.rail = None + self.rail: GridTransitionMap = None self.width = width self.height = height @@ -222,6 +223,12 @@ class RailEnv(Environment): if regen_rail or self.rail is None: self.rail = tRailAgents[0] self.height, self.width = self.rail.grid.shape + for r in range(self.height): + for c in range(self.width): + rcPos = (r, c) + check = self.rail.cell_neighbours_valid(rcPos, True) + if not check: + print("WARNING: Invalid grid at {} -> {}".format(rcPos, check)) if replace_agents: self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])