Skip to content
Snippets Groups Projects
Commit 7efb0816 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files
parents 9a4a1f0b dc586f23
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ case of multi-agent environments. ...@@ -13,6 +13,7 @@ case of multi-agent environments.
class PredictionBuilder: class PredictionBuilder:
""" """
PredictionBuilder base class. PredictionBuilder base class.
""" """
def __init__(self, max_depth: int = 20): def __init__(self, max_depth: int = 20):
...@@ -27,12 +28,15 @@ class PredictionBuilder: ...@@ -27,12 +28,15 @@ class PredictionBuilder:
""" """
pass pass
def get(self, handle=0): def get(self, custom_args=None, handle=0):
""" """
Called whenever predict is called on the environment. Called whenever get_many in the observation build is called.
Parameters Parameters
------- -------
custom_args: dict
Implementation-dependent custom arguments, see the sub-classes.
handle : int (optional) handle : int (optional)
Handle of the agent for which to compute the observation vector. Handle of the agent for which to compute the observation vector.
......
...@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a ...@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a
derived GridTransitions class, which allows for the specification of derived GridTransitions class, which allows for the specification of
possible transitions over a 2D grid. possible transitions over a 2D grid.
""" """
from enum import IntEnum
import numpy as np import numpy as np
...@@ -129,6 +130,16 @@ class Transitions: ...@@ -129,6 +130,16 @@ class Transitions:
""" """
raise NotImplementedError() raise NotImplementedError()
def get_direction_enum(self) -> IntEnum:
raise NotImplementedError()
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
class Grid4Transitions(Transitions): class Grid4Transitions(Transitions):
""" """
...@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions): ...@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions):
cell_transition = value cell_transition = value
return cell_transition return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid4TransitionsEnum
class Grid8TransitionsEnum(IntEnum):
NORTH = 0
NORTH_EAST = 1
EAST = 2
SOUTH_EAST = 3
SOUTH = 4
SOUTH_WEST = 5
WEST = 6
NORTH_WEST = 7
class Grid8Transitions(Transitions): class Grid8Transitions(Transitions):
""" """
...@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions): ...@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions):
return cell_transition return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid8TransitionsEnum
class RailEnvTransitions(Grid4Transitions): class RailEnvTransitions(Grid4Transitions):
""" """
......
...@@ -7,6 +7,8 @@ a GridTransitionMap object. ...@@ -7,6 +7,8 @@ a GridTransitionMap object.
import numpy as np import numpy as np
from flatland.core.transitions import Grid4TransitionsEnum
def get_direction(pos1, pos2): def get_direction(pos1, pos2):
""" """
...@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2): ...@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
def get_new_position(position, movement): def get_new_position(position, movement):
if movement == 0: # NORTH """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1]) return (position[0] - 1, position[1])
elif movement == 1: # EAST elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1) return (position[0], position[1] + 1)
elif movement == 2: # SOUTH elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1]) return (position[0] + 1, position[1])
elif movement == 3: # WEST elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1) return (position[0], position[1] - 1)
......
...@@ -6,6 +6,7 @@ from collections import deque ...@@ -6,6 +6,7 @@ from collections import deque
import numpy as np import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.env_utils import coordinate_to_position from flatland.envs.env_utils import coordinate_to_position
...@@ -48,16 +49,19 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -48,16 +49,19 @@ class TreeObsForRailEnv(ObservationBuilder):
self.agents_previous_reset = agents self.agents_previous_reset = agents
if compute_distance_map: if compute_distance_map:
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, self._compute_distance_map()
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] def _compute_distance_map(self):
agents = self.env.agents
# Update local lookup table for all agents' target locations nAgents = len(agents)
self.location_has_target = {tuple(agent.target): 1 for agent in agents} self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr): def _distance_map_walker(self, position, target_nr):
""" """
...@@ -159,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -159,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
Utility function that converts a compass movement over a 2D grid to new positions (r, c). Utility function that converts a compass movement over a 2D grid to new positions (r, c).
""" """
if movement == 0: # NORTH if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1]) return (position[0] - 1, position[1])
elif movement == 1: # EAST elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1) return (position[0], position[1] + 1)
elif movement == 2: # SOUTH elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1]) return (position[0] + 1, position[1])
elif movement == 3: # WEST elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1) return (position[0], position[1] - 1)
def get_many(self, handles=[]): def get_many(self, handles=[]):
...@@ -177,7 +181,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -177,7 +181,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor: if self.predictor:
self.predicted_pos = {} self.predicted_pos = {}
self.predicted_dir = {} self.predicted_dir = {}
self.predictions = self.predictor.get(self.distance_map) self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
for t in range(len(self.predictions[0])): for t in range(len(self.predictions[0])):
pos_list = [] pos_list = []
dir_list = [] dir_list = []
...@@ -796,8 +800,3 @@ class LocalObsForRailEnv(ObservationBuilder): ...@@ -796,8 +800,3 @@ class LocalObsForRailEnv(ObservationBuilder):
direction = self._get_one_hot_for_agent_direction(agent) direction = self._get_one_hot_for_agent_direction(agent)
return local_rail_obs, obs_map_state, obs_other_agents_state, direction return local_rail_obs, obs_map_state, obs_other_agents_state, direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# Returns a local observation around the given agent
# """
...@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. ...@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import numpy as np import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.env_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_env import RailEnvActions
...@@ -16,24 +17,28 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -16,24 +17,28 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action. The prediction acts as if no other agent is in the environment and always takes the forward action.
""" """
def get(self, distancemap, handle=None): def get(self, custom_args=None, handle=None):
""" """
Called whenever predict is called on the environment. Called whenever get_many in the observation build is called.
Parameters Parameters
------- -------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional) handle : int (optional)
Handle of the agent for which to compute the observation vector. Handle of the agent for which to compute the observation vector.
Returns Returns
------- -------
function np.array
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset - time_offset
- position axis 0 - position axis 0
- position axis 1 - position axis 1
- direction - direction
- action taken to come here - action taken to come here
The prediction at 0 is the current position, direction etc.
""" """
agents = self.env.agents agents = self.env.agents
if handle: if handle:
...@@ -46,13 +51,12 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -46,13 +51,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
_agent_initial_position = agent.position _agent_initial_position = agent.position
_agent_initial_direction = agent.direction _agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1): for index in range(1, self.max_depth + 1):
action_done = False action_done = False
# if we're at the target, stop moving... # if we're at the target, stop moving...
if agent.position == agent.target: if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
RailEnvActions.STOP_MOVING]
continue continue
for action in action_priorities: for action in action_priorities:
...@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
# performed # performed
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, action] prediction[index] = [index, *new_position, new_direction, action]
action_done = True action_done = True
break break
if not action_done: if not action_done:
...@@ -76,90 +80,95 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -76,90 +80,95 @@ class DummyPredictorForRailEnv(PredictionBuilder):
class ShortestPathPredictorForRailEnv(PredictionBuilder): class ShortestPathPredictorForRailEnv(PredictionBuilder):
""" """
DummyPredictorForRailEnv object. ShortestPathPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment. This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action. The prediction acts as if no other agent is in the environment and always takes the forward action.
""" """
def get(self, distancemap, handle=None): def get(self, custom_args=None, handle=None):
""" """
Called whenever predict is called on the environment. Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters Parameters
------- -------
custom_args: dict
- distance_map : dict
handle : int (optional) handle : int (optional)
Handle of the agent for which to compute the observation vector. Handle of the agent for which to compute the observation vector.
Returns Returns
------- -------
function np.array
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset - time_offset
- position axis 0 - position axis 0
- position axis 1 - position axis 1
- direction - direction
- action taken to come here - action taken to come here
The prediction at 0 is the current position, direction etc.
""" """
agents = self.env.agents agents = self.env.agents
if handle: if handle:
agents = [self.env.agents[handle]] agents = [self.env.agents[handle]]
assert custom_args
distance_map = custom_args.get('distance_map')
assert distance_map is not None
prediction_dict = {} prediction_dict = {}
agent_idx = 0
for agent in agents: for agent in agents:
_agent_initial_position = agent.position _agent_initial_position = agent.position
_agent_initial_direction = agent.direction _agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1): for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving... # if we're at the target, stop moving...
if agent.position == agent.target: if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
RailEnvActions.STOP_MOVING]
continue continue
if not agent.moving: if not agent.moving:
prediction[index] = [index, agent.position[0], agent.position[1], agent.direction, prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
RailEnvActions.STOP_MOVING]
continue continue
# Take shortest possible path # Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1: if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions) new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1: elif np.sum(cell_transitions) > 1:
min_dist = np.inf min_dist = np.inf
for direct in range(4): for direction in range(4):
if cell_transitions[direct] == 1: if cell_transitions[direction] == 1:
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
if target_dist < min_dist: if target_dist < min_dist:
min_dist = target_dist min_dist = target_dist
new_direction = direct new_direction = direction
new_position = self._new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
else:
raise Exception("No transition possible {}".format(cell_transitions))
# which action to take for the transition?
action = None
for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
_, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent)
if np.array_equal(_new_position, new_position):
action = _action
break
assert action is not None
# update the agent's position and direction
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
action_done = True # prediction is ready
if not action_done: prediction[index] = [index, *new_position, new_direction, action]
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position agent.position = _agent_initial_position
agent.direction = _agent_initial_direction agent.direction = _agent_initial_direction
agent_idx += 1
return prediction_dict return prediction_dict
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
...@@ -80,11 +80,3 @@ def test_global_obs(): ...@@ -80,11 +80,3 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned # If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell # places the agent on an empty cell
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
def main():
test_global_obs()
if __name__ == "__main__":
main()
...@@ -4,15 +4,18 @@ ...@@ -4,15 +4,18 @@
import numpy as np import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
"""Test predictions for `flatland` package.""" """Test predictions for `flatland` package."""
def test_predictions(): def make_simple_rail():
# We instantiate a very simple rail network on a 7x10 grid: # We instantiate a very simple rail network on a 7x10 grid:
# | # |
# | # |
...@@ -22,7 +25,6 @@ def test_predictions(): ...@@ -22,7 +25,6 @@ def test_predictions():
# | # |
# | # |
# | # |
cells = [int('0000000000000000', 2), # empty cell - Case 0 cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch int('1001001000100000', 2), # Case 2 - simple switch
...@@ -31,22 +33,17 @@ def test_predictions(): ...@@ -31,22 +33,17 @@ def test_predictions():
int('1100110000110011', 2), # Case 5 - double slip switch int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([]) transitions = Grid4Transitions([])
empty = cells[0] empty = cells[0]
dead_end_from_south = cells[7] dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1] vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90) horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6] double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition( double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180) double_switch_south_horizontal_straight, 180)
rail_map = np.array( rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
...@@ -56,26 +53,36 @@ def test_predictions(): ...@@ -56,26 +53,36 @@ def test_predictions():
[horizontal_straight] * 2 + [dead_end_from_west]] + [horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1], rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions) height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map rail.grid = rail_map
return rail, rail_map
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail), rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
) )
env.reset() env.reset()
# set initial position and direction for testing... # set initial position and direction for testing...
env.agents[0].position = (5, 6) env.agents[0].position = (5, 6)
env.agents[0].direction = 0 env.agents[0].direction = 0
env.agents[0].target = (3., 0.) env.agents[0].target = (3, 0)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
# test assertions
predictions = env.obs_builder.predictor.get(None) predictions = env.obs_builder.predictor.get(None)
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
...@@ -139,9 +146,149 @@ def test_predictions(): ...@@ -139,9 +146,149 @@ def test_predictions():
assert np.array_equal(actions, expected_actions) assert np.array_equal(actions, expected_actions)
def main(): def test_shortest_path_predictor(rendering=False):
test_predictions() rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
agent = env.agents[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
input("Continue?")
agent = env.agents[0]
assert agent.position == (5, 6)
assert agent.direction == 0
assert agent.target == (3, 9)
assert agent.moving
env.obs_builder._compute_distance_map()
distance_map = env.obs_builder.distance_map
assert distance_map[agent.handle, agent.position[0], agent.position[
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
# test assertions
env.obs_builder.get_many()
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
expected_positions = [
[5, 6],
[4, 6],
[3, 6],
[3, 7],
[3, 8],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
]
expected_directions = [
[Grid4TransitionsEnum.NORTH], # next is [5,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [4,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [3,6] heading north
[Grid4TransitionsEnum.EAST], # next is [3,7] heading east
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
]
expected_time_offsets = np.array([
[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.],
[20.],
])
expected_actions = np.array([
[RailEnvActions.DO_NOTHING], # next [5,6]
[RailEnvActions.MOVE_FORWARD], # next [4,6]
[RailEnvActions.MOVE_FORWARD], # next [3,6]
[RailEnvActions.MOVE_RIGHT], # next [3,7]
[RailEnvActions.MOVE_FORWARD], # next [3,8]
[RailEnvActions.MOVE_FORWARD], # next [3,9]
[RailEnvActions.STOP_MOVING], # at [3,9] == target
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
])
if __name__ == "__main__": assert np.array_equal(positions, expected_positions), \
main() "positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions), \
"actions {}, expected {}".format(actions, expected_actions)
...@@ -204,8 +204,3 @@ def test_dead_end(): ...@@ -204,8 +204,3 @@ def test_dead_end():
rail_env.reset() rail_env.reset()
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
if __name__ == "__main__":
test_rail_environment_single_agent()
test_dead_end()
...@@ -4,7 +4,3 @@ from examples.play_model import main ...@@ -4,7 +4,3 @@ from examples.play_model import main
def test_main(): def test_main():
main(render=True, n_steps=20, n_trials=2, sGL="PIL") main(render=True, n_steps=20, n_trials=2, sGL="PIL")
main(render=True, n_steps=20, n_trials=2, sGL="PILSVG") main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
if __name__ == "__main__":
test_main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment