Commit 4b960cbb authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

check bounds on position on map

parent 3c1fff85
Pipeline #8505 failed with stages
in 6 minutes and 24 seconds
...@@ -28,6 +28,7 @@ from flatland.envs.observations import GlobalObsForRailEnv ...@@ -28,6 +28,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.timetable_generators import timetable_generator from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.envs.step_utils import action_preprocessing from flatland.envs.step_utils import action_preprocessing
from flatland.envs.step_utils import env_utils from flatland.envs.step_utils import env_utils
...@@ -437,6 +438,11 @@ class RailEnv(Environment): ...@@ -437,6 +438,11 @@ class RailEnv(Environment):
current_position, current_direction = agent.initial_position, agent.initial_direction current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
# Check transitions, bounts for executing the action in the given position and directon
if not check_valid_action(action, self.rail, current_position, current_direction):
action = RailEnvActions.STOP_MOVING
return action return action
def clear_rewards_dict(self): def clear_rewards_dict(self):
...@@ -579,14 +585,15 @@ class RailEnv(Environment): ...@@ -579,14 +585,15 @@ class RailEnv(Environment):
movement_allowed = movement_allowed and agent.state != TrainState.DONE movement_allowed = movement_allowed and agent.state != TrainState.DONE
# Agent is being added to map # Agent is being added to map
if agent.state.is_on_map_state() and agent.state_machine.previous_state.is_off_map_state(): if agent.state.is_on_map_state():
agent.position = agent.initial_position if agent.state_machine.previous_state.is_off_map_state():
agent.direction = agent.initial_direction agent.position = agent.initial_position
agent.direction = agent.initial_direction
# Speed counter completes # Speed counter completes
elif movement_allowed and (agent.speed_counter.is_cell_exit): elif movement_allowed and (agent.speed_counter.is_cell_exit):
agent.position = agent_transition_data.position agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction agent.direction = agent_transition_data.direction
agent.state_machine.update_if_reached(agent.position, agent.target) agent.state_machine.update_if_reached(agent.position, agent.target)
# Off map or on map state and position should match # Off map or on map state and position should match
env_utils.state_position_sync_check(agent.state, agent.position, agent.handle) env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
......
...@@ -86,7 +86,7 @@ def _step_along_shortest_path(env, obs_builder, rail): ...@@ -86,7 +86,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
expected_next_position[agent.handle] = neighbour expected_next_position[agent.handle] = neighbour
print(" improved (action) -> {}".format(actions[agent.handle])) print(" improved (action) -> {}".format(actions[agent.handle]))
_, rewards, dones, _ = env.step(actions) _, rewards, dones, _ = env.step(actions)
return rewards return rewards, dones
def test_reward_function_conflict(rendering=False): def test_reward_function_conflict(rendering=False):
...@@ -162,8 +162,9 @@ def test_reward_function_conflict(rendering=False): ...@@ -162,8 +162,9 @@ def test_reward_function_conflict(rendering=False):
}, },
} }
while iteration < 5: while iteration < 5:
rewards = _step_along_shortest_path(env, obs_builder, rail) rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
if dones["__all__"]:
break
for agent in env.agents: for agent in env.agents:
# assert rewards[agent.handle] == 0 # assert rewards[agent.handle] == 0
expected_position = expected_positions[iteration + 1][agent.handle] expected_position = expected_positions[iteration + 1][agent.handle]
...@@ -289,7 +290,9 @@ def test_reward_function_waiting(rendering=False): ...@@ -289,7 +290,9 @@ def test_reward_function_waiting(rendering=False):
} }
while iteration < 7: while iteration < 7:
rewards = _step_along_shortest_path(env, obs_builder, rail) rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
if dones["__all__"]:
break
if rendering: if rendering:
renderer.render_env(show=True, show_observations=True) renderer.render_env(show=True, show_observations=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment