Skip to content
Snippets Groups Projects
Commit c7ca42e0 authored by u214892's avatar u214892
Browse files

66 shortest-path-predictor: cleanup and unit test; not working yet

parent 28ddb598
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,8 @@ a GridTransitionMap object. ...@@ -7,6 +7,8 @@ a GridTransitionMap object.
import numpy as np import numpy as np
from flatland.core.transitions import Grid4TransitionsEnum
def get_direction(pos1, pos2): def get_direction(pos1, pos2):
""" """
...@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2): ...@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
def get_new_position(position, movement): def get_new_position(position, movement):
if movement == 0: # NORTH """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1]) return (position[0] - 1, position[1])
elif movement == 1: # EAST elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1) return (position[0], position[1] + 1)
elif movement == 2: # SOUTH elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1]) return (position[0] + 1, position[1])
elif movement == 3: # WEST elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1) return (position[0], position[1] - 1)
......
...@@ -6,6 +6,7 @@ from collections import deque ...@@ -6,6 +6,7 @@ from collections import deque
import numpy as np import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.env_utils import coordinate_to_position from flatland.envs.env_utils import coordinate_to_position
...@@ -162,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -162,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
Utility function that converts a compass movement over a 2D grid to new positions (r, c). Utility function that converts a compass movement over a 2D grid to new positions (r, c).
""" """
if movement == 0: # NORTH if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1]) return (position[0] - 1, position[1])
elif movement == 1: # EAST elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1) return (position[0], position[1] + 1)
elif movement == 2: # SOUTH elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1]) return (position[0] + 1, position[1])
elif movement == 3: # WEST elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1) return (position[0], position[1] - 1)
def get_many(self, handles=[]): def get_many(self, handles=[]):
......
...@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. ...@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import numpy as np import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.env_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_env import RailEnvActions
...@@ -55,8 +56,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -55,8 +56,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
action_done = False action_done = False
# if we're at the target, stop moving... # if we're at the target, stop moving...
if agent.position == agent.target: if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
RailEnvActions.STOP_MOVING]
continue continue
for action in action_priorities: for action in action_priorities:
...@@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if np.sum(cell_transitions) == 1: if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions) new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1: elif np.sum(cell_transitions) > 1:
min_dist = np.inf min_dist = np.inf
for direction in range(4): for direction in range(4):
...@@ -144,11 +144,22 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -144,11 +144,22 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if target_dist < min_dist: if target_dist < min_dist:
min_dist = target_dist min_dist = target_dist
new_direction = direction new_direction = direction
new_position = self._new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
else:
raise Exception("No transition possible {}".format(cell_transitions))
action = None
for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
cell_isFree, new_cell_isValid, new_direction, _new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if np.array_equal(_new_position, new_position):
action = _action
break
assert action is not None
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD] prediction[index] = [index, *new_position, new_direction, action]
action_done = True action_done = True
if not action_done: if not action_done:
raise Exception("Cannot move further. Something is wrong") raise Exception("Cannot move further. Something is wrong")
...@@ -159,16 +170,3 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -159,16 +170,3 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
agent.direction = _agent_initial_direction agent.direction = _agent_initial_direction
return prediction_dict return prediction_dict
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment