Skip to content
Snippets Groups Projects
Commit c7ca42e0 authored by u214892's avatar u214892
Browse files

66 shortest-path-predictor: cleanup and unit test; not working yet

parent 28ddb598
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,8 @@ a GridTransitionMap object.
import numpy as np
from flatland.core.transitions import Grid4TransitionsEnum
def get_direction(pos1, pos2):
"""
......@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
def get_new_position(position, movement):
if movement == 0: # NORTH
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == 1: # EAST
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == 3: # WEST
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
......
......@@ -6,6 +6,7 @@ from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.env_utils import coordinate_to_position
......@@ -162,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == 1: # EAST
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == 3: # WEST
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_many(self, handles=[]):
......
......@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.env_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
......@@ -55,8 +56,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction,
RailEnvActions.STOP_MOVING]
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
......@@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
min_dist = np.inf
for direction in range(4):
......@@ -144,11 +144,22 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if target_dist < min_dist:
min_dist = target_dist
new_direction = direction
new_position = self._new_position(agent.position, new_direction)
new_position = get_new_position(agent.position, new_direction)
else:
raise Exception("No transition possible {}".format(cell_transitions))
action = None
for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
cell_isFree, new_cell_isValid, new_direction, _new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if np.array_equal(_new_position, new_position):
action = _action
break
assert action is not None
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD]
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
if not action_done:
raise Exception("Cannot move further. Something is wrong")
......@@ -159,16 +170,3 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
agent.direction = _agent_initial_direction
return prediction_dict
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment