Skip to content
Snippets Groups Projects
Commit 60bbde32 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'core-without-rendering' into 'master'

Core without rendering

See merge request flatland/flatland!281
parents c9b4f94b d9dc8c8f
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,6 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d ...@@ -7,7 +7,6 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# ---- ActionPlan --------------- # ---- ActionPlan ---------------
# an action plan element represents the actions to be taken by an agent at the given time step # an action plan element represents the actions to be taken by an agent at the given time step
...@@ -264,35 +263,3 @@ class ControllerFromTrainruns(): ...@@ -264,35 +263,3 @@ class ControllerFromTrainruns():
# execute the action exactly minimum_cell_time before the entry into the next cell # execute the action exactly minimum_cell_time before the entry into the next cell
action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action) action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
action_plan.append(action) action_plan.append(action)
class ControllerFromTrainrunsReplayer():
"""Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
@staticmethod
def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv, rendering: bool):
"""Replays this deterministic `ActionPlan` and verifies whether it is feasible."""
if rendering:
renderer = RenderTool(env, gl="PILSVG",
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
show_debug=True,
clear_debug_text=True,
screen_height=1000,
screen_width=1000)
renderer.render_env(show=True, show_observations=False, show_predictions=False)
i = 0
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
actions = ctl.act(i)
print("actions for {}: {}".format(i, actions))
obs, all_rewards, done, _ = env.step(actions)
if rendering:
renderer.render_env(show=True, show_observations=False, show_predictions=False)
i += 1
from typing import Callable
from flatland.action_plan.action_plan import ControllerFromTrainruns
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_trainrun_data_structures import Waypoint
ControllerFromTrainrunsReplayerRenderCallback = Callable[[RailEnv], None]
class ControllerFromTrainrunsReplayer():
"""Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
@staticmethod
def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv,
call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a, **k: None):
"""Replays this deterministic `ActionPlan` and verifies whether it is feasible.
Parameters
----------
ctl
env
call_back
Called before/after each step() call. The env is passed to it.
"""
call_back(env)
i = 0
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
actions = ctl.act(i)
print("actions for {}: {}".format(i, actions))
obs, all_rewards, done, _ = env.step(actions)
call_back(env)
i += 1
from flatland.action_plan.action_plan import TrainrunWaypoint, ControllerFromTrainrunsReplayer, ActionPlanElement, \ from flatland.action_plan.action_plan import TrainrunWaypoint, ActionPlanElement, \
ControllerFromTrainruns ControllerFromTrainruns
from flatland.action_plan.action_plan_player import ControllerFromTrainrunsReplayer
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.utils.simple_rail import make_simple_rail from flatland.utils.simple_rail import make_simple_rail
...@@ -83,4 +85,16 @@ def test_action_plan(rendering: bool = False): ...@@ -83,4 +85,16 @@ def test_action_plan(rendering: bool = False):
deterministic_controller = ControllerFromTrainruns(env, chosen_path_dict) deterministic_controller = ControllerFromTrainruns(env, chosen_path_dict)
deterministic_controller.print_action_plan() deterministic_controller.print_action_plan()
ControllerFromTrainruns.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan) ControllerFromTrainruns.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan)
ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, rendering) if rendering:
renderer = RenderTool(env, gl="PILSVG",
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
show_debug=True,
clear_debug_text=True,
screen_height=1000,
screen_width=1000)
def render(*argv):
if rendering:
renderer.render_env(show=True, show_observations=False, show_predictions=False)
ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, call_back=render)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment