Skip to content
Snippets Groups Projects
Commit 15db8d20 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '163-multispeed-tests-penalties' into 'master'

Resolve "Test for Multi-Speed"

Closes #163 and #168

See merge request flatland/flatland!194
parents 7f186674 ee1fc8fd
No related branches found
No related tags found
No related merge requests found
...@@ -46,8 +46,6 @@ def a_star(grid_map: GridTransitionMap, ...@@ -46,8 +46,6 @@ def a_star(grid_map: GridTransitionMap,
""" """
rail_shape = grid_map.grid.shape rail_shape = grid_map.grid.shape
tmp = np.zeros(rail_shape) - 10
start_node = AStarNode(start, None) start_node = AStarNode(start, None)
end_node = AStarNode(end, None) end_node = AStarNode(end, None)
open_nodes = OrderedSet() open_nodes = OrderedSet()
...@@ -114,8 +112,6 @@ def a_star(grid_map: GridTransitionMap, ...@@ -114,8 +112,6 @@ def a_star(grid_map: GridTransitionMap,
child.h = a_star_distance_function(child.pos, end_node.pos) child.h = a_star_distance_function(child.pos, end_node.pos)
child.f = child.g + child.h child.f = child.g + child.h
tmp[child.pos[0]][child.pos[1]] = child.f
# already in the open list? # already in the open list?
if child in open_nodes: if child in open_nodes:
continue continue
......
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2DArray from flatland.core.grid.grid_utils import IntVector2D
def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum: def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
""" """
Assumes pos1 and pos2 are adjacent location on grid. Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions. Returns direction (int) that can be used with transitions.
...@@ -10,13 +10,13 @@ def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4Transi ...@@ -10,13 +10,13 @@ def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4Transi
diff_0 = pos2[0] - pos1[0] diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1] diff_1 = pos2[1] - pos1[1]
if diff_0 < 0: if diff_0 < 0:
return 0 return Grid4TransitionsEnum.NORTH
if diff_0 > 0: if diff_0 > 0:
return 2 return Grid4TransitionsEnum.SOUTH
if diff_1 > 0: if diff_1 > 0:
return 1 return Grid4TransitionsEnum.EAST
if diff_1 < 0: if diff_1 < 0:
return 3 return Grid4TransitionsEnum.WEST
raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
......
...@@ -7,22 +7,25 @@ a GridTransitionMap object. ...@@ -7,22 +7,25 @@ a GridTransitionMap object.
from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_basic_operation(
start: IntVector2D, rail_trans: RailEnvTransitions,
end: IntVector2D, grid_map: GridTransitionMap,
flip_start_node_trans=False, start: IntVector2D,
flip_end_node_trans=False, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): flip_start_node_trans=False,
flip_end_node_trans=False,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
""" """
Creates a new path [start,end] in grid_map, based on rail_trans. Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions.
""" """
# in the worst case we will need to do a A* search, so we might as well set that up # in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(grid_map, start, end, a_star_distance_function) path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function)
if len(path) < 2: if len(path) < 2:
return [] return []
current_dir = get_direction(path[0], path[1]) current_dir = get_direction(path[0], path[1])
...@@ -71,23 +74,24 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi ...@@ -71,23 +74,24 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function) return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function)
def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function) return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function)
def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance
) -> IntVector2DArray:
return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function) return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function)
def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function)
...@@ -237,7 +237,8 @@ class RailEnv(Environment): ...@@ -237,7 +237,8 @@ class RailEnv(Environment):
Relies on the rail_generator returning agent_static lists (pos, dir, target) Relies on the rail_generator returning agent_static lists (pos, dir, target)
""" """
# TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
# can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
if optionals and 'distance_map' in optionals: if optionals and 'distance_map' in optionals:
...@@ -257,6 +258,9 @@ class RailEnv(Environment): ...@@ -257,6 +258,9 @@ class RailEnv(Environment):
agents_hints = None agents_hints = None
if optionals and 'agents_hints' in optionals: if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints'] agents_hints = optionals['agents_hints']
# TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
# why do we need static agents? could we it more elegantly?
self.agents_static = EnvAgentStatic.from_lists( self.agents_static = EnvAgentStatic.from_lists(
*self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
self.restart_agents() self.restart_agents()
...@@ -408,13 +412,14 @@ class RailEnv(Environment): ...@@ -408,13 +412,14 @@ class RailEnv(Environment):
# is the agent malfunctioning? # is the agent malfunctioning?
malfunction = self._agent_malfunction(i_agent) malfunction = self._agent_malfunction(i_agent)
# if agent is broken, actions are ignored and agent does not move, # if agent is broken, actions are ignored and agent does not move.
# the agent is not penalized in this step! # full step penalty in this case
if malfunction: if malfunction:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return return
# Is the agent at the beginning of the cell? Then, it can take an action. # Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken!
if agent.speed_data['position_fraction'] == 0.0: if agent.speed_data['position_fraction'] == 0.0:
# No action has been supplied for this agent -> set DO_NOTHING as default # No action has been supplied for this agent -> set DO_NOTHING as default
if action is None: if action is None:
...@@ -463,9 +468,9 @@ class RailEnv(Environment): ...@@ -463,9 +468,9 @@ class RailEnv(Environment):
_action_stored = True _action_stored = True
if not _action_stored: if not _action_stored:
# If the agent cannot move due to an invalid transition, we set its state to not moving # If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False agent.moving = False
...@@ -498,6 +503,9 @@ class RailEnv(Environment): ...@@ -498,6 +503,9 @@ class RailEnv(Environment):
agent.moving = False agent.moving = False
else: else:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else:
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
""" """
......
...@@ -2,8 +2,8 @@ import numpy as np ...@@ -2,8 +2,8 @@ import numpy as np
import pytest import pytest
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
from flatland.core.grid.grid4_utils import get_direction from flatland.core.grid.grid4_utils import get_direction
from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
depth_to_test = 5 depth_to_test = 5
positions_to_test = [0, 5, 1, 6, 20, 30] positions_to_test = [0, 5, 1, 6, 20, 30]
...@@ -31,4 +31,4 @@ def test_get_direction(): ...@@ -31,4 +31,4 @@ def test_get_direction():
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
with pytest.raises(Exception, match="Could not determine direction"): with pytest.raises(Exception, match="Could not determine direction"):
get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH get_direction((0, 0), (0, 0))
import random import random
from typing import Dict
import numpy as np import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from test_utils import Replay
class SingleAgentNavigationObs(TreeObsForRailEnv): class SingleAgentNavigationObs(TreeObsForRailEnv):
...@@ -54,7 +53,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): ...@@ -54,7 +53,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
min_distances.append(np.inf) min_distances.append(np.inf)
observation = [0, 0, 0] observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1 observation[np.argmin(min_distances)[0]] = 1
return observation return observation
...@@ -83,7 +82,6 @@ def test_malfunction_process(): ...@@ -83,7 +82,6 @@ def test_malfunction_process():
agent_halts = 0 agent_halts = 0
total_down_time = 0 total_down_time = 0
agent_malfunctioning = False
agent_old_position = env.agents[0].position agent_old_position = env.agents[0].position
for step in range(100): for step in range(100):
actions = {} actions = {}
...@@ -142,12 +140,12 @@ def test_malfunction_process_statistically(): ...@@ -142,12 +140,12 @@ def test_malfunction_process_statistically():
env.reset() env.reset()
nb_malfunction = 0 nb_malfunction = 0
for step in range(100): for step in range(100):
action_dict = {} action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents: for agent in env.agents:
if agent.malfunction_data['malfunction'] > 0: if agent.malfunction_data['malfunction'] > 0:
nb_malfunction += 1 nb_malfunction += 1
# We randomly select an action # We randomly select an action
action_dict[agent.handle] = np.random.randint(4) action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
env.step(action_dict) env.step(action_dict)
...@@ -155,7 +153,7 @@ def test_malfunction_process_statistically(): ...@@ -155,7 +153,7 @@ def test_malfunction_process_statistically():
assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction) assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction(rendering=True): def test_initial_malfunction():
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -189,75 +187,56 @@ def test_initial_malfunction(rendering=True): ...@@ -189,75 +187,56 @@ def test_initial_malfunction(rendering=True):
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
) )
set_penalties_for_replay(env)
if rendering: replay_config = ReplayConfig(
renderer = RenderTool(env) replay=[
renderer.render_env(show=True, frames=False, show_observations=False) Replay(
_action = dict() position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
replay_steps = [ action=RailEnvActions.MOVE_FORWARD,
Replay( set_malfunction=3,
position=(28, 5), malfunction=3,
direction=Grid4TransitionsEnum.EAST, reward=env.step_penalty # full step penalty when malfunctioning
action=RailEnvActions.MOVE_FORWARD, ),
malfunction=3 Replay(
), position=(28, 5),
Replay( direction=Grid4TransitionsEnum.EAST,
position=(28, 5), action=RailEnvActions.MOVE_FORWARD,
direction=Grid4TransitionsEnum.EAST, malfunction=2,
action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty # full step penalty when malfunctioning
malfunction=2 ),
), # malfunction stops in the next step and we're still at the beginning of the cell
# malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell Replay(
Replay( position=(28, 5),
position=(28, 5), direction=Grid4TransitionsEnum.EAST,
direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD,
action=RailEnvActions.MOVE_FORWARD, malfunction=1,
malfunction=1 reward=env.start_penalty + env.step_penalty * 1.0
), # malfunctioning ends: starting and running at speed 1.0
Replay( ),
position=(28, 4), Replay(
direction=Grid4TransitionsEnum.WEST, position=(28, 4),
action=RailEnvActions.MOVE_FORWARD, direction=Grid4TransitionsEnum.WEST,
malfunction=0 action=RailEnvActions.MOVE_FORWARD,
), malfunction=0,
Replay( reward=env.step_penalty * 1.0 # running at speed 1.0
position=(27, 4), ),
direction=Grid4TransitionsEnum.NORTH, Replay(
action=RailEnvActions.MOVE_FORWARD, position=(27, 4),
malfunction=0 direction=Grid4TransitionsEnum.NORTH,
) action=RailEnvActions.MOVE_FORWARD,
] malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
info_dict = { )
'action_required': [True] ],
} speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target
for i, replay in enumerate(replay_steps): )
run_replay_config(env, [replay_config])
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
def test_initial_malfunction_stop_moving():
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_initial_malfunction_stop_moving(rendering=True):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -291,84 +270,66 @@ def test_initial_malfunction_stop_moving(rendering=True): ...@@ -291,84 +270,66 @@ def test_initial_malfunction_stop_moving(rendering=True):
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
) )
set_penalties_for_replay(env)
if rendering: replay_config = ReplayConfig(
renderer = RenderTool(env) replay=[
renderer.render_env(show=True, frames=False, show_observations=False) Replay(
_action = dict() position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
replay_steps = [ action=RailEnvActions.DO_NOTHING,
Replay( set_malfunction=3,
position=(28, 5), malfunction=3,
direction=Grid4TransitionsEnum.EAST, reward=env.step_penalty # full step penalty when stopped
action=RailEnvActions.DO_NOTHING, ),
malfunction=3 Replay(
), position=(28, 5),
Replay( direction=Grid4TransitionsEnum.EAST,
position=(28, 5), action=RailEnvActions.DO_NOTHING,
direction=Grid4TransitionsEnum.EAST, malfunction=2,
action=RailEnvActions.DO_NOTHING, reward=env.step_penalty # full step penalty when stopped
malfunction=2 ),
), # malfunction stops in the next step and we're still at the beginning of the cell
# malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving
# --> if we take action DO_NOTHING, agent should restart without moving #
# Replay(
Replay( position=(28, 5),
position=(28, 5), direction=Grid4TransitionsEnum.EAST,
direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING,
action=RailEnvActions.STOP_MOVING, malfunction=1,
malfunction=1 reward=env.step_penalty # full step penalty while stopped
), ),
# we have stopped and do nothing --> should stand still # we have stopped and do nothing --> should stand still
Replay( Replay(
position=(28, 5), position=(28, 5),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=0 malfunction=0,
), reward=env.step_penalty # full step penalty while stopped
# we start to move forward --> should go to next cell now ),
Replay( # we start to move forward --> should go to next cell now
position=(28, 5), Replay(
direction=Grid4TransitionsEnum.EAST, position=(28, 5),
action=RailEnvActions.MOVE_FORWARD, direction=Grid4TransitionsEnum.EAST,
malfunction=0 action=RailEnvActions.MOVE_FORWARD,
), malfunction=0,
Replay( reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped
position=(28, 4), ),
direction=Grid4TransitionsEnum.WEST, Replay(
action=RailEnvActions.MOVE_FORWARD, position=(28, 4),
malfunction=0 direction=Grid4TransitionsEnum.WEST,
) action=RailEnvActions.MOVE_FORWARD,
] malfunction=0,
reward=env.step_penalty * 1.0 # full step penalty while stopped
info_dict = { )
'action_required': [True] ],
} speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target
for i, replay in enumerate(replay_steps): )
def _assert(actual, expected, msg): run_replay_config(env, [replay_config])
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0] def test_initial_malfunction_do_nothing():
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_initial_malfunction_do_nothing(rendering=True):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -402,78 +363,59 @@ def test_initial_malfunction_do_nothing(rendering=True): ...@@ -402,78 +363,59 @@ def test_initial_malfunction_do_nothing(rendering=True):
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
) )
set_penalties_for_replay(env)
if rendering: replay_config = ReplayConfig(
renderer = RenderTool(env) replay=[Replay(
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3
),
Replay(
position=(28, 5), position=(28, 5),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=2 set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty while malfunctioning
), ),
# malfunction stops in the next step and we're still at the beginning of the cell Replay(
# --> if we take action DO_NOTHING, agent should restart without moving position=(28, 5),
# direction=Grid4TransitionsEnum.EAST,
Replay( action=RailEnvActions.DO_NOTHING,
position=(28, 5), malfunction=2,
direction=Grid4TransitionsEnum.EAST, reward=env.step_penalty # full step penalty while malfunctioning
action=RailEnvActions.DO_NOTHING, ),
malfunction=1 # malfunction stops in the next step and we're still at the beginning of the cell
), # --> if we take action DO_NOTHING, agent should restart without moving
# we haven't started moving yet --> stay here #
Replay( Replay(
position=(28, 5), position=(28, 5),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=0 malfunction=1,
), reward=env.step_penalty # full step penalty while stopped
# we start to move forward --> should go to next cell now ),
Replay( # we haven't started moving yet --> stay here
position=(28, 5), Replay(
direction=Grid4TransitionsEnum.EAST, position=(28, 5),
action=RailEnvActions.MOVE_FORWARD, direction=Grid4TransitionsEnum.EAST,
malfunction=0 action=RailEnvActions.DO_NOTHING,
), malfunction=0,
Replay( reward=env.step_penalty # full step penalty while stopped
position=(28, 4), ),
direction=Grid4TransitionsEnum.WEST, # we start to move forward --> should go to next cell now
action=RailEnvActions.MOVE_FORWARD, Replay(
malfunction=0 position=(28, 5),
) direction=Grid4TransitionsEnum.EAST,
] action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
info_dict = { reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0
'action_required': [True] ),
} Replay(
position=(28, 4),
for i, replay in enumerate(replay_steps): direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
def _assert(actual, expected, msg): malfunction=0,
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) reward=env.step_penalty * 1.0 # step penalty for speed 1.0
)
agent: EnvAgent = env.agents[0] ],
speed=env.agents[0].speed_data['speed'],
_assert(agent.position, replay.position, 'position') target=env.agents[0].target
_assert(agent.direction, replay.direction, 'direction') )
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
run_replay_config(env, [replay_config])
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
This diff is collapsed.
"""Test Utils.""" """Test Utils."""
from typing import List, Tuple from typing import List, Tuple, Optional
import numpy as np
from attr import attrs, attrib from attr import attrs, attrib
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.rail_env import RailEnvActions from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.utils.rendertools import RenderTool
@attrs @attrs
...@@ -13,6 +16,8 @@ class Replay(object): ...@@ -13,6 +16,8 @@ class Replay(object):
direction = attrib(type=Grid4TransitionsEnum) direction = attrib(type=Grid4TransitionsEnum)
action = attrib(type=RailEnvActions) action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int) malfunction = attrib(default=0, type=int)
set_malfunction = attrib(default=None, type=Optional[int])
reward = attrib(default=None, type=Optional[float])
@attrs @attrs
...@@ -20,3 +25,89 @@ class ReplayConfig(object): ...@@ -20,3 +25,89 @@ class ReplayConfig(object):
replay = attrib(type=List[Replay]) replay = attrib(type=List[Replay])
target = attrib(type=Tuple[int, int]) target = attrib(type=Tuple[int, int])
speed = attrib(type=float) speed = attrib(type=float)
# ensure that env is working correctly with start/stop/invalidaction penalty different from 0
def set_penalties_for_replay(env: RailEnv):
env.step_penalty = -7
env.start_penalty = -13
env.stop_penalty = -19
env.invalid_action_penalty = -29
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
"""
Runs the replay configs and checks assertions.
*Initially*
- the position, direction, target and speed of the initial step are taken to initialize the agents
*Before each step*
- action must only be provided if action_required from previous step (initally all True)
- position, direction before step are verified
- optionally, set_malfunction is applied
- malfunction is verified
*After each step*
- reward is verified after step
Parameters
----------
env
test_configs
rendering
"""
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
info_dict = {
'action_required': [True for _ in test_configs]
}
for step in range(len(test_configs[0].replay)):
if step == 0:
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[0]
# set the initial position
agent.position = replay.position
agent.direction = replay.direction
agent.target = test_config.target
agent.speed_data['speed'] = test_config.speed
def _assert(a, actual, expected, msg):
assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg,
actual,
expected)
action_dict = {}
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.action is not None:
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
step, a, True)
action_dict[a] = replay.action
else:
assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
step, a, False)
if replay.set_malfunction is not None:
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['moving_before_malfunction'] = agent.moving
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
_, rewards_dict, _, info_dict = env.step(action_dict)
if rendering:
renderer.render_env(show=True, show_observations=True)
for a, test_config in enumerate(test_configs):
replay = test_config.replay[step]
_assert(a, rewards_dict[a], replay.reward, 'reward')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment