diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 13781f121bbb2cf47c0f32767598814f109b5ffb..dd351aa76c11bd130242820626de25392c3f40cd 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -14,7 +14,7 @@ from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file from flatland.utils.ordered_set import OrderedSet -ShortestPathElement = \ +WalkingElement = \ NamedTuple('Path_Element', [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)]) @@ -87,7 +87,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, return valid_actions -def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPathElement]]: +def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[WalkingElement]]: # TODO: do we need to support unreachable targets? # TODO refactoring: unify with predictor (support agent.moving and max_depth) shortest_paths = dict() @@ -106,13 +106,13 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath best_next_action = next_action distance = next_action_distance - shortest_paths[a.handle].append(ShortestPathElement(position, direction, best_next_action)) + shortest_paths[a.handle].append(WalkingElement(position, direction, best_next_action)) position = best_next_action.next_position direction = best_next_action.next_direction shortest_paths[a.handle].append( - ShortestPathElement(position, direction, - RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) + WalkingElement(position, direction, + RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) return shortest_paths diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 6139b0e22a1b7f89d5f169ab3ae2ac13eda69d05..569cd3addddf77e0e328e7d244ec57078a90281f 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction -from flatland.envs.rail_env_utils import get_shortest_paths, ShortestPathElement +from flatland.envs.rail_env_utils import get_shortest_paths, WalkingElement from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool @@ -145,18 +145,18 @@ def test_shortest_path_predictor(rendering=False): paths = get_shortest_paths(env.distance_map)[0] assert paths == [ - ShortestPathElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), - next_direction=0)), - ShortestPathElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), - next_direction=0)), - ShortestPathElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), - next_direction=1)), - ShortestPathElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), - next_direction=1)), - ShortestPathElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), - next_direction=1)), - ShortestPathElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9), - next_direction=1))] + WalkingElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), + next_direction=0)), + WalkingElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), + next_direction=0)), + WalkingElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), + next_direction=1)), + WalkingElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), + next_direction=1)), + WalkingElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), + next_direction=1)), + WalkingElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9), + next_direction=1))] # extract the data predictions = env.obs_builder.predictions