diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cc85604e9b44a740929ba05536af866fb293e014..0ca62f9310c0c7a70946fdafb495e79c6e208a70 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -346,29 +346,28 @@ class RailEnv(Environment): # Check if agent breaks at this step malfunction = self._agent_malfunction(i_agent, action) - # if we're at the beginning of the cell, store the action - # As long as we're broken down at the beginning of the cell, we can choose other actions! - # This is a design choice made by Erik and Christian. - - # TODO refactor!!! - # If the agent can make an action + # Is the agent at the beginning of the cell? Then, it can take an action + # Design choice (Erik+Christian): + # as long as we're broken down at the beginning of the cell, we can choose other actions! if agent.speed_data['position_fraction'] == 0.0: if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving action = RailEnvActions.MOVE_FORWARD - if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0: + if action == RailEnvActions.STOP_MOVING and agent.moving: # Only allow halting an agent on entering new cells. agent.moving = False self.rewards_dict[i_agent] += self.stop_penalty - if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + if not agent.moving and not ( + action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True self.rewards_dict[i_agent] += self.start_penalty - if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ + # Store the action + if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]: + _, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(action, agent) if all([new_cell_valid, transition_valid]): @@ -377,7 +376,7 @@ class RailEnv(Environment): # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # try to keep moving forward! if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ + _, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) if all([new_cell_valid, transition_valid]): @@ -388,7 +387,6 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.stop_penalty agent.moving = False - action = RailEnvActions.DO_NOTHING else: # If the agent cannot move due to an invalid transition, we set its state to not moving @@ -396,10 +394,10 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.stop_penalty agent.moving = False - action = RailEnvActions.DO_NOTHING else: agent.speed_data['transition_action_on_cellexit'] = action + # if we're broken, nothing else to do if malfunction: continue @@ -422,16 +420,10 @@ class RailEnv(Environment): # Nothing left to do with broken agent continue - # Now perform a movement. - # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) - # store the desired action in `transition_action_on_cellexit' (only if the desired transition is - # allowed! otherwise DO_NOTHING!) - # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the - # position_fraction by the speed of the agent (regardless of action taken, as long as no - # STOP_MOVING, but that makes agent.moving=False) + # If agent.moving, increment the position_fraction by the speed of the agent # If the new position fraction is >= 1, reset to 0, and perform the stored - # transition_action_on_cellexit + # transition_action_on_cellexit if the cell is free. if agent.moving: @@ -445,9 +437,11 @@ class RailEnv(Environment): RailEnvActions.STOP_MOVING]: agent.speed_data['position_fraction'] = 0.0 else: + # cell and transition validity was checked when we stored transition_action_on_cellexit! cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( agent.speed_data['transition_action_on_cellexit'], agent) - assert cell_free == all([cell_free, new_cell_valid, transition_valid]) + if not cell_free == all([cell_free, new_cell_valid, transition_valid]): + warnings.warn("Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!") if cell_free: agent.position = new_position agent.direction = new_direction diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 8de36c81e4a13c0b7e7e5e556ad79234503ad31a..b8b1afaf433d7155e48d17a2c277a579c72798ce 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,8 +1,17 @@ -import numpy as np +from typing import List -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator +import numpy as np +from attr import attrib, attrs + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map +from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator +from flatland.utils.rendertools import RenderTool +from flatland.utils.simple_rail import make_simple_rail np.random.seed(1) @@ -86,3 +95,149 @@ def test_multi_speed_init(): if (step + 1) % (i_agent + 1) == 0: print(step, i_agent, env.agents[i_agent].position) old_pos[i_agent] = env.agents[i_agent].position + + +# TODO test malfunction +# TODO test other agent blocking +def test_multispeed_actions_no_malfunction(rendering=True): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # initialize agents_static + env.reset() + + @attrs + class Replay(object): + position = attrib() + direction = attrib() + action = attrib(type=RailEnvActions) + + @attrs + class TestConfig(object): + replay = attrib(type=List[Replay]) + target = attrib() + speed = attrib(type=float) + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + + test_configs = [ + TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + ] + + # TODO test penalties! + agentStatic: EnvAgentStatic = env.agents_static[0] + for test_config in test_configs: + info_dict = { + 'action_required': [True] + } + for i, replay in enumerate(test_config.replay): + if i == 0: + # set the initial position + agentStatic.position = replay.position + agentStatic.direction = replay.direction + agentStatic.target = test_config.target + agentStatic.moving = True + agentStatic.speed_data['speed'] = test_config.speed + + # reset to set agents from agents_static + env.reset(False, False) + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + + if replay.action: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True)