diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 294ffab233458f1f3b98c18be50743ba65bd2d73..0467fcd6b0086981c3dcd30d4bb25072361e96cb 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -236,7 +236,8 @@ class RailEnv(Environment): Relies on the rail_generator returning agent_static lists (pos, dir, target) """ - # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? + # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 + # can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) if optionals and 'distance_map' in optionals: @@ -256,6 +257,9 @@ class RailEnv(Environment): agents_hints = None if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] + + # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185 + # why do we need static agents? could we it more elegantly? self.agents_static = EnvAgentStatic.from_lists( *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) self.restart_agents() @@ -356,6 +360,8 @@ class RailEnv(Environment): # Perform step on all agents for i_agent in range(self.get_num_agents()): + if self._elapsed_steps - 1 == 3 and i_agent == 0: + a = 5 self._step_agent(i_agent, action_dict_.get(i_agent)) # Check for end of episode + set global reward to all rewards! @@ -407,13 +413,14 @@ class RailEnv(Environment): # is the agent malfunctioning? malfunction = self._agent_malfunction(i_agent) - # if agent is broken, actions are ignored and agent does not move, - # the agent is not penalized in this step! + # if agent is broken, actions are ignored and agent does not move. + # full step penalty in this case if malfunction: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] return # Is the agent at the beginning of the cell? Then, it can take an action. + # As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken! if agent.speed_data['position_fraction'] == 0.0: # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: @@ -464,7 +471,6 @@ class RailEnv(Environment): if not _action_stored: # If the agent cannot move due to an invalid transition, we set its state to not moving self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.stop_penalty agent.moving = False @@ -497,6 +503,9 @@ class RailEnv(Environment): agent.moving = False else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + else: + # step penalty if not moving (stopped now or before) + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): """ diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 884a2a51f84a40a45acced32e7310dcf4d497944..8bd023cf251d609fd324fc30cba5a004b050ab66 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -4,13 +4,11 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator -from flatland.utils.rendertools import RenderTool -from test_utils import Replay +from test_utils import Replay, ReplayConfig, run_replay_config class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -155,7 +153,7 @@ def test_malfunction_process_statistically(): assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction) -def test_initial_malfunction(rendering=True): +def test_initial_malfunction(): random.seed(0) np.random.seed(0) @@ -190,74 +188,55 @@ def test_initial_malfunction(rendering=True): stochastic_data=stochastic_data, # Malfunction data generator ) - if rendering: - renderer = RenderTool(env) - renderer.render_env(show=True, frames=False, show_observations=False) - _action = dict() - - replay_steps = [ - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=3 - ), - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=2 - ), - # malfunction stops in the next step and we're still at the beginning of the cell - # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=1 - ), - Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ), - Replay( - position=(27, 4), - direction=Grid4TransitionsEnum.NORTH, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ) - ] - - info_dict = { - 'action_required': [True] - } - - for i, replay in enumerate(replay_steps): - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) - - -def test_initial_malfunction_stop_moving(rendering=True): + replay_config = ReplayConfig( + replay=[ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + set_malfunction=3, + malfunction=3, + reward=env.step_penalty # full step penalty when malfunctioning + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=2, + reward=env.step_penalty # full step penalty when malfunctioning + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=1, + reward=env.start_penalty + env.step_penalty * 1.0 + # malfunctioning ends: starting and running at speed 1.0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # running at speed 1.0 + ), + Replay( + position=(27, 4), + direction=Grid4TransitionsEnum.NORTH, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # running at speed 1.0 + ) + ], + speed=env.agents[0].speed_data['speed'], + target=env.agents[0].target + ) + run_replay_config(env, [replay_config]) + + +def test_initial_malfunction_stop_moving(): random.seed(0) np.random.seed(0) @@ -292,80 +271,62 @@ def test_initial_malfunction_stop_moving(rendering=True): stochastic_data=stochastic_data, # Malfunction data generator ) - if rendering: - renderer = RenderTool(env) - renderer.render_env(show=True, frames=False, show_observations=False) - _action = dict() - - replay_steps = [ - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=3 - ), - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=2 - ), - # malfunction stops in the next step and we're still at the beginning of the cell - # --> if we take action DO_NOTHING, agent should restart without moving - # - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.STOP_MOVING, - malfunction=1 - ), - # we have stopped and do nothing --> should stand still - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=0 - ), - # we start to move forward --> should go to next cell now - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ), - Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ) - ] - - info_dict = { - 'action_required': [True] - } - - for i, replay in enumerate(replay_steps): - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) + replay_config = ReplayConfig( + replay=[ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + set_malfunction=3, + malfunction=3, + reward=env.step_penalty # full step penalty when stopped + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2, + reward=env.step_penalty # full step penalty when stopped + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action STOP_MOVING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.STOP_MOVING, + malfunction=1, + reward=env.step_penalty # full step penalty while stopped + ), + # we have stopped and do nothing --> should stand still + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0, + reward=env.step_penalty # full step penalty while stopped + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # full step penalty while stopped + ) + ], + speed=env.agents[0].speed_data['speed'], + target=env.agents[0].target + ) + + run_replay_config(env, [replay_config]) def test_initial_malfunction_do_nothing(rendering=True): @@ -403,77 +364,58 @@ def test_initial_malfunction_do_nothing(rendering=True): stochastic_data=stochastic_data, # Malfunction data generator ) - if rendering: - renderer = RenderTool(env) - renderer.render_env(show=True, frames=False, show_observations=False) - _action = dict() - - replay_steps = [ - Replay( + replay_config = ReplayConfig( + replay=[Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=3 + set_malfunction=3, + malfunction=3, + reward=env.step_penalty # full step penalty while malfunctioning ), - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=2 - ), - # malfunction stops in the next step and we're still at the beginning of the cell - # --> if we take action DO_NOTHING, agent should restart without moving - # - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=1 - ), - # we haven't started moving yet --> stay here - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=0 - ), - # we start to move forward --> should go to next cell now - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ), - Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ) - ] - - info_dict = { - 'action_required': [True] - } - - for i, replay in enumerate(replay_steps): - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2, + reward=env.step_penalty # full step penalty while malfunctioning + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action DO_NOTHING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=1, + reward=env.step_penalty # full step penalty while stopped + ), + # we haven't started moving yet --> stay here + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0, + reward=env.step_penalty # full step penalty while stopped + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # step penalty for speed 1.0 + ) + ], + speed=env.agents[0].speed_data['speed'], + target=env.agents[0].target + ) + + run_replay_config(env, [replay_config]) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 1cf0c325ac48e9e3d5ac04fb51b5f8462c867726..f2fd3613d922ecd5dc23bd6a0649936f7838fbd6 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,7 +1,6 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -9,7 +8,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail -from test_utils import ReplayConfig, Replay +from test_utils import ReplayConfig, Replay, run_replay_config np.random.seed(1) @@ -95,7 +94,6 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position -# TODO test penalties! # TODO test invalid actions! def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): """Test that actions are correctly performed on cell exit for a single agent.""" @@ -108,123 +106,96 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") + # env.start_penalty = 13 + # env.stop_penalty = 19 test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_LEFT + action=RailEnvActions.MOVE_LEFT, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty ), # Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when stopped ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting + running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), - ], target=(3, 0), # west dead-end speed=0.5 ) - agentStatic: EnvAgentStatic = env.agents_static[0] - info_dict = { - 'action_required': [True] - } - for i, replay in enumerate(test_config.replay): - if i == 0: - # set the initial position - agentStatic.position = replay.position - agentStatic.direction = replay.direction - agentStatic.target = test_config.target - agentStatic.moving = True - agentStatic.speed_data['speed'] = test_config.speed - - # reset to set agents from agents_static - env.reset(False, False) - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) + run_replay_config(env, [test_config]) def test_multispeed_actions_no_malfunction_blocking(rendering=True): @@ -238,80 +209,83 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") - test_configs = [ ReplayConfig( replay=[ Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 1.0 / 3.0 # starting and running at speed 1/3 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ) ], target=(3, 0), # west dead-end @@ -321,69 +295,81 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_LEFT + action=RailEnvActions.MOVE_LEFT, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # not blocked, action required! Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], target=(3, 0), # west dead-end @@ -391,49 +377,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): ) ] - - # TODO test penalties! - info_dict = { - 'action_required': [True for _ in test_configs] - } - for step in range(len(test_configs[0].replay)): - if step == 0: - for a, test_config in enumerate(test_configs): - agentStatic: EnvAgentStatic = env.agents_static[a] - replay = test_config.replay[0] - # set the initial position - agentStatic.position = replay.position - agentStatic.direction = replay.direction - agentStatic.target = test_config.target - agentStatic.moving = True - agentStatic.speed_data['speed'] = test_config.speed - - # reset to set agents from agents_static - env.reset(False, False) - - def _assert(a, actual, expected, msg): - assert actual == expected, "[{}] {} {}: actual={}, expected={}".format(step, a, msg, actual, expected) - - action_dict = {} - - for a, test_config in enumerate(test_configs): - agent: EnvAgent = env.agents[a] - replay = test_config.replay[step] - - _assert(a, agent.position, replay.position, 'position') - _assert(a, agent.direction, replay.direction, 'direction') - - if replay.action is not None: - assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( - step, a, True) - action_dict[a] = replay.action - else: - assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( - step, a, False) - _, _, _, info_dict = env.step(action_dict) - - if rendering: - renderer.render_env(show=True, show_observations=True) + run_replay_config(env, test_configs) def test_multispeed_actions_malfunction_no_blocking(rendering=True): @@ -461,136 +405,118 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # add additional step in the cell Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, - malfunction=2 # recovers in two steps from now! + set_malfunction=2, # recovers in two steps from now!, + malfunction=2, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + malfunction=1, + reward=env.step_penalty * 0.5 # recovered: running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, - malfunction=2 # recovers in two steps from now! + set_malfunction=2, # recovers in two steps from now! + malfunction=2, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_LEFT, + malfunction=1, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty for speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 while stopped ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # DO_NOTHING keeps moving! Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.DO_NOTHING + action=RailEnvActions.DO_NOTHING, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(6, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], target=(3, 0), # west dead-end speed=0.5 ) - - # TODO test penalties! - agentStatic: EnvAgentStatic = env.agents_static[0] - info_dict = { - 'action_required': [True] - } - for i, replay in enumerate(test_config.replay): - if i == 0: - # set the initial position - agentStatic.position = replay.position - agentStatic.direction = replay.direction - agentStatic.target = test_config.target - agentStatic.moving = True - agentStatic.speed_data['speed'] = test_config.speed - - # reset to set agents from agents_static - env.reset(False, False) - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - - if replay.malfunction > 0: - agent.malfunction_data['malfunction'] = replay.malfunction - agent.malfunction_data['moving_before_malfunction'] = agent.moving - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) + run_replay_config(env, [test_config]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 967615833c4790d67af80a1d75e35174e2ff5e5a..88d669fa220e8d65f8e667532c39e0f0e6ad7a69 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,9 @@ from typing import List, Tuple, Optional from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.rail_env import RailEnvActions +from flatland.envs.agent_utils import EnvAgent +from flatland.envs.rail_env import RailEnvActions, RailEnv +from flatland.utils.rendertools import RenderTool @attrs @@ -13,7 +15,8 @@ class Replay(object): direction = attrib(type=Grid4TransitionsEnum) action = attrib(type=RailEnvActions) malfunction = attrib(default=0, type=int) - penalty = attrib(default=None, type=Optional[float]) + set_malfunction = attrib(default=None, type=Optional[int]) + reward = attrib(default=None, type=Optional[float]) @attrs @@ -21,3 +24,78 @@ class ReplayConfig(object): replay = attrib(type=List[Replay]) target = attrib(type=Tuple[int, int]) speed = attrib(type=float) + + +def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): + """ + Runs the replay configs and checks assertions. + + *Initially* + - the position, direction, target and speed of the initial step are taken to initialize the agents + + *Before each step* + - action must only be provided if action_required from previous step (initally all True) + - position, direction before step are verified + - optionally, set_malfunction is applied + - malfunction is verified + + *After each step* + - reward is verified after step + + Parameters + ---------- + env + test_configs + rendering + """ + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + info_dict = { + 'action_required': [True for _ in test_configs] + } + for step in range(len(test_configs[0].replay)): + if step == 0: + for a, test_config in enumerate(test_configs): + agent: EnvAgent = env.agents[a] + replay = test_config.replay[0] + # set the initial position + agent.position = replay.position + agent.direction = replay.direction + agent.target = test_config.target + agent.speed_data['speed'] = test_config.speed + + def _assert(a, actual, expected, msg): + assert actual == expected, "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, actual, + expected) + + action_dict = {} + + for a, test_config in enumerate(test_configs): + agent: EnvAgent = env.agents[a] + replay = test_config.replay[step] + + _assert(a, agent.position, replay.position, 'position') + _assert(a, agent.direction, replay.direction, 'direction') + + if replay.action is not None: + assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( + step, a, True) + action_dict[a] = replay.action + else: + assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( + step, a, False) + + if replay.set_malfunction is not None: + agent.malfunction_data['malfunction'] = replay.set_malfunction + agent.malfunction_data['moving_before_malfunction'] = agent.moving + _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + _, rewards_dict, _, info_dict = env.step(action_dict) + + for a, test_config in enumerate(test_configs): + replay = test_config.replay[step] + _assert(a, rewards_dict[a], replay.reward, 'reward') + + if rendering: + renderer.render_env(show=True, show_observations=True)