Skip to content
Snippets Groups Projects
Commit 6e367dd6 authored by u214892's avatar u214892
Browse files

bugfix #141: check_path_exists and tests

parent 28f5d0c0
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transitions import Transitions
......@@ -298,8 +299,7 @@ class GridTransitionMap(TransitionMap):
self.height = new_height
self.grid = new_grid
def is_dead_end(self,rcPos):
def is_dead_end(self, rcPos):
"""
Check if the cell is a dead-end
:param rcPos: tuple(row, column) with grid coordinate
......@@ -310,7 +310,30 @@ class GridTransitionMap(TransitionMap):
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
return nbits==1
return nbits == 1
def _path_exists(self, start, direction, end):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if node_position[0] == end[0] and node_position[1] == end[1]:
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions(node_position[0], node_position[1], node_direction)
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
"""
......
......@@ -131,29 +131,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return True
if node not in visited:
visited.add(node)
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
if rail.is_dead_end(node[0]):
stack.append((node[0], (node[1] + 2) % 4))
return False
valid_positions = []
for r in range(rail.height):
......@@ -194,6 +171,8 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
re_generate = False
for i in range(num_agents):
valid_movements = []
if rail.is_dead_end(agents_position[i]):
print(" dead_end", agents_position[i])
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
......@@ -204,14 +183,15 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0],
agents_target[i]):
if m[0] not in valid_starting_directions and rail._path_exists(new_position, m[0],
agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
update_agents[i] = 1
print("reset position for agents:",i, agents_position[i],agents_target[i])
print("reset position for agents:", i, agents_position[i], agents_target[i])
print(" dead_end", rail.is_dead_end(agents_position[i]))
re_generate = True
break
else:
agents_direction[i] = valid_starting_directions[
......
......@@ -81,6 +81,43 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
rail.grid = rail_map
return rail, rail_map
def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _ _ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] +
[[empty] * 3 + [dead_end_from_north] + [empty] * 6] +
[[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
......
from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
def test_grid4_get_transitions():
......@@ -43,4 +50,111 @@ def test_grid8_set_transitions():
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
# TODO GridTransitionMap
def check_path(env, rail, position, direction, target, expected, rendering=False):
agent = env.agents_static[0]
agent.position = position # south dead-end
agent.direction = direction # north
agent.target = target # east dead-end
agent.moving = True
# reset to set agents from agents_static
# env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
assert rail._path_exists(agent.position, agent.direction, agent.target) == expected
def test_path_exists(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
check_path(
env,
rail,
(5, 6), # north of south dead-end
0, # north
(3, 9), # east dead-end
True
)
check_path(
env,
rail,
(6, 6), # south dead-end
2, # south
(3, 9), # east dead-end
True
)
check_path(
env,
rail,
(3, 0), # east dead-end
3, # west
(0, 3), # north dead-end
True
)
check_path(
env,
rail,
(5, 6), # east dead-end
0, # west
(1, 3), # north dead-end
True)
check_path(
env,
rail,
(1,3), # east dead-end
2, # south
(3,3), # north dead-end
True
)
check_path(
env,
rail,
(1,3), # east dead-end
0, # north
(3,3), # north dead-end
True
)
def test_path_not_exists(rendering=False):
rail, rail_map = make_simple_rail_unconnected()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
check_path(
env,
rail,
(5, 6), # south dead-end
0, # north
(0, 3), # north dead-end
False
)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment