diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index a049ae260860e946ac8b27f9e83ab11bf4ed2920..3a75aa81193d2355f71a05d8825bc64da4547f6f 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -46,8 +46,6 @@ def a_star(grid_map: GridTransitionMap, """ rail_shape = grid_map.grid.shape - tmp = np.zeros(rail_shape) - 10 - start_node = AStarNode(start, None) end_node = AStarNode(end, None) open_nodes = OrderedSet() @@ -114,8 +112,6 @@ def a_star(grid_map: GridTransitionMap, child.h = a_star_distance_function(child.pos, end_node.pos) child.f = child.g + child.h - tmp[child.pos[0]][child.pos[1]] = child.f - # already in the open list? if child in open_nodes: continue diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py index 98652459d7a7ac7f1694ac53fe1d0a12880ab8b2..75cef7b4d3aea783140a5c08c3498a0bc321fb62 100644 --- a/flatland/core/grid/grid4_utils.py +++ b/flatland/core/grid/grid4_utils.py @@ -1,8 +1,8 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid_utils import IntVector2DArray +from flatland.core.grid.grid_utils import IntVector2D -def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum: +def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum: """ Assumes pos1 and pos2 are adjacent location on grid. Returns direction (int) that can be used with transitions. @@ -10,13 +10,13 @@ def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4Transi diff_0 = pos2[0] - pos1[0] diff_1 = pos2[1] - pos1[1] if diff_0 < 0: - return 0 + return Grid4TransitionsEnum.NORTH if diff_0 > 0: - return 2 + return Grid4TransitionsEnum.SOUTH if diff_1 > 0: - return 1 + return Grid4TransitionsEnum.EAST if diff_1 < 0: - return 3 + return Grid4TransitionsEnum.WEST raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index d6f47abfd85cfa1cc7e72e27aeb4f7ededa975dd..fce3ffdf320a3c38d7f0551151ffdc8debe6ab5d 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -7,22 +7,25 @@ a GridTransitionMap object. from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror -from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance +from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, - end: IntVector2D, - flip_start_node_trans=False, - flip_end_node_trans=False, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): +def connect_basic_operation( + rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + start: IntVector2D, + end: IntVector2D, + flip_start_node_trans=False, + flip_end_node_trans=False, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: """ - Creates a new path [start,end] in grid_map, based on rail_trans. + Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and + returns the path created as a list of positions. """ # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(grid_map, start, end, a_star_distance_function) + path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function) if len(path) < 2: return [] current_dir = get_direction(path[0], path[1]) @@ -71,23 +74,24 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function) def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function) def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance + ) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function) def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index c81ef9dc82df0817f3f3fc42798392d7ffdbcf5e..862774319ce58c2625b227b89b77940c12016e89 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -237,7 +237,8 @@ class RailEnv(Environment): Relies on the rail_generator returning agent_static lists (pos, dir, target) """ - # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? + # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 + # can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) if optionals and 'distance_map' in optionals: @@ -257,6 +258,9 @@ class RailEnv(Environment): agents_hints = None if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] + + # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185 + # why do we need static agents? could we it more elegantly? self.agents_static = EnvAgentStatic.from_lists( *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) self.restart_agents() @@ -408,13 +412,14 @@ class RailEnv(Environment): # is the agent malfunctioning? malfunction = self._agent_malfunction(i_agent) - # if agent is broken, actions are ignored and agent does not move, - # the agent is not penalized in this step! + # if agent is broken, actions are ignored and agent does not move. + # full step penalty in this case if malfunction: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] return # Is the agent at the beginning of the cell? Then, it can take an action. + # As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken! if agent.speed_data['position_fraction'] == 0.0: # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: @@ -463,9 +468,9 @@ class RailEnv(Environment): _action_stored = True if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.stop_penalty agent.moving = False @@ -498,6 +503,9 @@ class RailEnv(Environment): agent.moving = False else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + else: + # step penalty if not moving (stopped now or before) + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): """ diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index b95922cf67febdaa0aad396459bc446bc31adfea..cf5c8592708eef237bcf29308032df49753860bd 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -2,8 +2,8 @@ import numpy as np import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position from flatland.core.grid.grid4_utils import get_direction +from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position depth_to_test = 5 positions_to_test = [0, 5, 1, 6, 20, 30] @@ -31,4 +31,4 @@ def test_get_direction(): assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH with pytest.raises(Exception, match="Could not determine direction"): - get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH + get_direction((0, 0), (0, 0)) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 55d3526757123230fb351dbf67dbfc269e58b6ac..1b3c6adead4d0d82fd676efcc051fc66b4486ef8 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,16 +1,15 @@ import random +from typing import Dict import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator -from flatland.utils.rendertools import RenderTool -from test_utils import Replay +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -54,7 +53,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances.append(np.inf) observation = [0, 0, 0] - observation[np.argmin(min_distances)] = 1 + observation[np.argmin(min_distances)[0]] = 1 return observation @@ -83,7 +82,6 @@ def test_malfunction_process(): agent_halts = 0 total_down_time = 0 - agent_malfunctioning = False agent_old_position = env.agents[0].position for step in range(100): actions = {} @@ -142,12 +140,12 @@ def test_malfunction_process_statistically(): env.reset() nb_malfunction = 0 for step in range(100): - action_dict = {} + action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: if agent.malfunction_data['malfunction'] > 0: nb_malfunction += 1 # We randomly select an action - action_dict[agent.handle] = np.random.randint(4) + action_dict[agent.handle] = RailEnvActions(np.random.randint(4)) env.step(action_dict) @@ -155,7 +153,7 @@ def test_malfunction_process_statistically(): assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction) -def test_initial_malfunction(rendering=True): +def test_initial_malfunction(): random.seed(0) np.random.seed(0) @@ -189,75 +187,56 @@ def test_initial_malfunction(rendering=True): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - - if rendering: - renderer = RenderTool(env) - renderer.render_env(show=True, frames=False, show_observations=False) - _action = dict() - - replay_steps = [ - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=3 - ), - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=2 - ), - # malfunction stops in the next step and we're still at the beginning of the cell - # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=1 - ), - Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ), - Replay( - position=(27, 4), - direction=Grid4TransitionsEnum.NORTH, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ) - ] - - info_dict = { - 'action_required': [True] - } - - for i, replay in enumerate(replay_steps): - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) - - -def test_initial_malfunction_stop_moving(rendering=True): + set_penalties_for_replay(env) + replay_config = ReplayConfig( + replay=[ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + set_malfunction=3, + malfunction=3, + reward=env.step_penalty # full step penalty when malfunctioning + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=2, + reward=env.step_penalty # full step penalty when malfunctioning + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=1, + reward=env.start_penalty + env.step_penalty * 1.0 + # malfunctioning ends: starting and running at speed 1.0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # running at speed 1.0 + ), + Replay( + position=(27, 4), + direction=Grid4TransitionsEnum.NORTH, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # running at speed 1.0 + ) + ], + speed=env.agents[0].speed_data['speed'], + target=env.agents[0].target + ) + run_replay_config(env, [replay_config]) + + +def test_initial_malfunction_stop_moving(): random.seed(0) np.random.seed(0) @@ -291,84 +270,66 @@ def test_initial_malfunction_stop_moving(rendering=True): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - - if rendering: - renderer = RenderTool(env) - renderer.render_env(show=True, frames=False, show_observations=False) - _action = dict() - - replay_steps = [ - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=3 - ), - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=2 - ), - # malfunction stops in the next step and we're still at the beginning of the cell - # --> if we take action DO_NOTHING, agent should restart without moving - # - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.STOP_MOVING, - malfunction=1 - ), - # we have stopped and do nothing --> should stand still - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=0 - ), - # we start to move forward --> should go to next cell now - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ), - Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ) - ] - - info_dict = { - 'action_required': [True] - } - - for i, replay in enumerate(replay_steps): - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) - - -def test_initial_malfunction_do_nothing(rendering=True): + set_penalties_for_replay(env) + replay_config = ReplayConfig( + replay=[ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + set_malfunction=3, + malfunction=3, + reward=env.step_penalty # full step penalty when stopped + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2, + reward=env.step_penalty # full step penalty when stopped + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action STOP_MOVING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.STOP_MOVING, + malfunction=1, + reward=env.step_penalty # full step penalty while stopped + ), + # we have stopped and do nothing --> should stand still + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0, + reward=env.step_penalty # full step penalty while stopped + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # full step penalty while stopped + ) + ], + speed=env.agents[0].speed_data['speed'], + target=env.agents[0].target + ) + + run_replay_config(env, [replay_config]) + + +def test_initial_malfunction_do_nothing(): random.seed(0) np.random.seed(0) @@ -402,78 +363,59 @@ def test_initial_malfunction_do_nothing(rendering=True): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - - if rendering: - renderer = RenderTool(env) - renderer.render_env(show=True, frames=False, show_observations=False) - _action = dict() - - replay_steps = [ - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=3 - ), - Replay( + set_penalties_for_replay(env) + replay_config = ReplayConfig( + replay=[Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=2 + set_malfunction=3, + malfunction=3, + reward=env.step_penalty # full step penalty while malfunctioning ), - # malfunction stops in the next step and we're still at the beginning of the cell - # --> if we take action DO_NOTHING, agent should restart without moving - # - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=1 - ), - # we haven't started moving yet --> stay here - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - malfunction=0 - ), - # we start to move forward --> should go to next cell now - Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ), - Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD, - malfunction=0 - ) - ] - - info_dict = { - 'action_required': [True] - } - - for i, replay in enumerate(replay_steps): - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2, + reward=env.step_penalty # full step penalty while malfunctioning + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action DO_NOTHING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=1, + reward=env.step_penalty # full step penalty while stopped + ), + # we haven't started moving yet --> stay here + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0, + reward=env.step_penalty # full step penalty while stopped + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0 # step penalty for speed 1.0 + ) + ], + speed=env.agents[0].speed_data['speed'], + target=env.agents[0].target + ) + + run_replay_config(env, [replay_config]) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 1cf0c325ac48e9e3d5ac04fb51b5f8462c867726..b0f274ba4c4b5453140fcc50bc6137e39e8e4f04 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,15 +1,13 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator -from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail -from test_utils import ReplayConfig, Replay +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay np.random.seed(1) @@ -95,9 +93,7 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position -# TODO test penalties! -# TODO test invalid actions! -def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): +def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], @@ -108,126 +104,97 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") - + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_LEFT + action=RailEnvActions.MOVE_LEFT, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty ), # Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when stopped ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting + running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), - ], target=(3, 0), # west dead-end speed=0.5 ) - agentStatic: EnvAgentStatic = env.agents_static[0] - info_dict = { - 'action_required': [True] - } - for i, replay in enumerate(test_config.replay): - if i == 0: - # set the initial position - agentStatic.position = replay.position - agentStatic.direction = replay.direction - agentStatic.target = test_config.target - agentStatic.moving = True - agentStatic.speed_data['speed'] = test_config.speed - - # reset to set agents from agents_static - env.reset(False, False) - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) + run_replay_config(env, [test_config]) - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - if rendering: - renderer.render_env(show=True, show_observations=True) - - -def test_multispeed_actions_no_malfunction_blocking(rendering=True): +def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], @@ -237,81 +204,84 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") - + set_penalties_for_replay(env) test_configs = [ ReplayConfig( replay=[ Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 1.0 / 3.0 # starting and running at speed 1/3 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ), Replay( position=(3, 5), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3 ) ], target=(3, 0), # west dead-end @@ -321,69 +291,81 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # blocked although fraction >= 1.0 Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_LEFT + action=RailEnvActions.MOVE_LEFT, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # not blocked, action required! Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], target=(3, 0), # west dead-end @@ -391,52 +373,10 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): ) ] + run_replay_config(env, test_configs) + - # TODO test penalties! - info_dict = { - 'action_required': [True for _ in test_configs] - } - for step in range(len(test_configs[0].replay)): - if step == 0: - for a, test_config in enumerate(test_configs): - agentStatic: EnvAgentStatic = env.agents_static[a] - replay = test_config.replay[0] - # set the initial position - agentStatic.position = replay.position - agentStatic.direction = replay.direction - agentStatic.target = test_config.target - agentStatic.moving = True - agentStatic.speed_data['speed'] = test_config.speed - - # reset to set agents from agents_static - env.reset(False, False) - - def _assert(a, actual, expected, msg): - assert actual == expected, "[{}] {} {}: actual={}, expected={}".format(step, a, msg, actual, expected) - - action_dict = {} - - for a, test_config in enumerate(test_configs): - agent: EnvAgent = env.agents[a] - replay = test_config.replay[step] - - _assert(a, agent.position, replay.position, 'position') - _assert(a, agent.direction, replay.direction, 'direction') - - if replay.action is not None: - assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( - step, a, True) - action_dict[a] = replay.action - else: - assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( - step, a, False) - _, _, _, info_dict = env.step(action_dict) - - if rendering: - renderer.render_env(show=True, show_observations=True) - - -def test_multispeed_actions_malfunction_no_blocking(rendering=True): +def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], @@ -447,107 +387,202 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") - + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # add additional step in the cell Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, - malfunction=2 # recovers in two steps from now! + set_malfunction=2, # recovers in two steps from now!, + malfunction=2, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + malfunction=1, + reward=env.step_penalty * 0.5 # recovered: running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, - malfunction=2 # recovers in two steps from now! + set_malfunction=2, # recovers in two steps from now! + malfunction=2, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_LEFT, + malfunction=1, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty for speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + action=RailEnvActions.STOP_MOVING, + reward=env.step_penalty * 0.5 # step penalty for speed 0.5 while stopped ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), # DO_NOTHING keeps moving! Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.DO_NOTHING + action=RailEnvActions.DO_NOTHING, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, - action=None + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), Replay( position=(6, 6), direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.MOVE_FORWARD + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + run_replay_config(env, [test_config]) + + +# TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour? +def test_multispeed_actions_no_malfunction_invalid_actions(): + """Test that actions are correctly performed on cell exit for a single agent.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + set_penalties_for_replay(env) + test_config = ReplayConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5 # auto-correction left to forward without penalty! + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5 # wrong action is corrected to forward without penalty! + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5 # wrong action is corrected to forward without penalty! + ), Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 ), ], @@ -555,42 +590,4 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): speed=0.5 ) - # TODO test penalties! - agentStatic: EnvAgentStatic = env.agents_static[0] - info_dict = { - 'action_required': [True] - } - for i, replay in enumerate(test_config.replay): - if i == 0: - # set the initial position - agentStatic.position = replay.position - agentStatic.direction = replay.direction - agentStatic.target = test_config.target - agentStatic.moving = True - agentStatic.speed_data['speed'] = test_config.speed - - # reset to set agents from agents_static - env.reset(False, False) - - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) - - agent: EnvAgent = env.agents[0] - - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - - if replay.malfunction > 0: - agent.malfunction_data['malfunction'] = replay.malfunction - agent.malfunction_data['moving_before_malfunction'] = agent.moving - - if replay.action is not None: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) - - else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) - - if rendering: - renderer.render_env(show=True, show_observations=True) + run_replay_config(env, [test_config]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6347bd0f5048350c099ba2568dac7caba74baf2d..903120d868aa65833e7c2393ddfcc821c26da4f6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,13 @@ """Test Utils.""" -from typing import List, Tuple +from typing import List, Tuple, Optional +import numpy as np from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.rail_env import RailEnvActions +from flatland.envs.agent_utils import EnvAgent +from flatland.envs.rail_env import RailEnvActions, RailEnv +from flatland.utils.rendertools import RenderTool @attrs @@ -13,6 +16,8 @@ class Replay(object): direction = attrib(type=Grid4TransitionsEnum) action = attrib(type=RailEnvActions) malfunction = attrib(default=0, type=int) + set_malfunction = attrib(default=None, type=Optional[int]) + reward = attrib(default=None, type=Optional[float]) @attrs @@ -20,3 +25,89 @@ class ReplayConfig(object): replay = attrib(type=List[Replay]) target = attrib(type=Tuple[int, int]) speed = attrib(type=float) + + +# ensure that env is working correctly with start/stop/invalidaction penalty different from 0 +def set_penalties_for_replay(env: RailEnv): + env.step_penalty = -7 + env.start_penalty = -13 + env.stop_penalty = -19 + env.invalid_action_penalty = -29 + + +def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): + """ + Runs the replay configs and checks assertions. + + *Initially* + - the position, direction, target and speed of the initial step are taken to initialize the agents + + *Before each step* + - action must only be provided if action_required from previous step (initally all True) + - position, direction before step are verified + - optionally, set_malfunction is applied + - malfunction is verified + + *After each step* + - reward is verified after step + + Parameters + ---------- + env + test_configs + rendering + """ + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + info_dict = { + 'action_required': [True for _ in test_configs] + } + + for step in range(len(test_configs[0].replay)): + if step == 0: + for a, test_config in enumerate(test_configs): + agent: EnvAgent = env.agents[a] + replay = test_config.replay[0] + # set the initial position + agent.position = replay.position + agent.direction = replay.direction + agent.target = test_config.target + agent.speed_data['speed'] = test_config.speed + + def _assert(a, actual, expected, msg): + assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, + actual, + expected) + + action_dict = {} + + for a, test_config in enumerate(test_configs): + agent: EnvAgent = env.agents[a] + replay = test_config.replay[step] + + _assert(a, agent.position, replay.position, 'position') + _assert(a, agent.direction, replay.direction, 'direction') + + if replay.action is not None: + assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( + step, a, True) + action_dict[a] = replay.action + else: + assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( + step, a, False) + + if replay.set_malfunction is not None: + agent.malfunction_data['malfunction'] = replay.set_malfunction + agent.malfunction_data['moving_before_malfunction'] = agent.moving + _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + _, rewards_dict, _, info_dict = env.step(action_dict) + if rendering: + renderer.render_env(show=True, show_observations=True) + + for a, test_config in enumerate(test_configs): + replay = test_config.replay[step] + _assert(a, rewards_dict[a], replay.reward, 'reward') + +