Skip to content
Snippets Groups Projects
Commit 98e65ed7 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '167-malfunction-bugfix-action_on_cellexit' into 'master'

#167 bugfix action_on_cellexit

Closes #167, #162, and #164

See merge request flatland/flatland!179
parents db9bc6a3 8612ec45
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import warnings
from typing import Callable, Tuple, Any, Optional
from typing import Callable, Tuple, Optional, Dict, List, Any
import msgpack
import numpy as np
......@@ -11,7 +11,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
......@@ -560,63 +560,43 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
# Generate a set of nodes for the sparse network
# Try to connect cities to nodes first
node_positions = []
city_positions = []
intersection_positions = []
# Evenly distribute cities and intersections
node_positions: List[Any] = None
nb_nodes = num_cities + num_intersections
if grid_mode:
tot_num_node = num_intersections + num_cities
nodes_ratio = height / width
nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio)))
nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
city_idx = np.random.choice(np.arange(tot_num_node), num_cities)
city_idx = np.random.choice(np.arange(nb_nodes), num_cities)
for node_idx in range(num_cities + num_intersections):
to_close = True
tries = 0
node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions,
nb_nodes,
nodes_per_row, x_positions,
y_positions)
if not grid_mode:
while to_close:
x_tmp = node_radius + np.random.randint(height - node_radius)
y_tmp = node_radius + np.random.randint(width - node_radius)
to_close = False
# Check distance to cities
for node_pos in city_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
# Check distance to intersections
for node_pos in intersection_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
if not to_close:
node_positions.append((x_tmp, y_tmp))
if node_idx < num_cities:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
tries += 1
if tries > 100:
warnings.warn("Could not set nodes, please change initial parameters!!!!")
break
else:
x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row]
if node_idx in city_idx:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
node_positions = city_positions + intersection_positions
else:
node_positions = _generate_node_positions_not_grid_mode(city_positions, height,
intersection_positions,
nb_nodes, width)
# reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
nb_nodes = len(node_positions)
_num_cities = len(city_positions)
_num_intersections = len(intersection_positions)
# Chose node connection
# Set up list of available nodes to connect to
available_nodes_full = np.arange(num_cities + num_intersections)
available_cities = np.arange(num_cities)
available_intersections = np.arange(num_cities, num_cities + num_intersections)
available_nodes_full = np.arange(nb_nodes)
available_cities = np.arange(_num_cities)
available_intersections = np.arange(_num_cities, nb_nodes)
# Start at some node
current_node = np.random.randint(len(available_nodes_full))
......@@ -629,13 +609,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
# Priority city to intersection connections
if current_node < num_cities and len(available_intersections) > 0:
if current_node < _num_cities and len(available_intersections) > 0:
available_nodes = available_intersections
delete_idx = np.where(available_cities == current_node)
available_cities = np.delete(available_cities, delete_idx, 0)
# Priority intersection to city connections
elif current_node >= num_cities and len(available_cities) > 0:
elif current_node >= _num_cities and len(available_cities) > 0:
available_nodes = available_cities
delete_idx = np.where(available_intersections == current_node)
available_intersections = np.delete(available_intersections, delete_idx, 0)
......@@ -669,15 +649,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
node_stack.pop(0)
# Place train stations close to the node
# We currently place them uniformly distirbuted among all cities
# We currently place them uniformly distributed among all cities
built_num_trainstation = 0
train_stations = [[] for i in range(num_cities)]
train_stations = [[] for i in range(_num_cities)]
if num_cities > 1:
if _num_cities > 1:
for station in range(num_trainstations):
spot_found = True
trainstation_node = int(station / num_trainstations * num_cities)
trainstation_node = int(station / num_trainstations * _num_cities)
station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
0,
......@@ -702,6 +682,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
if tries > 100:
warnings.warn("Could not set trainstations, please change initial parameters!!!!")
spot_found = False
break
if spot_found:
train_stations[trainstation_node].append((station_x, station_y))
......@@ -725,7 +706,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
# We currently place them uniformly distirbuted among all cities
if enhance_intersection:
for intersection in range(num_intersections):
for intersection in range(_num_intersections):
intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
1,
height - 2)
......@@ -762,7 +743,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
# Slot availability in node
node_available_start = []
node_available_target = []
for node_idx in range(num_cities):
for node_idx in range(_num_cities):
node_available_start.append(len(train_stations[node_idx]))
node_available_target.append(len(train_stations[node_idx]))
......@@ -797,4 +778,57 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
'train_stations': train_stations
}}
def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
width):
node_positions = []
for node_idx in range(nb_nodes):
to_close = True
tries = 0
while to_close:
x_tmp = node_radius + np.random.randint(height - node_radius)
y_tmp = node_radius + np.random.randint(width - node_radius)
to_close = False
# Check distance to cities
for node_pos in city_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
# Check distance to intersections
for node_pos in intersection_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
if not to_close:
node_positions.append((x_tmp, y_tmp))
if node_idx < num_cities:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
tries += 1
if tries > 100:
warnings.warn(
"Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
len(node_positions),
tries, nb_nodes))
break
node_positions = city_positions + intersection_positions
return node_positions
def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
nodes_per_row, x_positions, y_positions):
for node_idx in range(nb_nodes):
x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row]
if node_idx in city_idx:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
node_positions = city_positions + intersection_positions
return node_positions
return generator
......@@ -55,24 +55,25 @@ def test_rail_env_action_required_info():
obs_builder_object=GlobalObsForRailEnv())
np.random.seed(0)
env_only_if_action_required = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6,
# Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
for step in range(100):
......@@ -87,7 +88,8 @@ def test_rail_env_action_required_info():
if step == 0 or info_only_if_action_required['action_required'][a]:
action_dict_only_if_action_required.update({a: action})
else:
print("[{}] not action_required {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data))
print("[{}] not action_required {}, speed_data={}".format(step, a,
env_always_action.agents[a].speed_data))
obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
action_dict_always_action)
......@@ -156,3 +158,23 @@ def test_rail_env_malfunction_speed_info():
if done['__all__']:
break
def test_sparse_generator_with_too_man_cities_does_not_break_down():
np.random.seed(0)
RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(
num_cities=100, # Number of cities in map
num_intersections=10, # Number of interesections in map
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
......@@ -110,3 +110,38 @@ def test_malfunction_process():
# Check that malfunctioning data was standing around
assert total_down_time > 0
def test_malfunction_process_statistically():
"""Tests hat malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 2,
'min_duration': 3,
'max_duration': 3}
np.random.seed(5)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
env.reset()
nb_malfunction = 0
for step in range(100):
action_dict = {}
for agent in env.agents:
if agent.malfunction_data['malfunction'] > 0:
nb_malfunction += 1
# We randomly select an action
action_dict[agent.handle] = np.random.randint(4)
env.step(action_dict)
# check that generation of malfunctions works as expected
# results are different in py36 and py37, therefore no exact test on nb_malfunction
assert nb_malfunction > 150
from typing import List
import numpy as np
from attr import attrib, attrs
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
np.random.seed(1)
......@@ -86,3 +95,505 @@ def test_multi_speed_init():
if (step + 1) % (i_agent + 1) == 0:
print(step, i_agent, env.agents[i_agent].position)
old_pos[i_agent] = env.agents[i_agent].position
@attrs
class Replay(object):
position = attrib()
direction = attrib()
action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int)
@attrs
class TestConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
"""Test that actions are correctly performed on cell exit for a single agent."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_config = TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=None
),
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
# TODO test penalties!
agentStatic: EnvAgentStatic = env.agents_static[0]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(test_config.replay):
if i == 0:
# set the initial position
agentStatic.position = replay.position
agentStatic.direction = replay.direction
agentStatic.target = test_config.target
agentStatic.moving = True
agentStatic.speed_data['speed'] = test_config.speed
# reset to set agents from agents_static
env.reset(False, False)
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
if replay.action:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_multispeed_actions_no_malfunction_blocking(rendering=True):
"""The second agent blocks the first because it is slower."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_configs = [
TestConfig(
replay=[
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None
)
],
target=(3, 0), # west dead-end
speed=1 / 3),
TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
# blocked although fraction >= 1.0
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
# blocked although fraction >= 1.0
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
# blocked although fraction >= 1.0
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
# not blocked, action required!
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
]
# TODO test penalties!
info_dict = {
'action_required': [True for _ in test_configs]
}
for step in range(len(test_configs[0].replay)):
if step == 0:
for a, test_config in enumerate(test_configs):
agentStatic: EnvAgentStatic = env.agents_static[a]
replay = test_config.replay[0]
# set the initial position
agentStatic.position = replay.position
agentStatic.direction = replay.direction
agentStatic.target = test_config.target
agentStatic.moving = True
agentStatic.speed_data['speed'] = test_config.speed
# reset to set agents from agents_static
env.reset(False, False)
def _assert(a, actual, expected, msg):
assert actual == expected, "[{}] {} {}: actual={}, expected={}".format(step, a, msg, actual, expected)
action_dict = {}
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.action:
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True)
action_dict[a] = replay.action
else:
assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False)
_, _, _, info_dict = env.step(action_dict)
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_multispeed_actions_malfunction_no_blocking(rendering=True):
"""Test on a single agent whether action on cell exit work correctly despite malfunction."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_config = TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
# add additional step in the cell
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=2 # recovers in two steps from now!
),
# agent recovers in this step
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2 # recovers in two steps from now!
),
# agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT,
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=None
),
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
# TODO test penalties!
agentStatic: EnvAgentStatic = env.agents_static[0]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(test_config.replay):
if i == 0:
# set the initial position
agentStatic.position = replay.position
agentStatic.direction = replay.direction
agentStatic.target = test_config.target
agentStatic.moving = True
agentStatic.speed_data['speed'] = test_config.speed
# reset to set agents from agents_static
env.reset(False, False)
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
if replay.malfunction:
agent.malfunction_data['malfunction'] = 2
if replay.action:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment