Commit e4930fae authored by u214892's avatar u214892
Browse files

get shortest paths from distance map

parent dd6ae4da
Pipeline #2137 passed with stages
in 39 minutes and 1 second
......@@ -21,12 +21,14 @@ class DistanceMap:
"""
Set the distance map
"""
def set(self, distance_map: np.ndarray):
self.distance_map = distance_map
"""
Get the distance map
"""
def get(self) -> np.ndarray:
if self.reset_was_called:
......@@ -54,9 +56,10 @@ class DistanceMap:
"""
Reset the distance map
"""
def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
self.reset_was_called = True
self.agents = agents
self.agents: List[EnvAgent] = agents
self.rail = rail
self.env_height = rail.height
self.env_width = rail.width
......@@ -110,7 +113,8 @@ class DistanceMap:
return max_distance
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
......@@ -134,8 +138,7 @@ class DistanceMap:
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
# is_valid = True
desired_movement_from_new_cell)
if is_valid:
"""
......
......@@ -4,7 +4,7 @@ Definition of the RailEnv environment.
# TODO: _ this is a global method --> utils or remove later
import warnings
from enum import IntEnum
from typing import List, Set, NamedTuple, Optional, Tuple, Dict
from typing import List, NamedTuple, Optional, Tuple, Dict
import msgpack
import msgpack_numpy as m
......@@ -20,7 +20,6 @@ from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
from flatland.utils.ordered_set import OrderedSet
m.patch()
......@@ -587,60 +586,6 @@ class RailEnv(Environment):
transition_valid = True
return new_direction, transition_valid
@staticmethod
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
rail: GridTransitionMap) -> Set[RailEnvNextAction]:
"""
Get the valid move actions (forward, left, right) for an agent.
Parameters
----------
agent_direction : Grid4TransitionsEnum
agent_position: Tuple[int,int]
rail : GridTransitionMap
Returns
-------
Set of `RailEnvNextAction` (tuples of (action,position,direction))
Possible move actions (forward,left,right) and the next position/direction they lead to.
It is not checked that the next cell is free.
"""
valid_actions: Set[RailEnvNextAction] = OrderedSet()
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
action = RailEnvActions.MOVE_FORWARD
elif new_direction == (agent_direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif new_direction == (agent_direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
def _get_observations(self):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
......
import math
from typing import Tuple, Set, Dict, List
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnv, RailEnvNextAction, RailEnvActions
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.utils.ordered_set import OrderedSet
def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None):
......@@ -17,3 +27,78 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
schedule_generator=schedule_from_file(file_name, load_from_package),
obs_builder_object=obs_builder_object)
return environment
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
rail: GridTransitionMap) -> Set[RailEnvNextAction]:
"""
Get the valid move actions (forward, left, right) for an agent.
Parameters
----------
agent_direction : Grid4TransitionsEnum
agent_position: Tuple[int,int]
rail : GridTransitionMap
Returns
-------
Set of `RailEnvNextAction` (tuples of (action,position,direction))
Possible move actions (forward,left,right) and the next position/direction they lead to.
It is not checked that the next cell is free.
"""
valid_actions: Set[RailEnvNextAction] = OrderedSet()
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
action = RailEnvActions.MOVE_FORWARD
elif new_direction == (agent_direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif new_direction == (agent_direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
def get_shorts_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]:
shortest_paths = dict()
for a in distance_map.agents:
position = a.position
direction = a.direction
shortest_paths[a.handle] = []
while (position != a.target):
next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
best = math.inf
best_next_action = None
for next_action in next_actions:
if distance_map.get()[a.handle, position[0], position[1], direction] < best:
best_next_action = next_action
position = best_next_action.next_position
direction = best_next_action.next_direction
shortest_paths[a.handle].append(best_next_action)
return shortest_paths
......@@ -7,7 +7,8 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
from flatland.envs.rail_env_utils import get_shorts_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -142,6 +143,14 @@ def test_shortest_path_predictor(rendering=False):
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
paths = get_shorts_paths(env.distance_map)[0]
assert paths == [
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), next_direction=1),
RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), next_direction=1)]
# extract the data
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment