Skip to content
Snippets Groups Projects
Commit 44b55f20 authored by u214892's avatar u214892
Browse files

#168 #163 multispeed and penalty testing

parent d2ac83fe
No related branches found
No related tags found
No related merge requests found
......@@ -236,7 +236,8 @@ class RailEnv(Environment):
Relies on the rail_generator returning agent_static lists (pos, dir, target)
"""
# TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
# TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
# can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
if optionals and 'distance_map' in optionals:
......@@ -256,6 +257,9 @@ class RailEnv(Environment):
agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
# TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
# why do we need static agents? could we it more elegantly?
self.agents_static = EnvAgentStatic.from_lists(
*self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
self.restart_agents()
......@@ -356,6 +360,8 @@ class RailEnv(Environment):
# Perform step on all agents
for i_agent in range(self.get_num_agents()):
if self._elapsed_steps - 1 == 3 and i_agent == 0:
a = 5
self._step_agent(i_agent, action_dict_.get(i_agent))
# Check for end of episode + set global reward to all rewards!
......@@ -407,13 +413,14 @@ class RailEnv(Environment):
# is the agent malfunctioning?
malfunction = self._agent_malfunction(i_agent)
# if agent is broken, actions are ignored and agent does not move,
# the agent is not penalized in this step!
# if agent is broken, actions are ignored and agent does not move.
# full step penalty in this case
if malfunction:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken!
if agent.speed_data['position_fraction'] == 0.0:
# No action has been supplied for this agent -> set DO_NOTHING as default
if action is None:
......@@ -464,7 +471,6 @@ class RailEnv(Environment):
if not _action_stored:
# If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False
......@@ -497,6 +503,9 @@ class RailEnv(Environment):
agent.moving = False
else:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else:
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
"""
......
......@@ -4,13 +4,11 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from test_utils import Replay
from test_utils import Replay, ReplayConfig, run_replay_config
class SingleAgentNavigationObs(TreeObsForRailEnv):
......@@ -155,7 +153,7 @@ def test_malfunction_process_statistically():
assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction(rendering=True):
def test_initial_malfunction():
random.seed(0)
np.random.seed(0)
......@@ -190,74 +188,55 @@ def test_initial_malfunction(rendering=True):
stochastic_data=stochastic_data, # Malfunction data generator
)
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=3
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
),
Replay(
position=(27, 4),
direction=Grid4TransitionsEnum.NORTH,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
)
]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(replay_steps):
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_initial_malfunction_stop_moving(rendering=True):
replay_config = ReplayConfig(
replay=[
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty when malfunctioning
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2,
reward=env.step_penalty # full step penalty when malfunctioning
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1,
reward=env.start_penalty + env.step_penalty * 1.0
# malfunctioning ends: starting and running at speed 1.0
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
),
Replay(
position=(27, 4),
direction=Grid4TransitionsEnum.NORTH,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
)
],
speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target
)
run_replay_config(env, [replay_config])
def test_initial_malfunction_stop_moving():
random.seed(0)
np.random.seed(0)
......@@ -292,80 +271,62 @@ def test_initial_malfunction_stop_moving(rendering=True):
stochastic_data=stochastic_data, # Malfunction data generator
)
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1
),
# we have stopped and do nothing --> should stand still
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
)
]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(replay_steps):
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
replay_config = ReplayConfig(
replay=[
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty when stopped
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty # full step penalty when stopped
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
#
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1,
reward=env.step_penalty # full step penalty while stopped
),
# we have stopped and do nothing --> should stand still
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty # full step penalty while stopped
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # full step penalty while stopped
)
],
speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target
)
run_replay_config(env, [replay_config])
def test_initial_malfunction_do_nothing(rendering=True):
......@@ -403,77 +364,58 @@ def test_initial_malfunction_do_nothing(rendering=True):
stochastic_data=stochastic_data, # Malfunction data generator
)
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
replay_config = ReplayConfig(
replay=[Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty while malfunctioning
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1
),
# we haven't started moving yet --> stay here
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
)
]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(replay_steps):
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty # full step penalty while malfunctioning
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1,
reward=env.step_penalty # full step penalty while stopped
),
# we haven't started moving yet --> stay here
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty # full step penalty while stopped
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # step penalty for speed 1.0
)
],
speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target
)
run_replay_config(env, [replay_config])
This diff is collapsed.
......@@ -4,7 +4,9 @@ from typing import List, Tuple, Optional
from attr import attrs, attrib
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.utils.rendertools import RenderTool
@attrs
......@@ -13,7 +15,8 @@ class Replay(object):
direction = attrib(type=Grid4TransitionsEnum)
action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int)
penalty = attrib(default=None, type=Optional[float])
set_malfunction = attrib(default=None, type=Optional[int])
reward = attrib(default=None, type=Optional[float])
@attrs
......@@ -21,3 +24,78 @@ class ReplayConfig(object):
replay = attrib(type=List[Replay])
target = attrib(type=Tuple[int, int])
speed = attrib(type=float)
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
"""
Runs the replay configs and checks assertions.
*Initially*
- the position, direction, target and speed of the initial step are taken to initialize the agents
*Before each step*
- action must only be provided if action_required from previous step (initally all True)
- position, direction before step are verified
- optionally, set_malfunction is applied
- malfunction is verified
*After each step*
- reward is verified after step
Parameters
----------
env
test_configs
rendering
"""
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
info_dict = {
'action_required': [True for _ in test_configs]
}
for step in range(len(test_configs[0].replay)):
if step == 0:
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[0]
# set the initial position
agent.position = replay.position
agent.direction = replay.direction
agent.target = test_config.target
agent.speed_data['speed'] = test_config.speed
def _assert(a, actual, expected, msg):
assert actual == expected, "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, actual,
expected)
action_dict = {}
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.action is not None:
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
step, a, True)
action_dict[a] = replay.action
else:
assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
step, a, False)
if replay.set_malfunction is not None:
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['moving_before_malfunction'] = agent.moving
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
_, rewards_dict, _, info_dict = env.step(action_dict)
for a, test_config in enumerate(test_configs):
replay = test_config.replay[step]
_assert(a, rewards_dict[a], replay.reward, 'reward')
if rendering:
renderer.render_env(show=True, show_observations=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment