Commit a2061acf authored by u214892's avatar u214892
Browse files

refactoring WalkingElement

parent 2d93454f
Pipeline #2781 passed with stages
in 38 minutes and 38 seconds
......@@ -14,7 +14,7 @@ from flatland.utils.ordered_set import OrderedSet
WalkingElement = \
NamedTuple('WalkingElement',
[('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
[('position', Tuple[int, int]), ('direction', int), ('next_action', Optional[RailEnvActions])])
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
......@@ -74,7 +74,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
def get_new_position_for_action(
agent_position: Tuple[int, int],
agent_direction: Grid4TransitionsEnum,
action: RailEnvNextAction,
action: RailEnvActions,
rail: GridTransitionMap) -> Tuple[int, int, int]:
"""
Get the next position for this action.
......@@ -245,7 +245,8 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
best_next_action = next_action
distance = next_action_distance
shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action))
shortest_paths[agent.handle].append(
WalkingElement(position, direction, best_next_action.action if best_next_action is not None else None))
depth += 1
# if there is no way to continue, the rail must be disconnected!
......@@ -257,9 +258,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
position = best_next_action.next_position
direction = best_next_action.next_direction
if max_depth is None or depth < max_depth:
shortest_paths[agent.handle].append(
WalkingElement(position, direction,
RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
shortest_paths[agent.handle].append(WalkingElement(position, direction, RailEnvActions.STOP_MOVING))
if agent_handle is not None:
_shortest_path_for_agent(distance_map.agents[agent_handle])
......@@ -279,7 +278,7 @@ def get_k_shortest_paths(env: RailEnv,
Computes the k shortest paths using modified Dijkstra
following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing
In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths.
We add the next_action_element
We add the next_action
Parameters
----------
......@@ -356,7 +355,7 @@ def get_k_shortest_paths(env: RailEnv,
if debug:
print(" looking at neighbor v={}".format((*new_position, new_direction)))
v = WalkingElement(position=new_position, direction=new_direction, next_action_element=None)
v = WalkingElement(position=new_position, direction=new_direction, next_action=None)
# CAVEAT: do not allow for loopy paths
if v in pu:
continue
......@@ -368,19 +367,18 @@ def get_k_shortest_paths(env: RailEnv,
# add actions to shortest paths
shortest_paths_with_action = []
for p in shortest_paths:
p_with_action = tuple(WalkingElement(position=el.position,
direction=el.direction,
next_action_element=RailEnvNextAction(
action=int(get_action_for_move(el.position,
el.direction,
p[i + 1].position,
p[i + 1].direction,
env.rail)),
next_position=p[i + 1].position,
next_direction=p[i + 1].direction)) for i, el in enumerate(p[:-1]))
target_walking_element = WalkingElement(position=p[-1].position, direction=p[-1].direction,
next_action_element=RailEnvNextAction(action=int(RailEnvActions.DO_NOTHING),
next_position=None, next_direction=None))
p_with_action = tuple(
WalkingElement(position=el.position,
direction=el.direction,
next_action=int(get_action_for_move(el.position,
el.direction,
p[i + 1].position,
p[i + 1].direction,
env.rail))) for i, el in
enumerate(p[:-1]))
target_walking_element = WalkingElement(position=p[-1].position,
direction=p[-1].direction,
next_action=int(RailEnvActions.DO_NOTHING))
shortest_paths_with_action.append(p_with_action + (target_walking_element,))
# return P
......
......@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
......@@ -146,18 +146,13 @@ def test_shortest_path_predictor(rendering=False):
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
WalkingElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6),
next_direction=0)),
WalkingElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6),
next_direction=0)),
WalkingElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7),
next_direction=1)),
WalkingElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8),
next_direction=1)),
WalkingElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9),
next_direction=1)),
WalkingElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9),
next_direction=1))]
WalkingElement((5, 6), 0, RailEnvActions.MOVE_FORWARD),
WalkingElement((4, 6), 0, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 6), 0, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 7), 1, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 8), 1, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 9), 1, RailEnvActions.STOP_MOVING)
]
# extract the data
predictions = env.obs_builder.predictions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment