Skip to content
Snippets Groups Projects
Commit 9890abf5 authored by Erik Nygren's avatar Erik Nygren
Browse files

added new prediction which follows shortest path

updated test to handle new predictor
parent 725b98de
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
np.random.seed(1) np.random.seed(1)
...@@ -11,7 +11,7 @@ np.random.seed(1) ...@@ -11,7 +11,7 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment # Training on simple small tasks is the best way to get familiar with the environment
# #
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
......
...@@ -177,7 +177,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -177,7 +177,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor: if self.predictor:
self.predicted_pos = {} self.predicted_pos = {}
self.predicted_dir = {} self.predicted_dir = {}
self.predictions = self.predictor.get() self.predictions = self.predictor.get(self.distance_map)
for t in range(len(self.predictions[0])): for t in range(len(self.predictions[0])):
pos_list = [] pos_list = []
......
...@@ -16,7 +16,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -16,7 +16,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action. The prediction acts as if no other agent is in the environment and always takes the forward action.
""" """
def get(self, handle=None): def get(self, distancemap, handle=None):
""" """
Called whenever predict is called on the environment. Called whenever predict is called on the environment.
...@@ -72,3 +72,92 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -72,3 +72,92 @@ class DummyPredictorForRailEnv(PredictionBuilder):
agent.position = _agent_initial_position agent.position = _agent_initial_position
agent.direction = _agent_initial_direction agent.direction = _agent_initial_direction
return prediction_dict return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, distancemap, handle=None):
"""
Called whenever predict is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
agent_idx = 0
for agent in agents:
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
RailEnvActions.STOP_MOVING]
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction)
else:
for direct in range(4):
min_dist = np.inf
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_position = self._new_position(agent.position, new_direction)
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
action_done = True
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
agent_idx += 1
return prediction_dict
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
...@@ -74,7 +74,7 @@ def test_predictions(): ...@@ -74,7 +74,7 @@ def test_predictions():
env.agents[0].direction = 0 env.agents[0].direction = 0
env.agents[0].target = (3., 0.) env.agents[0].target = (3., 0.)
predictions = env.obs_builder.predictor.get() predictions = env.obs_builder.predictor.get(None)
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment