Skip to content
Snippets Groups Projects
Commit 96f670b3 authored by Erik Nygren's avatar Erik Nygren
Browse files

Major update to shortest path predictor

major update and bugfix for collision detection
parent 8176e2fe
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ class Transitions:
Generic class that implements checks to control whether a
certain transition is allowed (agent facing a direction
`orientation' and moving into direction `direction')
`orientation' and moving into direction `orientation')
"""
def get_type(self):
......
......@@ -283,7 +283,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
......@@ -351,22 +351,23 @@ class TreeObsForRailEnv(ObservationBuilder):
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for opposing paths at distance num_step
if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict:
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict:
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict:
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......@@ -436,41 +437,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf,
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction
]
"""
if last_isTarget:
observation = [own_target_encountered,
......
......@@ -142,7 +142,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
min_dist = np.inf
for direction in range(4):
if cell_transitions[direction] == 1:
target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direction
......@@ -150,21 +151,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
else:
raise Exception("No transition possible {}".format(cell_transitions))
# which action to take for the transition?
action = None
for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
_, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent)
if np.array_equal(_new_position, new_position):
action = _action
break
assert action is not None
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
# prediction is ready
prediction[index] = [index, *new_position, new_direction, action]
prediction[index] = [index, *new_position, new_direction, 0]
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
......
......@@ -9,7 +9,6 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
"""Test predictions for `flatland` package."""
......@@ -187,7 +186,6 @@ def test_shortest_path_predictor(rendering=False):
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
expected_positions = [
[5, 6],
......@@ -260,35 +258,9 @@ def test_shortest_path_predictor(rendering=False):
[20.],
])
expected_actions = np.array([
[RailEnvActions.DO_NOTHING], # next [5,6]
[RailEnvActions.MOVE_FORWARD], # next [4,6]
[RailEnvActions.MOVE_FORWARD], # next [3,6]
[RailEnvActions.MOVE_RIGHT], # next [3,7]
[RailEnvActions.MOVE_FORWARD], # next [3,8]
[RailEnvActions.MOVE_FORWARD], # next [3,9]
[RailEnvActions.STOP_MOVING], # at [3,9] == target
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
[RailEnvActions.STOP_MOVING],
])
assert np.array_equal(positions, expected_positions), \
"positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions), \
"actions {}, expected {}".format(actions, expected_actions)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment