Commit 8216da27 authored by u214892's avatar u214892
Browse files

bugfix get-shortest-path

parent e4930fae
Pipeline #2138 canceled with stages
......@@ -18,19 +18,18 @@ class DistanceMap:
self.agents: List[EnvAgent] = agents
self.rail: Optional[GridTransitionMap] = None
"""
Set the distance map
"""
def set(self, distance_map: np.ndarray):
"""
Set the distance map
"""
self.distance_map = distance_map
"""
Get the distance map
"""
def get(self) -> np.ndarray:
def get(self) -> np.ndarray:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
......@@ -53,11 +52,12 @@ class DistanceMap:
return self.distance_map
"""
Reset the distance map
"""
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents: List[EnvAgent] = agents
self.rail = rail
......
......@@ -83,22 +83,26 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
return valid_actions
def get_shorts_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]:
def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]:
# TODO: do we need to support unreachable targets?
# TODO refactoring: unify with predictor (support agent.moving and max_depth)
shortest_paths = dict()
for a in distance_map.agents:
position = a.position
direction = a.direction
shortest_paths[a.handle] = []
distance = math.inf
while (position != a.target):
next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
best = math.inf
best_next_action = None
for next_action in next_actions:
if distance_map.get()[a.handle, position[0], position[1], direction] < best:
next_action_distance = distance_map.get()[a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction]
if next_action_distance < distance:
best_next_action = next_action
distance = next_action_distance
position = best_next_action.next_position
direction = best_next_action.next_direction
shortest_paths[a.handle].append(best_next_action)
return shortest_paths
......@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
from flatland.envs.rail_env_utils import get_shorts_paths
from flatland.envs.rail_env_utils import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -143,7 +143,7 @@ def test_shortest_path_predictor(rendering=False):
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
paths = get_shorts_paths(env.distance_map)[0]
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0),
......
import numpy as np
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
from flatland.envs.rail_env_utils import load_flatland_environment_from_file, get_shortest_paths
def test_get_shortest_paths():
env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
actual = get_shortest_paths(env.distance_map)
expected = {
0: [RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(1, 2), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(1, 3), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 3), next_direction=2),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 4), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 5), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 6), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 7), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 8), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 9), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 10), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 11), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 12), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 13), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 14), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 15), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 16), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 17), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 18), next_direction=1)],
1: [RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 17), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 16), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 16), next_direction=0),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 15), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 14), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 13), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 12), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 11), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 10), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 9), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 8), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 7), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 6), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 5), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 4), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 3), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 2), next_direction=3),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 1), next_direction=3)]
}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment