Forked from
Flatland / Flatland
528 commits behind the upstream repository.
test_action_plan.py 6.28 KiB
from flatland.action_plan.action_plan import PathScheduleElement, CellPin, ActionPlanReplayer
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import WalkingElement
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
def test_action_plan(rendering: bool = False):
"""Tests ActionPlanReplayer: does action plan generation and replay work as expected."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=77),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True
)
env.reset()
env.agents[0].initial_position = (3, 0)
env.agents[0].target = (3, 8)
env.agents[0].initial_direction = Grid4TransitionsEnum.WEST
env.agents[1].initial_position = (3, 8)
env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
env.agents[1].target = (0, 3)
env.agents[1].speed_data['speed'] = 0.5 # two
env.reset(False, False, False)
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
chosen_path_dict = {0: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=0, d=3)),
PathScheduleElement(scheduled_at=2, cell_pin=CellPin(r=3, c=1, d=1)),
PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=2, d=1)),
PathScheduleElement(scheduled_at=14, cell_pin=CellPin(r=3, c=3, d=1)),
PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=3, c=4, d=1)),
PathScheduleElement(scheduled_at=16, cell_pin=CellPin(r=3, c=5, d=1)),
PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=3, c=6, d=1)),
PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=3, c=7, d=1)),
PathScheduleElement(scheduled_at=19, cell_pin=CellPin(r=3, c=8, d=1)),
PathScheduleElement(scheduled_at=20, cell_pin=CellPin(r=3, c=8, d=5))],
1: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=8, d=3)),
PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=7, d=3)),
PathScheduleElement(scheduled_at=5, cell_pin=CellPin(r=3, c=6, d=3)),
PathScheduleElement(scheduled_at=7, cell_pin=CellPin(r=3, c=5, d=3)),
PathScheduleElement(scheduled_at=9, cell_pin=CellPin(r=3, c=4, d=3)),
PathScheduleElement(scheduled_at=11, cell_pin=CellPin(r=3, c=3, d=3)),
PathScheduleElement(scheduled_at=13, cell_pin=CellPin(r=2, c=3, d=0)),
PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=1, c=3, d=0)),
PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=0, c=3, d=0)),
PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=0, c=3, d=5))]}
expected_action_plan = [[
# take action to enter the grid
(0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
# take action to enter the cell properly
(1, WalkingElement(position=(3, 0), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(2, WalkingElement(position=(3, 1), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(3, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.STOP_MOVING)),
(13, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(14, WalkingElement(position=(3, 3), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(15, WalkingElement(position=(3, 4), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(16, WalkingElement(position=(3, 5), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(17, WalkingElement(position=(3, 6), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(18, WalkingElement(position=(3, 7), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(19, WalkingElement(position=None, direction=1, next_action=RailEnvActions.STOP_MOVING))
], [
(0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(1, WalkingElement(position=(3, 8), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(3, WalkingElement(position=(3, 7), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(5, WalkingElement(position=(3, 6), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(7, WalkingElement(position=(3, 5), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(9, WalkingElement(position=(3, 4), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(11, WalkingElement(position=(3, 3), direction=3, next_action=RailEnvActions.MOVE_RIGHT)),
(13, WalkingElement(position=(2, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)),
(15, WalkingElement(position=(1, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)),
(17, WalkingElement(position=None, direction=0, next_action=RailEnvActions.STOP_MOVING)),
]]
MAX_EPISODE_STEPS = 50
actual_action_plan = ActionPlanReplayer(env, chosen_path_dict)
actual_action_plan.print_action_plan()
ActionPlanReplayer.compare_action_plans(expected_action_plan, actual_action_plan.action_plan)
assert actual_action_plan.action_plan == expected_action_plan, \
"expected {}, found {}".format(expected_action_plan, actual_action_plan.action_plan)
actual_action_plan.replay_verify(MAX_EPISODE_STEPS, env, rendering)