diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 435571c8216026b02e7ec4633ace2dcd85c15e4e..ead879a860d514ccb4233e9919257e459a7ab56c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -467,6 +467,7 @@ class RailEnv(Environment): _action_stored = True if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving self.rewards_dict[i_agent] += self.invalid_action_penalty self.rewards_dict[i_agent] += self.stop_penalty diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 57c0430d630615d02785545d0b54f3e9980c54d4..b0f274ba4c4b5453140fcc50bc6137e39e8e4f04 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -93,8 +93,7 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position -# TODO test invalid actions! -def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): +def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], @@ -195,7 +194,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): run_replay_config(env, [test_config]) -def test_multispeed_actions_no_malfunction_blocking(rendering=True): +def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], @@ -377,7 +376,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): run_replay_config(env, test_configs) -def test_multispeed_actions_malfunction_no_blocking(rendering=True): +def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], @@ -509,3 +508,86 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): speed=0.5 ) run_replay_config(env, [test_config]) + + +# TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour? +def test_multispeed_actions_no_malfunction_invalid_actions(): + """Test that actions are correctly performed on cell exit for a single agent.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + set_penalties_for_replay(env) + test_config = ReplayConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5 # auto-correction left to forward without penalty! + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5 # wrong action is corrected to forward without penalty! + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5 # wrong action is corrected to forward without penalty! + ), Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5 # running at speed 0.5 + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + + run_replay_config(env, [test_config]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5821cd05604aa79b0b55a73136161b49e847f01e..903120d868aa65833e7c2393ddfcc821c26da4f6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ """Test Utils.""" from typing import List, Tuple, Optional +import numpy as np from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum @@ -28,9 +29,10 @@ class ReplayConfig(object): # ensure that env is working correctly with start/stop/invalidaction penalty different from 0 def set_penalties_for_replay(env: RailEnv): - env.step_penalty = 13 - env.stop_penalty = 19 - env.invalid_action_penalty = 29 + env.step_penalty = -7 + env.start_penalty = -13 + env.stop_penalty = -19 + env.invalid_action_penalty = -29 def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): @@ -74,8 +76,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: agent.speed_data['speed'] = test_config.speed def _assert(a, actual, expected, msg): - assert actual == expected, "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, actual, - expected) + assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, + actual, + expected) action_dict = {} @@ -100,10 +103,11 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') _, rewards_dict, _, info_dict = env.step(action_dict) + if rendering: + renderer.render_env(show=True, show_observations=True) for a, test_config in enumerate(test_configs): replay = test_config.replay[step] _assert(a, rewards_dict[a], replay.reward, 'reward') - if rendering: - renderer.render_env(show=True, show_observations=True) +