Commit 96f670b3 authored by Erik Nygren's avatar Erik Nygren
Browse files

Major update to shortest path predictor

major update and bugfix for collision detection
parent 8176e2fe
Pipeline #1334 failed with stage
in 6 minutes and 13 seconds
......@@ -12,7 +12,7 @@ class Transitions:
Generic class that implements checks to control whether a
certain transition is allowed (agent facing a direction
`orientation' and moving into direction `direction')
`orientation' and moving into direction `orientation')
"""
def get_type(self):
......
......@@ -283,7 +283,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
......@@ -351,22 +351,23 @@ class TreeObsForRailEnv(ObservationBuilder):
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict:
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict:
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict:
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......@@ -436,41 +437,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf,
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction
]
"""
if last_isTarget:
observation = [own_target_encountered,
......
......@@ -142,7 +142,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
min_dist = np.inf
for direction in range(4):
if cell_transitions[direction] == 1:
target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direction
......@@ -150,21 +151,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
else:
raise Exception("No transition possible {}".format(cell_transitions))
# which action to take for the transition?
action = None
for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
_, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent)
if np.array_equal(_new_position, new_position):
action = _action
break
assert action is not None
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
# prediction is ready
prediction[index] = [index, *new_position, new_direction, action]
prediction[index] = [index, *new_position, new_direction, 0]
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
......
......@@ -9,7 +9,6 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
"""Test predictions for `flatland` package."""
......@@ -187,7 +186,6 @@ def test_shortest_path_predictor(rendering=False):
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
expected_positions = [
[5, 6],
......@@ -260,35 +258,9 @@ def test_shortest_path_predictor(rendering=False):
[20.],
])
expected_actions = np.array([
[RailEnvActions.DO_NOTHING], # next [5,6]
[RailEnvActions.MOVE_FORWARD], # next [4,6]
[RailEnvActions.MOVE_FORWARD], # next [3,6]
[RailEnvActions.MOVE_RIGHT], # next [3,7]
[RailEnvActions.MOVE_FORWARD], # next [3,8]
[RailEnvActions.MOVE_FORWARD], # next [3,9]
[RailEnvActions.STOP_MOVING], # at [3,9] == target
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
])
assert np.array_equal(positions, expected_positions), \
"positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions), \
"actions {}, expected {}".format(actions, expected_actions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment