Skip to content
Snippets Groups Projects
Commit 3fb4cb97 authored by u214892's avatar u214892
Browse files

#167 bugfix action_on_cellexit

parent 4fae3ccb
No related branches found
No related tags found
No related merge requests found
...@@ -346,29 +346,28 @@ class RailEnv(Environment): ...@@ -346,29 +346,28 @@ class RailEnv(Environment):
# Check if agent breaks at this step # Check if agent breaks at this step
malfunction = self._agent_malfunction(i_agent, action) malfunction = self._agent_malfunction(i_agent, action)
# if we're at the beginning of the cell, store the action # Is the agent at the beginning of the cell? Then, it can take an action
# As long as we're broken down at the beginning of the cell, we can choose other actions! # Design choice (Erik+Christian):
# This is a design choice made by Erik and Christian. # as long as we're broken down at the beginning of the cell, we can choose other actions!
# TODO refactor!!!
# If the agent can make an action
if agent.speed_data['position_fraction'] == 0.0: if agent.speed_data['position_fraction'] == 0.0:
if action == RailEnvActions.DO_NOTHING and agent.moving: if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving # Keep moving
action = RailEnvActions.MOVE_FORWARD action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0: if action == RailEnvActions.STOP_MOVING and agent.moving:
# Only allow halting an agent on entering new cells. # Only allow halting an agent on entering new cells.
agent.moving = False agent.moving = False
self.rewards_dict[i_agent] += self.stop_penalty self.rewards_dict[i_agent] += self.stop_penalty
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): if not agent.moving and not (
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action # Allow agent to start with any forward or direction action
agent.moving = True agent.moving = True
self.rewards_dict[i_agent] += self.start_penalty self.rewards_dict[i_agent] += self.start_penalty
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: # Store the action
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]:
_, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent) self._check_action_on_agent(action, agent)
if all([new_cell_valid, transition_valid]): if all([new_cell_valid, transition_valid]):
...@@ -377,7 +376,7 @@ class RailEnv(Environment): ...@@ -377,7 +376,7 @@ class RailEnv(Environment):
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward! # try to keep moving forward!
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ _, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_valid, transition_valid]): if all([new_cell_valid, transition_valid]):
...@@ -388,7 +387,6 @@ class RailEnv(Environment): ...@@ -388,7 +387,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False agent.moving = False
action = RailEnvActions.DO_NOTHING
else: else:
# If the agent cannot move due to an invalid transition, we set its state to not moving # If the agent cannot move due to an invalid transition, we set its state to not moving
...@@ -396,10 +394,10 @@ class RailEnv(Environment): ...@@ -396,10 +394,10 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False agent.moving = False
action = RailEnvActions.DO_NOTHING
else: else:
agent.speed_data['transition_action_on_cellexit'] = action agent.speed_data['transition_action_on_cellexit'] = action
# if we're broken, nothing else to do
if malfunction: if malfunction:
continue continue
...@@ -422,16 +420,10 @@ class RailEnv(Environment): ...@@ -422,16 +420,10 @@ class RailEnv(Environment):
# Nothing left to do with broken agent # Nothing left to do with broken agent
continue continue
# Now perform a movement. # Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) # If agent.moving, increment the position_fraction by the speed of the agent
# store the desired action in `transition_action_on_cellexit' (only if the desired transition is
# allowed! otherwise DO_NOTHING!)
# Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
# position_fraction by the speed of the agent (regardless of action taken, as long as no
# STOP_MOVING, but that makes agent.moving=False)
# If the new position fraction is >= 1, reset to 0, and perform the stored # If the new position fraction is >= 1, reset to 0, and perform the stored
# transition_action_on_cellexit # transition_action_on_cellexit if the cell is free.
if agent.moving: if agent.moving:
...@@ -445,9 +437,11 @@ class RailEnv(Environment): ...@@ -445,9 +437,11 @@ class RailEnv(Environment):
RailEnvActions.STOP_MOVING]: RailEnvActions.STOP_MOVING]:
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
else: else:
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent) agent.speed_data['transition_action_on_cellexit'], agent)
assert cell_free == all([cell_free, new_cell_valid, transition_valid]) if not cell_free == all([cell_free, new_cell_valid, transition_valid]):
warnings.warn("Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!")
if cell_free: if cell_free:
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
......
import numpy as np from typing import List
from flatland.envs.rail_env import RailEnv import numpy as np
from flatland.envs.rail_generators import complex_rail_generator from attr import attrib, attrs
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
np.random.seed(1) np.random.seed(1)
...@@ -86,3 +95,149 @@ def test_multi_speed_init(): ...@@ -86,3 +95,149 @@ def test_multi_speed_init():
if (step + 1) % (i_agent + 1) == 0: if (step + 1) % (i_agent + 1) == 0:
print(step, i_agent, env.agents[i_agent].position) print(step, i_agent, env.agents[i_agent].position)
old_pos[i_agent] = env.agents[i_agent].position old_pos[i_agent] = env.agents[i_agent].position
# TODO test malfunction
# TODO test other agent blocking
def test_multispeed_actions_no_malfunction(rendering=True):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
@attrs
class Replay(object):
position = attrib()
direction = attrib()
action = attrib(type=RailEnvActions)
@attrs
class TestConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_configs = [
TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=None
),
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
]
# TODO test penalties!
agentStatic: EnvAgentStatic = env.agents_static[0]
for test_config in test_configs:
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(test_config.replay):
if i == 0:
# set the initial position
agentStatic.position = replay.position
agentStatic.direction = replay.direction
agentStatic.target = test_config.target
agentStatic.moving = True
agentStatic.speed_data['speed'] = test_config.speed
# reset to set agents from agents_static
env.reset(False, False)
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
if replay.action:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment