Forked from
Flatland / Flatland
1351 commits behind the upstream repository.
predictions.py 7.41 KiB
"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, handle=None):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self, max_depth=20):
# Initialize with depth 20
self.max_depth = max_depth
def get(self, handle=None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
distance_map = self.env.distance_map
assert distance_map is not None
prediction_dict = {}
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
if not agent.moving:
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist or no_dist_found:
min_dist = target_dist
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction))
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict