Commit f6e81e1a authored by Christian Eichenberger's avatar Christian Eichenberger 🏸 Committed by spiglerg
Browse files

Resolve "shortest-path-predictor"

parent a154294c
......@@ -13,6 +13,7 @@ case of multi-agent environments.
class PredictionBuilder:
"""
PredictionBuilder base class.
"""
def __init__(self, max_depth: int = 20):
......@@ -27,12 +28,15 @@ class PredictionBuilder:
"""
pass
def get(self, handle=0):
def get(self, custom_args=None, handle=0):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Implementation-dependent custom arguments, see the sub-classes.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
......
......@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a
derived GridTransitions class, which allows for the specification of
possible transitions over a 2D grid.
"""
from enum import IntEnum
import numpy as np
......@@ -129,6 +130,16 @@ class Transitions:
"""
raise NotImplementedError()
def get_direction_enum(self) -> IntEnum:
raise NotImplementedError()
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
class Grid4Transitions(Transitions):
"""
......@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions):
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid4TransitionsEnum
class Grid8TransitionsEnum(IntEnum):
NORTH = 0
NORTH_EAST = 1
EAST = 2
SOUTH_EAST = 3
SOUTH = 4
SOUTH_WEST = 5
WEST = 6
NORTH_WEST = 7
class Grid8Transitions(Transitions):
"""
......@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions):
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid8TransitionsEnum
class RailEnvTransitions(Grid4Transitions):
"""
......
......@@ -7,6 +7,8 @@ a GridTransitionMap object.
import numpy as np
from flatland.core.transitions import Grid4TransitionsEnum
def get_direction(pos1, pos2):
"""
......@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
def get_new_position(position, movement):
if movement == 0: # NORTH
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == 1: # EAST
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == 3: # WEST
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
......
......@@ -6,6 +6,7 @@ from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.env_utils import coordinate_to_position
......@@ -48,16 +49,19 @@ class TreeObsForRailEnv(ObservationBuilder):
self.agents_previous_reset = agents
if compute_distance_map:
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self._compute_distance_map()
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _compute_distance_map(self):
agents = self.env.agents
nAgents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr):
"""
......@@ -159,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == 1: # EAST
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == 3: # WEST
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_many(self, handles=[]):
......@@ -177,7 +181,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor:
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get(self.distance_map)
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
for t in range(len(self.predictions[0])):
pos_list = []
dir_list = []
......@@ -796,8 +800,3 @@ class LocalObsForRailEnv(ObservationBuilder):
direction = self._get_one_hot_for_agent_direction(agent)
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# Returns a local observation around the given agent
# """
......@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.env_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
......@@ -16,24 +17,28 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, distancemap, handle=None):
def get(self, custom_args=None, handle=None):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
......@@ -46,13 +51,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
RailEnvActions.STOP_MOVING]
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
......@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
......@@ -76,90 +80,95 @@ class DummyPredictorForRailEnv(PredictionBuilder):
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
ShortestPathPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, distancemap, handle=None):
def get(self, custom_args=None, handle=None):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args
distance_map = custom_args.get('distance_map')
assert distance_map is not None
prediction_dict = {}
agent_idx = 0
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
RailEnvActions.STOP_MOVING]
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
if not agent.moving:
prediction[index] = [index, agent.position[0], agent.position[1], agent.direction,
RailEnvActions.STOP_MOVING]
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
min_dist = np.inf
for direct in range(4):
if cell_transitions[direct] == 1:
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
for direction in range(4):
if cell_transitions[direction] == 1:
target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_position = self._new_position(agent.position, new_direction)
new_direction = direction
new_position = get_new_position(agent.position, new_direction)
else:
raise Exception("No transition possible {}".format(cell_transitions))
# which action to take for the transition?
action = None
for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
_, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent)
if np.array_equal(_new_position, new_position):
action = _action
break
assert action is not None
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
action_done = True
if not action_done:
raise Exception("Cannot move further. Something is wrong")
# prediction is ready
prediction[index] = [index, *new_position, new_direction, action]
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
agent_idx += 1
return prediction_dict
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
......@@ -80,11 +80,3 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
def main():
test_global_obs()
if __name__ == "__main__":
main()
......@@ -4,15 +4,18 @@
import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
"""Test predictions for `flatland` package."""
def test_predictions():
def make_simple_rail():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
......@@ -22,7 +25,6 @@ def test_predictions():
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
......@@ -31,22 +33,17 @@ def test_predictions():
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
......@@ -56,26 +53,36 @@ def test_predictions():
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
env.reset()
# set initial position and direction for testing...
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
env.agents[0].target = (3., 0.)
env.agents[0].target = (3, 0)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
# test assertions
predictions = env.obs_builder.predictor.get(None)
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
......@@ -139,9 +146,149 @@ def test_predictions():
assert np.array_equal(actions, expected_actions)
def main():
test_predictions()
def test_shortest_path_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
agent = env.agents[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
agent = env.agents[0]
assert agent.position == (5, 6)
assert agent.direction == 0
assert agent.target == (3, 9)
assert agent.moving
env.obs_builder._compute_distance_map()
distance_map = env.obs_builder.distance_map
assert distance_map[agent.handle, agent.position[0], agent.position[
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
# test assertions
env.obs_builder.get_many()
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
expected_positions = [
[5, 6],
[4, 6],
[3, 6],
[3, 7],
[3, 8],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
]
expected_directions = [
[Grid4TransitionsEnum.NORTH], # next is [5,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [4,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [3,6] heading north
[Grid4TransitionsEnum.EAST], # next is [3,7] heading east
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
]
expected_time_offsets = np.array([
[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.],
[20.],
])
expected_actions = np.array([
[RailEnvActions.DO_NOTHING], # next [5,6]