Commit 9890abf5 authored by Erik Nygren's avatar Erik Nygren
Browse files

added new prediction which follows shortest path

updated test to handle new predictor
parent 725b98de
Pipeline #1087 failed with stages
in 9 minutes and 27 seconds
......@@ -2,7 +2,7 @@ import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
np.random.seed(1)
......@@ -11,7 +11,7 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment
#
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv())
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
......
......@@ -177,7 +177,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor:
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get()
self.predictions = self.predictor.get(self.distance_map)
for t in range(len(self.predictions[0])):
pos_list = []
......
......@@ -16,7 +16,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, handle=None):
def get(self, distancemap, handle=None):
"""
Called whenever predict is called on the environment.
......@@ -72,3 +72,92 @@ class DummyPredictorForRailEnv(PredictionBuilder):
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, distancemap, handle=None):
"""
Called whenever predict is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
agent_idx = 0
for agent in agents:
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
RailEnvActions.STOP_MOVING]
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction)
else:
for direct in range(4):
min_dist = np.inf
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_position = self._new_position(agent.position, new_direction)
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
action_done = True
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
agent_idx += 1
return prediction_dict
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
......@@ -74,7 +74,7 @@ def test_predictions():
env.agents[0].direction = 0
env.agents[0].target = (3., 0.)
predictions = env.obs_builder.predictor.get()
predictions = env.obs_builder.predictor.get(None)
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment