Commit 96b8f258 authored by Christian Eichenberger's avatar Christian Eichenberger 🏸
Browse files

Merge branch 'get-shortest-paths-from-distance-map' into 'master'

get shortest paths from distance map

See merge request !202
parents 60da3eb9 18787d4e
Pipeline #2179 passed with stages
in 84 minutes and 2 seconds
......@@ -18,27 +18,21 @@ class DistanceMap:
self.agents: List[EnvAgent] = agents
self.rail: Optional[GridTransitionMap] = None
"""
Set the distance map
"""
def set(self, distance_map: np.ndarray):
"""
Set the distance map
"""
self.distance_map = distance_map
"""
Get the distance map
"""
def get(self) -> np.ndarray:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
nb_agents = len(self.agents)
compute_distance_map = True
if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation):
compute_distance_map = False
for i in range(nb_agents):
if self.agents[i].target != self.agents_previous_computation[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_computation is None and self.distance_map is not None:
compute_distance_map = False
......@@ -51,12 +45,12 @@ class DistanceMap:
return self.distance_map
"""
Reset the distance map
"""
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents = agents
self.agents: List[EnvAgent] = agents
self.rail = rail
self.env_height = rail.height
self.env_width = rail.width
......@@ -110,7 +104,8 @@ class DistanceMap:
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
......@@ -134,8 +129,7 @@ class DistanceMap:
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
# is_valid = True
desired_movement_from_new_cell)
if is_valid:
"""
......
......@@ -5,8 +5,9 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
......@@ -59,7 +60,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
......@@ -92,6 +93,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Does not take into account future positions of other agents!
If there is no shortest path, the agent just stands still and stops moving.
Parameters
----------
......@@ -106,14 +110,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
- position axis 0
- position axis 1
- direction
- action taken to come here
- action taken to come here (not implemented yet)
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
distance_map = self.env.distance_map
assert distance_map is not None
distance_map: DistanceMap = self.env.distance_map
shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
prediction_dict = {}
for agent in agents:
......@@ -123,52 +128,35 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
shortest_path = shortest_paths[agent.handle]
# if there is a shortest path, remove the initial position
if shortest_path:
shortest_path = shortest_path[1:]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
if not agent.moving:
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
# if we're at the target or not moving, stop moving until max_depth is reached
if new_position == agent.target or not agent.moving or not shortest_path:
prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
visited.add((*new_position, agent.direction))
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist or no_dist_found:
min_dist = target_dist
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
if index % times_per_cell == 0:
new_position = shortest_path[0].position
new_direction = shortest_path[0].direction
shortest_path = shortest_path[1:]
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction))
visited.add((*new_position, new_direction))
# TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
......@@ -4,7 +4,7 @@ Definition of the RailEnv environment.
# TODO: _ this is a global method --> utils or remove later
import warnings
from enum import IntEnum
from typing import List, Set, NamedTuple, Optional, Tuple, Dict
from typing import List, NamedTuple, Optional, Tuple, Dict
import msgpack
import msgpack_numpy as m
......@@ -20,7 +20,6 @@ from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
from flatland.utils.ordered_set import OrderedSet
m.patch()
......@@ -241,9 +240,6 @@ class RailEnv(Environment):
# can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
if optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
if regen_rail or self.rail is None:
self.rail = rail
self.height, self.width = self.rail.grid.shape
......@@ -253,6 +249,11 @@ class RailEnv(Environment):
check = self.rail.cell_neighbours_valid(rc_pos, True)
if not check:
warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
# TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
# hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by
# rail_from_file!!!
elif optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
if replace_agents:
agents_hints = None
......@@ -587,60 +588,6 @@ class RailEnv(Environment):
transition_valid = True
return new_direction, transition_valid
@staticmethod
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
rail: GridTransitionMap) -> Set[RailEnvNextAction]:
"""
Get the valid move actions (forward, left, right) for an agent.
Parameters
----------
agent_direction : Grid4TransitionsEnum
agent_position: Tuple[int,int]
rail : GridTransitionMap
Returns
-------
Set of `RailEnvNextAction` (tuples of (action,position,direction))
Possible move actions (forward,left,right) and the next position/direction they lead to.
It is not checked that the next cell is free.
"""
valid_actions: Set[RailEnvNextAction] = OrderedSet()
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
action = RailEnvActions.MOVE_FORWARD
elif new_direction == (agent_direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif new_direction == (agent_direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
def _get_observations(self):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
......
import math
from typing import Dict, List, Optional, NamedTuple, Tuple, Set
import matplotlib.pyplot as plt
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
from flatland.utils.ordered_set import OrderedSet
WalkingElement = \
NamedTuple('WalkingElement',
[('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
rail: GridTransitionMap) -> Set[RailEnvNextAction]:
"""
Get the valid move actions (forward, left, right) for an agent.
Parameters
----------
agent_direction : Grid4TransitionsEnum
agent_position: Tuple[int,int]
rail : GridTransitionMap
Returns
-------
Set of `RailEnvNextAction` (tuples of (action,position,direction))
Possible move actions (forward,left,right) and the next position/direction they lead to.
It is not checked that the next cell is free.
"""
valid_actions: Set[RailEnvNextAction] = OrderedSet()
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
action = RailEnvActions.MOVE_FORWARD
elif new_direction == (agent_direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif new_direction == (agent_direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None) \
-> Dict[int, Optional[List[WalkingElement]]]:
"""
Computes the shortest path for each agent to its target and the action to be taken to do so.
The paths are derived from a `DistanceMap`.
If there is no path (rail disconnected), the path is given as None.
The agent state (moving or not) and its speed are not taken into account
Parameters
----------
distance_map
Returns
-------
Dict[int, Optional[List[WalkingElement]]]
"""
shortest_paths = dict()
def _shortest_path_for_agent(agent):
position = agent.position
direction = agent.direction
shortest_paths[agent.handle] = []
distance = math.inf
depth = 0
while (position != agent.target and (max_depth is None or depth < max_depth)):
next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
best_next_action = None
for next_action in next_actions:
next_action_distance = distance_map.get()[
agent.handle, next_action.next_position[0], next_action.next_position[
1], next_action.next_direction]
if next_action_distance < distance:
best_next_action = next_action
distance = next_action_distance
shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action))
depth += 1
# if there is no way to continue, the rail must be disconnected!
# (or distance map is incorrect)
if best_next_action is None:
shortest_paths[agent.handle] = None
return
position = best_next_action.next_position
direction = best_next_action.next_direction
if max_depth is None or depth < max_depth:
shortest_paths[agent.handle].append(
WalkingElement(position, direction,
RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
for agent in distance_map.agents:
_shortest_path_for_agent(agent)
return shortest_paths
def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
if agent_handle >= distance_map.get().shape[0]:
print("Error: agent_handle cannot be larger than actual number of agents")
return
# take min value of all 4 directions
min_distance_map = np.min(distance_map.get(), axis=3)
plt.imshow(min_distance_map[agent_handle][:][:])
plt.show()
import numpy as np
import matplotlib.pyplot as plt
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -21,13 +17,3 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
schedule_generator=schedule_from_file(file_name, load_from_package),
obs_builder_object=obs_builder_object)
return environment
def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
if agent_handle >= distance_map.get().shape[0]:
print("Error: agent_handle cannot be larger than actual number of agents")
return
# take min value of all 4 directions
min_distance_map = np.min(distance_map.get(), axis=3)
plt.imshow(min_distance_map[agent_handle][:][:])
plt.show()
......@@ -174,7 +174,6 @@ class PILGL(GraphicsLayer):
self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer)
xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
self.text(*xyPixLeftTop, strText, layer)
......@@ -606,7 +605,6 @@ class PILSVG(PILGL):
self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
if show_debug:
print("Call text:")
self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
def set_cell_occupied(self, agent_idx, row, col):
......
......@@ -45,6 +45,46 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
return rail, rail_map
def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _\ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[dead_end_from_west] + [dead_end_from_east] + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
......
......@@ -7,7 +7,8 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -142,6 +143,21 @@ def test_shortest_path_predictor(rendering=False):
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
WalkingElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6),
next_direction=0)),
WalkingElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6),
next_direction=0)),
WalkingElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7),
next_direction=1)),
WalkingElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8),
next_direction=1)),
WalkingElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9),
next_direction=1)),
WalkingElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9),
next_direction=1))]
# extract the data
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
......@@ -220,12 +236,13 @@ def test_shortest_path_predictor(rendering=False):
[20.],
])
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(positions, expected_positions), \
"positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
def test_shortest_path_predictor_conflicts(rendering=False):
......
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_disconnected_simple_rail
def test_get_shortest_paths_unreachable():
rail, rail_map = make_disconnected_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
# set the initial position
agent = env.agents_static[0]
agent.position = (3, 1) # west dead-end
agent.direction = Grid4TransitionsEnum.WEST