Commit 15884968 authored by u214892's avatar u214892
Browse files

66 shortest-path-predictor: cleanup and unit test

parent ea12d2ed
Pipeline #1113 failed with stage
in 9 minutes and 26 seconds
......@@ -13,6 +13,7 @@ case of multi-agent environments.
class PredictionBuilder:
"""
PredictionBuilder base class.
"""
def __init__(self, max_depth: int = 20):
......@@ -27,12 +28,15 @@ class PredictionBuilder:
"""
pass
def get(self, handle=0):
def get(self, custom_args=None, handle=0):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Implementation-dependent custom arguments, see the sub-classes.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
......
......@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a
derived GridTransitions class, which allows for the specification of
possible transitions over a 2D grid.
"""
from enum import IntEnum
import numpy as np
......@@ -129,6 +130,16 @@ class Transitions:
"""
raise NotImplementedError()
def get_direction_enum(self) -> IntEnum:
raise NotImplementedError()
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
class Grid4Transitions(Transitions):
"""
......@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions):
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid4TransitionsEnum
class Grid8TransitionsEnum(IntEnum):
NORTH = 0
NORTH_EAST = 1
EAST = 2
SOUTH_EAST = 3
SOUTH = 4
SOUTH_WEST = 5
WEST = 6
NORTH_WEST = 7
class Grid8Transitions(Transitions):
"""
......@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions):
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid8TransitionsEnum
class RailEnvTransitions(Grid4Transitions):
"""
......
......@@ -48,16 +48,19 @@ class TreeObsForRailEnv(ObservationBuilder):
self.agents_previous_reset = agents
if compute_distance_map:
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self._compute_distance_map()
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _compute_distance_map(self):
agents = self.env.agents
nAgents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr):
"""
......@@ -177,7 +180,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor:
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get(self.distance_map)
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
for t in range(len(self.predictions[0])):
pos_list = []
dir_list = []
......@@ -796,8 +799,3 @@ class LocalObsForRailEnv(ObservationBuilder):
direction = self._get_one_hot_for_agent_direction(agent)
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# Returns a local observation around the given agent
# """
......@@ -16,24 +16,28 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, distancemap, handle=None):
def get(self, custom_args=None, handle=None):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
......@@ -46,12 +50,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
prediction[index] = [index, *agent.target, agent.direction,
RailEnvActions.STOP_MOVING]
continue
......@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
......@@ -76,51 +80,55 @@ class DummyPredictorForRailEnv(PredictionBuilder):
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
ShortestPathPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, distancemap, handle=None):
def get(self, custom_args=None, handle=None):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args
distance_map = custom_args.get('distance_map')
assert distance_map is not None
prediction_dict = {}
agent_idx = 0
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
RailEnvActions.STOP_MOVING]
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
if not agent.moving:
prediction[index] = [index, agent.position[0], agent.position[1], agent.direction,
RailEnvActions.STOP_MOVING]
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
......@@ -130,24 +138,25 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_position = self._new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
min_dist = np.inf
for direct in range(4):
if cell_transitions[direct] == 1:
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
for direction in range(4):
if cell_transitions[direction] == 1:
target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_direction = direction
new_position = self._new_position(agent.position, new_direction)
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD]
action_done = True
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
agent_idx += 1
return prediction_dict
......
......@@ -80,11 +80,3 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
def main():
test_global_obs()
if __name__ == "__main__":
main()
......@@ -4,15 +4,18 @@
import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
"""Test predictions for `flatland` package."""
def test_predictions():
def make_simple_rail():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
......@@ -22,7 +25,6 @@ def test_predictions():
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
......@@ -31,22 +33,17 @@ def test_predictions():
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
......@@ -56,26 +53,36 @@ def test_predictions():
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
env.reset()
# set initial position and direction for testing...
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
env.agents[0].target = (3., 0.)
env.agents[0].target = (3, 0)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
# test assertions
predictions = env.obs_builder.predictor.get(None)
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
......@@ -139,9 +146,149 @@ def test_predictions():
assert np.array_equal(actions, expected_actions)
def main():
test_predictions()
def test_shortest_path_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
agent = env.agents[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
# input("Continue?")
agent = env.agents[0]
assert agent.position == (5, 6)
assert agent.direction == 0
assert agent.target == (3, 9)
assert agent.moving
env.obs_builder._compute_distance_map()
distance_map = env.obs_builder.distance_map
assert distance_map[agent.handle, agent.position[0], agent.position[
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
# test assertions
env.obs_builder.get_many()
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
expected_positions = [
[5, 6],
[4, 6],
[3, 6],
[3, 7],
[3, 8],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
]
expected_directions = [
[Grid4TransitionsEnum.NORTH], # next is [5,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [4,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [3,6] heading north
[Grid4TransitionsEnum.EAST], # next is [3,7] heading east
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
]
expected_time_offsets = np.array([
[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.],
[20.],
])
expected_actions = np.array([
[RailEnvActions.DO_NOTHING], # next [5,6]
[RailEnvActions.MOVE_FORWARD], # next [4,6]
[RailEnvActions.MOVE_FORWARD], # next [3,6]
[RailEnvActions.MOVE_RIGHT], # next [3,7]
[RailEnvActions.MOVE_FORWARD], # next [3,8]
[RailEnvActions.MOVE_FORWARD], # next [3,9]
[RailEnvActions.STOP_MOVING], # at [3,9] == target
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
])
if __name__ == "__main__":
main()
assert np.array_equal(positions, expected_positions), \
"positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions), \
"actions {}, expected {}".format(actions, expected_actions)
......@@ -204,8 +204,3 @@ def test_dead_end():
rail_env.reset()
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
if __name__ == "__main__":
test_rail_environment_single_agent()
test_dead_end()
......@@ -4,7 +4,3 @@ from examples.play_model import main
def test_main():
main(render=True, n_steps=20, n_trials=2, sGL="PIL")
main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
if __name__ == "__main__":
test_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment