diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cec0542c520277903d0b0f6b7746a365fc9862e7..0c1e3f844b8cf543f94d52a7b01b1e3af6964ce9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -28,6 +28,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.timetable_generators import timetable_generator from flatland.envs.step_utils.states import TrainState, StateTransitionSignals +from flatland.envs.step_utils.transition_utils import check_valid_action from flatland.envs.step_utils import action_preprocessing from flatland.envs.step_utils import env_utils @@ -437,6 +438,11 @@ class RailEnv(Environment): current_position, current_direction = agent.initial_position, agent.initial_direction action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) + + # Check transitions, bounts for executing the action in the given position and directon + if not check_valid_action(action, self.rail, current_position, current_direction): + action = RailEnvActions.STOP_MOVING + return action def clear_rewards_dict(self): @@ -579,14 +585,15 @@ class RailEnv(Environment): movement_allowed = movement_allowed and agent.state != TrainState.DONE # Agent is being added to map - if agent.state.is_on_map_state() and agent.state_machine.previous_state.is_off_map_state(): - agent.position = agent.initial_position - agent.direction = agent.initial_direction + if agent.state.is_on_map_state(): + if agent.state_machine.previous_state.is_off_map_state(): + agent.position = agent.initial_position + agent.direction = agent.initial_direction # Speed counter completes - elif movement_allowed and (agent.speed_counter.is_cell_exit): - agent.position = agent_transition_data.position - agent.direction = agent_transition_data.direction - agent.state_machine.update_if_reached(agent.position, agent.target) + elif movement_allowed and (agent.speed_counter.is_cell_exit): + agent.position = agent_transition_data.position + agent.direction = agent_transition_data.direction + agent.state_machine.update_if_reached(agent.position, agent.target) # Off map or on map state and position should match env_utils.state_position_sync_check(agent.state, agent.position, agent.handle) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 628298f0ac99341016037092d7ab797a89a2c14f..a23bf4c6df135dfbf82d4e999e39e1ab68884c90 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -86,7 +86,7 @@ def _step_along_shortest_path(env, obs_builder, rail): expected_next_position[agent.handle] = neighbour print(" improved (action) -> {}".format(actions[agent.handle])) _, rewards, dones, _ = env.step(actions) - return rewards + return rewards, dones def test_reward_function_conflict(rendering=False): @@ -162,8 +162,9 @@ def test_reward_function_conflict(rendering=False): }, } while iteration < 5: - rewards = _step_along_shortest_path(env, obs_builder, rail) - + rewards, dones = _step_along_shortest_path(env, obs_builder, rail) + if dones["__all__"]: + break for agent in env.agents: # assert rewards[agent.handle] == 0 expected_position = expected_positions[iteration + 1][agent.handle] @@ -289,7 +290,9 @@ def test_reward_function_waiting(rendering=False): } while iteration < 7: - rewards = _step_along_shortest_path(env, obs_builder, rail) + rewards, dones = _step_along_shortest_path(env, obs_builder, rail) + if dones["__all__"]: + break if rendering: renderer.render_env(show=True, show_observations=True)