Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 4784 additions and 265 deletions
import pytest
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_oval_rail
def test_shortest_paths():
rail, rail_map, optionals = make_oval_rail()
speed_ratio_map = {1.: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_shortest_path = env.agents[0].get_shortest_path(env.distance_map)
agent1_shortest_path = env.agents[1].get_shortest_path(env.distance_map)
assert len(agent0_shortest_path) == 10
assert len(agent1_shortest_path) == 10
def test_travel_time_on_shortest_paths():
rail, rail_map, optionals = make_oval_rail()
speed_ratio_map = {1.: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 10
assert agent1_travel_time == 10
speed_ratio_map = {1/2: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 20
assert agent1_travel_time == 20
speed_ratio_map = {1/3: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 30
assert agent1_travel_time == 30
speed_ratio_map = {1/4: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 40
assert agent1_travel_time == 40
# def test_latest_arrival_validity():
# pass
# def test_time_remaining_until_latest_arrival():
# pass
def main():
pass
if __name__ == "__main__":
main()
......@@ -2,8 +2,9 @@ import numpy as np
import pytest
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
from flatland.core.grid.grid4_utils import get_direction
from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
depth_to_test = 5
positions_to_test = [0, 5, 1, 6, 20, 30]
......@@ -31,4 +32,8 @@ def test_get_direction():
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
with pytest.raises(Exception, match="Could not determine direction"):
get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
get_direction((0, 0), (0, 0))
def test_load():
load_flatland_environment_from_file('test_001.pkl', 'env_data.tests')
......@@ -4,27 +4,30 @@
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.step_utils.states import TrainState
"""Tests for `flatland` package."""
def test_global_obs():
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
global_obs = env.reset()
global_obs, info = env.reset()
# we have to take step for the agent to enter the grid.
global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})
assert (global_obs[0][0].shape == rail_map.shape + (16,))
......@@ -38,29 +41,30 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
obs_agents_state = global_obs[0][1]
obs_agents_state = obs_agents_state + 1
assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def _step_along_shortest_path(env, obs_builder, rail):
actions = {}
expected_next_position = {}
for agent in env.agents:
agent: EnvAgent
shortest_distance = np.inf
for exit_direction in range(4):
neighbour = obs_builder._new_position(agent.position, exit_direction)
neighbour = get_new_position(agent.position, exit_direction)
if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
desired_movement_from_new_cell = (exit_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `entry_direction' to the neighbour possible?
# Is a transition along movement `entry_direction` to the neighbour possible?
is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
distance_to_target = obs_builder.distance_map[
distance_to_target = obs_builder.env.distance_map.get()[
(agent.handle, *agent.position, exit_direction)]
print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
agent.direction,
......@@ -82,40 +86,50 @@ def _step_along_shortest_path(env, obs_builder, rail):
expected_next_position[agent.handle] = neighbour
print(" improved (action) -> {}".format(actions[agent.handle]))
_, rewards, dones, _ = env.step(actions)
return rewards
return rewards, dones
def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
obs_builder: TreeObsForRailEnv = env.obs_builder
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent = env.agents[0]
agent.position = (5, 6) # south dead-end
agent.initial_position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
agent = env.agents_static[1]
agent = env.agents[1]
agent.position = (3, 8) # east dead-end
agent.initial_position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.initial_direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
# reset to set agents from agents_static
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
env.agents[0].position = (5, 6)
env.agents[1].position = (3, 8)
print("\n")
print(env.agents[0])
print(env.agents[1])
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=True)
renderer.render_env(show=True, show_observations=True)
iteration = 0
expected_positions = {
......@@ -148,52 +162,61 @@ def test_reward_function_conflict(rendering=False):
},
}
while iteration < 5:
rewards = _step_along_shortest_path(env, obs_builder, rail)
rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
if dones["__all__"]:
break
for agent in env.agents:
assert rewards[agent.handle] == -1
# assert rewards[agent.handle] == 0
expected_position = expected_positions[iteration + 1][agent.handle]
assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1,
agent.handle,
agent.position,
expected_position)
if rendering:
renderer.renderEnv(show=True, show_observations=True)
renderer.render_env(show=True, show_observations=True)
iteration += 1
def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=2,
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
remove_agents_at_target=False, random_seed=1)
obs_builder: TreeObsForRailEnv = env.obs_builder
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent = env.agents[0]
agent.initial_position = (3, 8) # east dead-end
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.initial_direction = 3 # west
agent.target = (3, 1) # west dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
agent = env.agents_static[1]
agent = env.agents[1]
agent.initial_position = (5, 6) # south dead-end
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 8) # east dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
# reset to set agents from agents_static
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
env.agents[0].position = (3, 8)
env.agents[1].position = (5, 6)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=True)
renderer.render_env(show=True, show_observations=True)
iteration = 0
expectations = {
......@@ -202,14 +225,14 @@ def test_reward_function_waiting(rendering=False):
0: (3, 8),
1: (5, 6),
},
'rewards': [-1, -1],
'rewards': [0, 0],
},
1: {
'positions': {
0: (3, 7),
1: (4, 6),
},
'rewards': [-1, -1],
'rewards': [0, 0],
},
# second agent has to wait for first, first can continue
2: {
......@@ -217,7 +240,7 @@ def test_reward_function_waiting(rendering=False):
0: (3, 6),
1: (4, 6),
},
'rewards': [-1, -1],
'rewards': [0, 0],
},
# both can move again
3: {
......@@ -225,14 +248,14 @@ def test_reward_function_waiting(rendering=False):
0: (3, 5),
1: (3, 6),
},
'rewards': [-1, -1],
'rewards': [0, 0],
},
4: {
'positions': {
0: (3, 4),
1: (3, 7),
},
'rewards': [-1, -1],
'rewards': [0, 0],
},
# second reached target
5: {
......@@ -240,14 +263,14 @@ def test_reward_function_waiting(rendering=False):
0: (3, 3),
1: (3, 8),
},
'rewards': [-1, 0],
'rewards': [0, 0],
},
6: {
'positions': {
0: (3, 2),
1: (3, 8),
},
'rewards': [-1, 0],
'rewards': [0, 0],
},
# first reaches, target too
7: {
......@@ -255,26 +278,27 @@ def test_reward_function_waiting(rendering=False):
0: (3, 1),
1: (3, 8),
},
'rewards': [1, 1],
'rewards': [0, 0],
},
8: {
'positions': {
0: (3, 1),
1: (3, 8),
},
'rewards': [1, 1],
'rewards': [0, 0],
},
}
while iteration < 7:
rewards = _step_along_shortest_path(env, obs_builder, rail)
rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
if dones["__all__"]:
break
if rendering:
renderer.renderEnv(show=True, show_observations=True)
renderer.render_env(show=True, show_observations=True)
print(env.dones["__all__"])
for agent in env.agents:
agent: EnvAgent
print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
for agent in env.agents:
......@@ -284,10 +308,10 @@ def test_reward_function_waiting(rendering=False):
agent.handle,
agent.position,
expected_position)
expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
actual_reward = rewards[agent.handle]
assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
agent.handle,
actual_reward,
expected_reward)
# expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
# actual_reward = rewards[agent.handle]
# assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
# agent.handle,
# actual_reward,
# expected_reward)
iteration += 1
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
def test_load_new():
filename = "test_load_new.pkl"
rail, rail_map, optionals = make_simple_rail()
n_agents = 2
env_initial = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env_initial.reset(False, False)
rails_initial = env_initial.rail.grid
agents_initial = env_initial.agents
RailEnvPersister.save(env_initial, filename)
env_loaded, _ = RailEnvPersister.load_new(filename)
rails_loaded = env_loaded.rail.grid
agents_loaded = env_loaded.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
def main():
pass
if __name__ == "__main__":
main()
......@@ -5,39 +5,49 @@ import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv, Node
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.step_utils.states import TrainState
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
"""Test predictions for `flatland` package."""
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
# reset to initialize agents_static
env.reset()
# set initial position and direction for testing...
env.agents_static[0].position = (5, 6)
env.agents_static[0].direction = 0
env.agents_static[0].target = (3, 0)
env.agents[0].initial_position = (5, 6)
env.agents[0].initial_direction = 0
env.agents[0].direction = 0
env.agents[0].target = (3, 0)
# reset to set agents from agents_static
env.reset(False, False)
env.agents[0].earliest_departure = 1
env._max_episode_steps = 100
# Make Agent 0 active
env.step({})
env.step({0: RailEnvActions.MOVE_FORWARD})
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
renderer.render_env(show=True, show_observations=False)
input("Continue?")
# test assertions
......@@ -89,7 +99,7 @@ def test_dummy_predictor(rendering=False):
expected_actions = np.array([[0.],
[2.],
[2.],
[1.],
[2.],
[2.],
[2.],
[2.],
......@@ -107,37 +117,52 @@ def test_dummy_predictor(rendering=False):
def test_shortest_path_predictor(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent = env.agents[0]
agent.initial_position = (5, 6) # south dead-end
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
# reset to set agents from agents_static
env.reset(False, False)
env.distance_map._compute(env.agents, env.rail)
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
renderer.render_env(show=True, show_observations=False)
input("Continue?")
# compute the observations and predictions
distance_map = env.obs_builder.distance_map
assert distance_map[0, agent.position[0], agent.position[
1], agent.direction] == 5.0, "found {} instead of {}".format(
distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
distance_map = env.distance_map.get()
distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction]
assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0)
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
Waypoint((5, 6), 0),
Waypoint((4, 6), 0),
Waypoint((3, 6), 0),
Waypoint((3, 7), 1),
Waypoint((3, 8), 1),
Waypoint((3, 9), 1)
]
# extract the data
predictions = env.obs_builder.predictions
......@@ -217,52 +242,66 @@ def test_shortest_path_predictor(rendering=False):
[20.],
])
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(positions, expected_positions), \
"positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_invalid_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent = env.agents_static[1]
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
# reset to set agents from agents_static
observations = env.reset(False, False)
env.agents[0].initial_position = (5, 6) # south dead-end
env.agents[0].position = (5, 6) # south dead-end
env.agents[0].direction = 0 # north
env.agents[0].initial_direction = 0 # north
env.agents[0].target = (3, 9) # east dead-end
env.agents[0].moving = True
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1].initial_position = (3, 8) # east dead-end
env.agents[1].position = (3, 8) # east dead-end
env.agents[1].direction = 3 # west
env.agents[1].initial_direction = 3 # west
env.agents[1].target = (6, 6) # south dead-end
env.agents[1].moving = True
env.agents[1]._set_state(TrainState.MOVING)
observations, info = env.reset(False, False)
env.agents[0].position = (5, 6) # south dead-end
env.agent_positions[env.agents[0].position] = 0
env.agents[1].position = (3, 8) # east dead-end
env.agent_positions[env.agents[1].position] = 1
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
observations = env._get_observations()
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=False)
renderer.render_env(show=True, show_observations=False)
input("Continue?")
# get the trees to test
obs_builder: TreeObsForRailEnv = env.obs_builder
pp = pprint.PrettyPrinter(indent=4)
tree_0 = obs_builder.unfold_observation_tree(observations[0])
tree_1 = obs_builder.unfold_observation_tree(observations[1])
pp.pprint(tree_0)
tree_0 = observations[0]
tree_1 = observations[1]
env.obs_builder.util_print_obs_subtree(tree_0)
env.obs_builder.util_print_obs_subtree(tree_1)
# check the expectations
expected_conflicts_0 = [('F', 'R')]
......@@ -271,11 +310,18 @@ def test_shortest_path_predictor_conflicts(rendering=False):
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][''][8]
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][a_2][''][8]
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''):
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_opposite_direction
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import numpy as np
import pytest
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
def test_load_env():
env = RailEnv(10, 10)
env.load_resource('env_data.tests', 'test-10x10.mpk')
def test_save_load():
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
agent_1_pos = env.agents[0].position
agent_1_dir = env.agents[0].direction
agent_1_tar = env.agents[0].target
agent_2_pos = env.agents[1].position
agent_2_dir = env.agents[1].direction
agent_2_tar = env.agents[1].target
os.makedirs("tmp", exist_ok=True)
def test_save_load():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
number_of_agents=2)
env.reset()
agent_1_pos = env.agents_static[0].position
agent_1_dir = env.agents_static[0].direction
agent_1_tar = env.agents_static[0].target
agent_2_pos = env.agents_static[1].position
agent_2_dir = env.agents_static[1].direction
agent_2_tar = env.agents_static[1].target
env.save("test_save.dat")
env.load("test_save.dat")
assert (env.width == 10)
assert (env.height == 10)
RailEnvPersister.save(env, "tmp/test_save.pkl")
env.save("tmp/test_save_2.pkl")
# env.load("test_save.dat")
env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
assert (env.width == 30)
assert (env.height == 30)
assert (len(env.agents) == 2)
assert (agent_1_pos == env.agents_static[0].position)
assert (agent_1_dir == env.agents_static[0].direction)
assert (agent_1_tar == env.agents_static[0].target)
assert (agent_2_pos == env.agents_static[1].position)
assert (agent_2_dir == env.agents_static[1].direction)
assert (agent_2_tar == env.agents_static[1].target)
def test_rail_environment_single_agent():
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
assert (agent_1_pos == env.agents[0].position)
assert (agent_1_dir == env.agents[0].direction)
assert (agent_1_tar == env.agents[0].target)
assert (agent_2_pos == env.agents[1].position)
assert (agent_2_dir == env.agents[1].direction)
assert (agent_2_tar == env.agents[1].target)
@pytest.mark.skip("Msgpack serializing not supported")
def test_save_load_mpk():
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
os.makedirs("tmp", exist_ok=True)
RailEnvPersister.save(env, "tmp/test_save.mpk")
# env.load("test_save.dat")
env2, env_dict = RailEnvPersister.load_new("tmp/test_save.mpk")
assert (env.width == env2.width)
assert (env.height == env2.height)
assert (len(env2.agents) == len(env.agents))
for agent1, agent2 in zip(env.agents, env2.agents):
assert (agent1.position == agent2.position)
assert (agent1.direction == agent2.direction)
assert (agent1.target == agent2.target)
@pytest.mark.skip(reason="Old file used to create env, not sure how to regenerate")
def test_rail_environment_single_agent(show=False):
# We instantiate the following map on a 3x3 grid
# _ _
# / \/ \
......@@ -65,32 +85,48 @@ def test_rail_environment_single_agent():
# \_/\_/
transitions = RailEnvTransitions()
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
# Simple turn not in the base transitions ?
south_east_turn = int('0100000000000010', 2)
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
rail_map = np.array([[south_east_turn, south_symmetrical_switch,
south_west_turn],
[vertical_line, vertical_line, vertical_line],
[north_east_turn, north_symmetrical_switch,
north_west_turn]],
dtype=np.uint16)
rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=3,
height=3,
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
for _ in range(200):
_ = rail_env.reset()
if False:
# This env creation doesn't quite work right.
cells = transitions.transition_list
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
south_east_turn = int('0100000000000010', 2)
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
rail_map = np.array([[south_east_turn, south_symmetrical_switch,
south_west_turn],
[vertical_line, vertical_line, vertical_line],
[north_east_turn, north_symmetrical_switch,
north_west_turn]],
dtype=np.uint16)
rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
else:
rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
rail_map = rail_env.rail.grid
rail_env._max_episode_steps = 1000
_ = rail_env.reset(False, False, True)
liActions = [int(a) for a in RailEnvActions]
env_renderer = RenderTool(rail_env)
# RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
for _ in range(5):
# rail_env.agents[0].initial_position = (1,2)
_ = rail_env.reset(False, False, True)
# We do not care about target for the moment
agent = rail_env.agents[0]
......@@ -99,47 +135,80 @@ def test_rail_environment_single_agent():
# Check that trains are always initialized at a consistent position
# or direction.
# They should always be able to go somewhere.
assert (transitions.get_transitions(
rail_map[agent.position],
agent.direction) != (0, 0, 0, 0))
if show:
print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
print(transitions.get_transitions(rail_map[agent.position], agent.direction))
# assert (transitions.get_transitions(
# rail_map[agent.position],
# agent.direction) != (0, 0, 0, 0))
# HACK - force the direction to one we know is good.
# agent.initial_position = agent.position = (2,3)
agent.initial_direction = agent.direction = 0
initial_pos = agent.position
if show:
print("handle:", agent.handle)
# agent.initial_position = initial_pos = agent.position
valid_active_actions_done = 0
pos = initial_pos
pos = agent.position
if show:
env_renderer.render_env(show=show, show_agents=True)
time.sleep(0.01)
iStep = 0
while valid_active_actions_done < 6:
# We randomly select an action
action = np.random.randint(4)
action = np.random.choice(liActions)
# action = RailEnvActions.MOVE_FORWARD
_, _, _, _ = rail_env.step({0: action})
_, _, dict_done, _ = rail_env.step({0: action})
prev_pos = pos
pos = agent.position # rail_env.agents_position[0]
print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
print(dict_done)
if prev_pos != pos:
valid_active_actions_done += 1
iStep += 1
if show:
env_renderer.render_env(show=show, show_agents=True, step=iStep)
time.sleep(0.01)
assert iStep < 100, "valid actions should have been performed by now - hung agent"
# After 6 movements on this railway network, the train should be back
# to its original height on the map.
assert (initial_pos[0] == agent.position[0])
# assert (initial_pos[0] == agent.position[0])
# We check that the train always attains its target after some time
for _ in range(10):
_ = rail_env.reset()
done = False
while not done:
rail_env.agents[0].direction = 0
# JW - to avoid problem with sparse_line_generator.
# rail_env.agents[0].position = (1,2)
iStep = 0
while iStep < 100:
# We randomly select an action
action = np.random.randint(4)
action = np.random.choice(liActions)
_, _, dones, _ = rail_env.step({0: action})
done = dones['__all__']
test_rail_environment_single_agent()
if done:
break
iStep += 1
assert iStep < 100, "agent should have finished by now"
env_renderer.render_env(show=show)
def test_dead_end():
transitions = Grid4Transitions([])
transitions = RailEnvTransitions()
straight_vertical = int('1000000000100000', 2) # Case 1 - straight
straight_horizontal = transitions.rotate_transition(straight_vertical,
......@@ -162,38 +231,31 @@ def test_dead_end():
transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
def check_consistency(rail_env):
# We run step to check that trains do not move anymore
# after being done.
# TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving
# until they are manually stopped.
for i in range(7):
prev_pos = rail_env.agents[0].position
# The train cannot turn, so we check that when it tries,
# it stays where it is.
_ = rail_env.step({0: 1})
_ = rail_env.step({0: 3})
assert (rail_env.agents[0].position == prev_pos)
_, _, dones, _ = rail_env.step({0: 2})
if i < 5:
assert (not dones[0] and not dones['__all__'])
else:
assert (dones[0] and dones['__all__'])
city_positions = [(0, 0), (0, 3)]
train_stations = [
[((0, 0), 0)],
[((0, 0), 0)],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
# We try the configuration in the 4 directions:
rail_env.reset()
rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)]
rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]
rail_env.reset()
rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)]
rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]
# In the vertical configuration:
rail_map = np.array(
......@@ -205,15 +267,130 @@ def test_dead_end():
height=rail_map.shape[0],
transitions=transitions)
city_positions = [(0, 0), (0, 3)]
train_stations = [
[((0, 0), 0)],
[((0, 0), 0)],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
rail_env.reset()
rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)]
rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]
rail_env.reset()
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
# TODO make assertions
def test_get_entry_directions():
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
def _assert(position, expected):
actual = env.get_valid_directions_on_grid(*position)
assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected)
# north dead end
_assert((0, 3), [True, False, False, False])
# west dead end
_assert((3, 0), [False, False, False, True])
# switch
_assert((3, 3), [False, True, True, True])
# horizontal
_assert((3, 2), [False, True, False, True])
# vertical
_assert((2, 3), [True, False, True, False])
# nowhere
_assert((0, 0), [False, False, False, False])
def test_rail_env_reset():
file_name = "test_rail_env_reset.pkl"
# Test to save and load file.
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# env.save(file_name)
RailEnvPersister.save(env, file_name)
dist_map_shape = np.shape(env.distance_map.get())
rails_initial = env.rail.grid
agents_initial = env.agents
# env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
# line_generator=line_from_file(file_name), number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
# env2.reset(False, False, False)
env2, env2_dict = RailEnvPersister.load_new(file_name)
rails_loaded = env2.rail.grid
agents_loaded = env2.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env3.reset(False, True)
rails_loaded = env3.rail.grid
agents_loaded = env3.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env4.reset(True, False)
rails_loaded = env4.rail.grid
agents_loaded = env4.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
def main():
# test_rail_environment_single_agent(show=True)
test_rail_env_reset()
if __name__ == "__main__":
main()
import sys
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives
from flatland.envs.persistence import RailEnvPersister
def test_get_shortest_paths_unreachable():
rail, rail_map, optionals = make_disconnected_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
# set the initial position
agent = env.agents[0]
agent.position = (3, 1) # west dead-end
agent.initial_position = (3, 1) # west dead-end
agent.direction = Grid4TransitionsEnum.WEST
agent.target = (3, 9) # east dead-end
agent.moving = True
env.reset(False, False)
actual = get_shortest_paths(env.distance_map)
expected = {0: None}
assert actual[0] == expected[0], "actual={},expected={}".format(actual[0], expected[0])
# todo file test_002.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
def test_get_shortest_paths():
#env = load_flatland_environment_from_file('test_002.mpk', 'env_data.tests')
env, env_dict = RailEnvPersister.load_new("test_002.mpk", "env_data.tests")
#print("env len(agents): ", len(env.agents))
#print(env.distance_map)
#print("env number_of_agents:", env.number_of_agents)
#print("env agents:", env.agents)
#env.distance_map.reset(env.agents, env.rail)
#actual = get_shortest_paths(env.distance_map)
#print("shortest paths:", actual)
#print(env.distance_map)
#print("Dist map agents:", env.distance_map.agents)
#print("\nenv reset()")
env.reset()
actual = get_shortest_paths(env.distance_map)
#print("env agents: ", len(env.agents))
#print("env number_of_agents: ", env.number_of_agents)
assert len(actual) == 2, "get_shortest_paths should return a dict of length 2"
expected = {
0: [
Waypoint(position=(1, 1), direction=1),
Waypoint(position=(1, 2), direction=1),
Waypoint(position=(1, 3), direction=1),
Waypoint(position=(2, 3), direction=2),
Waypoint(position=(2, 4), direction=1),
Waypoint(position=(2, 5), direction=1),
Waypoint(position=(2, 6), direction=1),
Waypoint(position=(2, 7), direction=1),
Waypoint(position=(2, 8), direction=1),
Waypoint(position=(2, 9), direction=1),
Waypoint(position=(2, 10), direction=1),
Waypoint(position=(2, 11), direction=1),
Waypoint(position=(2, 12), direction=1),
Waypoint(position=(2, 13), direction=1),
Waypoint(position=(2, 14), direction=1),
Waypoint(position=(2, 15), direction=1),
Waypoint(position=(2, 16), direction=1),
Waypoint(position=(2, 17), direction=1),
Waypoint(position=(2, 18), direction=1)],
1: [
Waypoint(position=(3, 18), direction=3),
Waypoint(position=(3, 17), direction=3),
Waypoint(position=(3, 16), direction=3),
Waypoint(position=(2, 16), direction=0),
Waypoint(position=(2, 15), direction=3),
Waypoint(position=(2, 14), direction=3),
Waypoint(position=(2, 13), direction=3),
Waypoint(position=(2, 12), direction=3),
Waypoint(position=(2, 11), direction=3),
Waypoint(position=(2, 10), direction=3),
Waypoint(position=(2, 9), direction=3),
Waypoint(position=(2, 8), direction=3),
Waypoint(position=(2, 7), direction=3),
Waypoint(position=(2, 6), direction=3),
Waypoint(position=(2, 5), direction=3),
Waypoint(position=(2, 4), direction=3),
Waypoint(position=(2, 3), direction=3),
Waypoint(position=(2, 2), direction=3),
Waypoint(position=(2, 1), direction=3)]
}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
# todo file test_002.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
def test_get_shortest_paths_max_depth():
#env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
env, _ = RailEnvPersister.load_new("test_002.mpk", "env_data.tests")
env.reset()
actual = get_shortest_paths(env.distance_map, max_depth=2)
expected = {
0: [
Waypoint(position=(1, 1), direction=1),
Waypoint(position=(1, 2), direction=1)
],
1: [
Waypoint(position=(3, 18), direction=3),
Waypoint(position=(3, 17), direction=3),
]
}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
# todo file Level_distance_map_shortest_path.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
def test_get_shortest_paths_agent_handle():
#env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests')
env, _ = RailEnvPersister.load_new("Level_distance_map_shortest_path.mpk", "env_data.tests")
env.reset()
actual = get_shortest_paths(env.distance_map, agent_handle=6)
print(actual, file=sys.stderr)
expected = {6:
[Waypoint(position=(5, 5),
direction=0),
Waypoint(position=(4, 5),
direction=0),
Waypoint(position=(3, 5),
direction=0),
Waypoint(position=(2, 5),
direction=0),
Waypoint(position=(1, 5),
direction=0),
Waypoint(position=(0, 5),
direction=0),
Waypoint(position=(0, 6),
direction=1),
Waypoint(position=(0, 7), direction=1),
Waypoint(position=(0, 8),
direction=1),
Waypoint(position=(0, 9),
direction=1),
Waypoint(position=(0, 10),
direction=1),
Waypoint(position=(1, 10),
direction=2),
Waypoint(position=(2, 10),
direction=2),
Waypoint(position=(3, 10),
direction=2),
Waypoint(position=(4, 10),
direction=2),
Waypoint(position=(5, 10),
direction=2),
Waypoint(position=(6, 10),
direction=2),
Waypoint(position=(7, 10),
direction=2),
Waypoint(position=(8, 10),
direction=2),
Waypoint(position=(9, 10),
direction=2),
Waypoint(position=(10, 10),
direction=2),
Waypoint(position=(11, 10),
direction=2),
Waypoint(position=(12, 10),
direction=2),
Waypoint(position=(13, 10),
direction=2),
Waypoint(position=(14, 10),
direction=2),
Waypoint(position=(15, 10),
direction=2),
Waypoint(position=(16, 10),
direction=2),
Waypoint(position=(17, 10),
direction=2),
Waypoint(position=(18, 10),
direction=2),
Waypoint(position=(19, 10),
direction=2),
Waypoint(position=(20, 10),
direction=2),
Waypoint(position=(20, 9),
direction=3),
Waypoint(position=(20, 8),
direction=3),
Waypoint(position=(21, 8),
direction=2),
Waypoint(position=(21, 7),
direction=3),
Waypoint(position=(21, 6),
direction=3),
Waypoint(position=(21, 5),
direction=3)
]}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
def test_get_k_shortest_paths(rendering=False):
rail, rail_map, optionals = make_simple_rail_with_alternatives()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
env.reset()
initial_position = (3, 1) # west dead-end
initial_direction = Grid4TransitionsEnum.WEST # west
target_position = (3, 9) # east
# set the initial position
agent = env.agents[0]
agent.position = initial_position
agent.initial_position = initial_position
agent.direction = initial_direction
agent.target = target_position # east dead-end
agent.moving = True
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input()
actual = set(get_k_shortest_paths(
env=env,
source_position=initial_position, # west dead-end
source_direction=int(initial_direction), # east
target_position=target_position,
k=10
))
expected = set([
(
Waypoint(position=(3, 1), direction=3),
Waypoint(position=(3, 0), direction=3),
Waypoint(position=(3, 1), direction=1),
Waypoint(position=(3, 2), direction=1),
Waypoint(position=(3, 3), direction=1),
Waypoint(position=(2, 3), direction=0),
Waypoint(position=(1, 3), direction=0),
Waypoint(position=(0, 3), direction=0),
Waypoint(position=(0, 4), direction=1),
Waypoint(position=(0, 5), direction=1),
Waypoint(position=(0, 6), direction=1),
Waypoint(position=(0, 7), direction=1),
Waypoint(position=(0, 8), direction=1),
Waypoint(position=(0, 9), direction=1),
Waypoint(position=(1, 9), direction=2),
Waypoint(position=(2, 9), direction=2),
Waypoint(position=(3, 9), direction=2)),
(
Waypoint(position=(3, 1), direction=3),
Waypoint(position=(3, 0), direction=3),
Waypoint(position=(3, 1), direction=1),
Waypoint(position=(3, 2), direction=1),
Waypoint(position=(3, 3), direction=1),
Waypoint(position=(3, 4), direction=1),
Waypoint(position=(3, 5), direction=1),
Waypoint(position=(3, 6), direction=1),
Waypoint(position=(4, 6), direction=2),
Waypoint(position=(5, 6), direction=2),
Waypoint(position=(6, 6), direction=2),
Waypoint(position=(5, 6), direction=0),
Waypoint(position=(4, 6), direction=0),
Waypoint(position=(4, 7), direction=1),
Waypoint(position=(4, 8), direction=1),
Waypoint(position=(4, 9), direction=1),
Waypoint(position=(3, 9), direction=0))
])
assert actual == expected, "actual={},expected={}".format(actual, expected)
def main():
test_get_shortest_paths()
if __name__ == "__main__":
main()
import unittest
import warnings
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
def test_sparse_rail_generator():
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=1,
grid_mode=False
),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(),
random_seed=1)
env.reset(False, False)
# for r in range(env.height):
# for c in range(env.width):
# if env.rail.grid[r][c] > 0:
# print("expected_grid_map[{}][{}] = {}".format(r, c, env.rail.grid[r][c]))
expected_grid_map = env.rail.grid
expected_grid_map[4][9] = 16386
expected_grid_map[4][10] = 1025
expected_grid_map[4][11] = 1025
expected_grid_map[4][12] = 1025
expected_grid_map[4][13] = 1025
expected_grid_map[4][14] = 1025
expected_grid_map[4][15] = 1025
expected_grid_map[4][16] = 1025
expected_grid_map[4][17] = 1025
expected_grid_map[4][18] = 1025
expected_grid_map[4][19] = 1025
expected_grid_map[4][20] = 1025
expected_grid_map[4][21] = 1025
expected_grid_map[4][22] = 17411
expected_grid_map[4][23] = 17411
expected_grid_map[4][24] = 1025
expected_grid_map[4][25] = 1025
expected_grid_map[4][26] = 1025
expected_grid_map[4][27] = 1025
expected_grid_map[4][28] = 5633
expected_grid_map[4][29] = 5633
expected_grid_map[4][30] = 4608
expected_grid_map[5][9] = 49186
expected_grid_map[5][10] = 1025
expected_grid_map[5][11] = 1025
expected_grid_map[5][12] = 1025
expected_grid_map[5][13] = 1025
expected_grid_map[5][14] = 1025
expected_grid_map[5][15] = 1025
expected_grid_map[5][16] = 1025
expected_grid_map[5][17] = 1025
expected_grid_map[5][18] = 1025
expected_grid_map[5][19] = 1025
expected_grid_map[5][20] = 1025
expected_grid_map[5][21] = 1025
expected_grid_map[5][22] = 2064
expected_grid_map[5][23] = 32800
expected_grid_map[5][28] = 32800
expected_grid_map[5][29] = 32800
expected_grid_map[5][30] = 32800
expected_grid_map[6][9] = 49186
expected_grid_map[6][10] = 1025
expected_grid_map[6][11] = 1025
expected_grid_map[6][12] = 1025
expected_grid_map[6][13] = 1025
expected_grid_map[6][14] = 1025
expected_grid_map[6][15] = 1025
expected_grid_map[6][16] = 1025
expected_grid_map[6][17] = 1025
expected_grid_map[6][18] = 1025
expected_grid_map[6][19] = 1025
expected_grid_map[6][20] = 1025
expected_grid_map[6][21] = 1025
expected_grid_map[6][22] = 1025
expected_grid_map[6][23] = 2064
expected_grid_map[6][28] = 32800
expected_grid_map[6][29] = 32872
expected_grid_map[6][30] = 37408
expected_grid_map[7][9] = 32800
expected_grid_map[7][28] = 32800
expected_grid_map[7][29] = 32800
expected_grid_map[7][30] = 32800
expected_grid_map[8][9] = 32872
expected_grid_map[8][10] = 4608
expected_grid_map[8][28] = 49186
expected_grid_map[8][29] = 34864
expected_grid_map[8][30] = 32872
expected_grid_map[8][31] = 4608
expected_grid_map[9][9] = 49186
expected_grid_map[9][10] = 34864
expected_grid_map[9][28] = 32800
expected_grid_map[9][29] = 32800
expected_grid_map[9][30] = 32800
expected_grid_map[9][31] = 32800
expected_grid_map[10][9] = 32800
expected_grid_map[10][10] = 32800
expected_grid_map[10][28] = 32872
expected_grid_map[10][29] = 37408
expected_grid_map[10][30] = 49186
expected_grid_map[10][31] = 2064
expected_grid_map[11][9] = 32800
expected_grid_map[11][10] = 32800
expected_grid_map[11][28] = 32800
expected_grid_map[11][29] = 32800
expected_grid_map[11][30] = 32800
expected_grid_map[12][9] = 32800
expected_grid_map[12][10] = 32800
expected_grid_map[12][28] = 32800
expected_grid_map[12][29] = 49186
expected_grid_map[12][30] = 34864
expected_grid_map[12][33] = 16386
expected_grid_map[12][34] = 1025
expected_grid_map[12][35] = 1025
expected_grid_map[12][36] = 1025
expected_grid_map[12][37] = 1025
expected_grid_map[12][38] = 5633
expected_grid_map[12][39] = 17411
expected_grid_map[12][40] = 1025
expected_grid_map[12][41] = 1025
expected_grid_map[12][42] = 1025
expected_grid_map[12][43] = 5633
expected_grid_map[12][44] = 17411
expected_grid_map[12][45] = 1025
expected_grid_map[12][46] = 4608
expected_grid_map[13][9] = 32872
expected_grid_map[13][10] = 37408
expected_grid_map[13][28] = 32800
expected_grid_map[13][29] = 32800
expected_grid_map[13][30] = 32800
expected_grid_map[13][33] = 32800
expected_grid_map[13][38] = 72
expected_grid_map[13][39] = 3089
expected_grid_map[13][40] = 1025
expected_grid_map[13][41] = 1025
expected_grid_map[13][42] = 1025
expected_grid_map[13][43] = 1097
expected_grid_map[13][44] = 2064
expected_grid_map[13][46] = 32800
expected_grid_map[14][9] = 49186
expected_grid_map[14][10] = 2064
expected_grid_map[14][24] = 16386
expected_grid_map[14][25] = 17411
expected_grid_map[14][26] = 1025
expected_grid_map[14][27] = 1025
expected_grid_map[14][28] = 34864
expected_grid_map[14][29] = 32800
expected_grid_map[14][30] = 32872
expected_grid_map[14][31] = 1025
expected_grid_map[14][32] = 1025
expected_grid_map[14][33] = 2064
expected_grid_map[14][46] = 32800
expected_grid_map[15][9] = 32800
expected_grid_map[15][24] = 32800
expected_grid_map[15][25] = 49186
expected_grid_map[15][26] = 1025
expected_grid_map[15][27] = 1025
expected_grid_map[15][28] = 3089
expected_grid_map[15][29] = 3089
expected_grid_map[15][30] = 2064
expected_grid_map[15][46] = 32800
expected_grid_map[16][8] = 16386
expected_grid_map[16][9] = 52275
expected_grid_map[16][10] = 4608
expected_grid_map[16][24] = 32800
expected_grid_map[16][25] = 32800
expected_grid_map[16][46] = 32800
expected_grid_map[17][8] = 32800
expected_grid_map[17][9] = 32800
expected_grid_map[17][10] = 32800
expected_grid_map[17][24] = 32872
expected_grid_map[17][25] = 37408
expected_grid_map[17][44] = 16386
expected_grid_map[17][45] = 17411
expected_grid_map[17][46] = 34864
expected_grid_map[18][8] = 32800
expected_grid_map[18][9] = 32800
expected_grid_map[18][10] = 32800
expected_grid_map[18][24] = 49186
expected_grid_map[18][25] = 34864
expected_grid_map[18][44] = 32800
expected_grid_map[18][45] = 32800
expected_grid_map[18][46] = 32800
expected_grid_map[19][8] = 32800
expected_grid_map[19][9] = 32800
expected_grid_map[19][10] = 32800
expected_grid_map[19][23] = 16386
expected_grid_map[19][24] = 34864
expected_grid_map[19][25] = 32872
expected_grid_map[19][26] = 4608
expected_grid_map[19][44] = 32800
expected_grid_map[19][45] = 32800
expected_grid_map[19][46] = 32800
expected_grid_map[20][8] = 32800
expected_grid_map[20][9] = 32872
expected_grid_map[20][10] = 37408
expected_grid_map[20][23] = 32800
expected_grid_map[20][24] = 32800
expected_grid_map[20][25] = 32800
expected_grid_map[20][26] = 32800
expected_grid_map[20][44] = 32800
expected_grid_map[20][45] = 32800
expected_grid_map[20][46] = 32800
expected_grid_map[21][8] = 32800
expected_grid_map[21][9] = 32800
expected_grid_map[21][10] = 32800
expected_grid_map[21][23] = 72
expected_grid_map[21][24] = 37408
expected_grid_map[21][25] = 49186
expected_grid_map[21][26] = 2064
expected_grid_map[21][44] = 32800
expected_grid_map[21][45] = 32800
expected_grid_map[21][46] = 32800
expected_grid_map[22][8] = 49186
expected_grid_map[22][9] = 34864
expected_grid_map[22][10] = 32872
expected_grid_map[22][11] = 4608
expected_grid_map[22][24] = 32872
expected_grid_map[22][25] = 37408
expected_grid_map[22][43] = 16386
expected_grid_map[22][44] = 2064
expected_grid_map[22][45] = 32800
expected_grid_map[22][46] = 32800
expected_grid_map[23][8] = 32800
expected_grid_map[23][9] = 32800
expected_grid_map[23][10] = 32800
expected_grid_map[23][11] = 32800
expected_grid_map[23][24] = 49186
expected_grid_map[23][25] = 34864
expected_grid_map[23][42] = 16386
expected_grid_map[23][43] = 33825
expected_grid_map[23][44] = 17411
expected_grid_map[23][45] = 3089
expected_grid_map[23][46] = 2064
expected_grid_map[24][8] = 32872
expected_grid_map[24][9] = 37408
expected_grid_map[24][10] = 49186
expected_grid_map[24][11] = 2064
expected_grid_map[24][24] = 32800
expected_grid_map[24][25] = 32800
expected_grid_map[24][42] = 32800
expected_grid_map[24][43] = 32800
expected_grid_map[24][44] = 32800
expected_grid_map[25][8] = 32800
expected_grid_map[25][9] = 32800
expected_grid_map[25][10] = 32800
expected_grid_map[25][24] = 32800
expected_grid_map[25][25] = 32800
expected_grid_map[25][42] = 32800
expected_grid_map[25][43] = 32872
expected_grid_map[25][44] = 37408
expected_grid_map[26][8] = 32800
expected_grid_map[26][9] = 49186
expected_grid_map[26][10] = 34864
expected_grid_map[26][24] = 49186
expected_grid_map[26][25] = 2064
expected_grid_map[26][42] = 32800
expected_grid_map[26][43] = 32800
expected_grid_map[26][44] = 32800
expected_grid_map[27][8] = 32800
expected_grid_map[27][9] = 32800
expected_grid_map[27][10] = 32800
expected_grid_map[27][24] = 32800
expected_grid_map[27][42] = 49186
expected_grid_map[27][43] = 34864
expected_grid_map[27][44] = 32872
expected_grid_map[27][45] = 4608
expected_grid_map[28][8] = 32800
expected_grid_map[28][9] = 32800
expected_grid_map[28][10] = 32800
expected_grid_map[28][24] = 32872
expected_grid_map[28][25] = 4608
expected_grid_map[28][42] = 32800
expected_grid_map[28][43] = 32800
expected_grid_map[28][44] = 32800
expected_grid_map[28][45] = 32800
expected_grid_map[29][8] = 32800
expected_grid_map[29][9] = 32800
expected_grid_map[29][10] = 32800
expected_grid_map[29][24] = 49186
expected_grid_map[29][25] = 34864
expected_grid_map[29][42] = 32872
expected_grid_map[29][43] = 37408
expected_grid_map[29][44] = 49186
expected_grid_map[29][45] = 2064
expected_grid_map[30][8] = 32800
expected_grid_map[30][9] = 32800
expected_grid_map[30][10] = 32800
expected_grid_map[30][23] = 16386
expected_grid_map[30][24] = 34864
expected_grid_map[30][25] = 32872
expected_grid_map[30][26] = 4608
expected_grid_map[30][42] = 32800
expected_grid_map[30][43] = 32800
expected_grid_map[30][44] = 32800
expected_grid_map[31][8] = 32800
expected_grid_map[31][9] = 32872
expected_grid_map[31][10] = 37408
expected_grid_map[31][23] = 32800
expected_grid_map[31][24] = 32800
expected_grid_map[31][25] = 32800
expected_grid_map[31][26] = 32800
expected_grid_map[31][42] = 32800
expected_grid_map[31][43] = 49186
expected_grid_map[31][44] = 34864
expected_grid_map[32][8] = 32800
expected_grid_map[32][9] = 32800
expected_grid_map[32][10] = 32800
expected_grid_map[32][23] = 72
expected_grid_map[32][24] = 37408
expected_grid_map[32][25] = 49186
expected_grid_map[32][26] = 2064
expected_grid_map[32][42] = 32800
expected_grid_map[32][43] = 32800
expected_grid_map[32][44] = 32800
expected_grid_map[33][8] = 49186
expected_grid_map[33][9] = 34864
expected_grid_map[33][10] = 32872
expected_grid_map[33][11] = 4608
expected_grid_map[33][24] = 32872
expected_grid_map[33][25] = 37408
expected_grid_map[33][41] = 16386
expected_grid_map[33][42] = 34864
expected_grid_map[33][43] = 32800
expected_grid_map[33][44] = 32800
expected_grid_map[34][8] = 32800
expected_grid_map[34][9] = 32800
expected_grid_map[34][10] = 32800
expected_grid_map[34][11] = 32800
expected_grid_map[34][24] = 49186
expected_grid_map[34][25] = 2064
expected_grid_map[34][41] = 32800
expected_grid_map[34][42] = 49186
expected_grid_map[34][43] = 2064
expected_grid_map[34][44] = 32800
expected_grid_map[35][8] = 32872
expected_grid_map[35][9] = 37408
expected_grid_map[35][10] = 49186
expected_grid_map[35][11] = 2064
expected_grid_map[35][24] = 32800
expected_grid_map[35][41] = 32800
expected_grid_map[35][42] = 32800
expected_grid_map[35][43] = 16386
expected_grid_map[35][44] = 2064
expected_grid_map[36][8] = 32800
expected_grid_map[36][9] = 32800
expected_grid_map[36][10] = 32800
expected_grid_map[36][18] = 16386
expected_grid_map[36][19] = 17411
expected_grid_map[36][20] = 1025
expected_grid_map[36][21] = 1025
expected_grid_map[36][22] = 1025
expected_grid_map[36][23] = 17411
expected_grid_map[36][24] = 52275
expected_grid_map[36][25] = 5633
expected_grid_map[36][26] = 5633
expected_grid_map[36][27] = 4608
expected_grid_map[36][41] = 32800
expected_grid_map[36][42] = 32800
expected_grid_map[36][43] = 32800
expected_grid_map[37][8] = 32800
expected_grid_map[37][9] = 49186
expected_grid_map[37][10] = 34864
expected_grid_map[37][13] = 16386
expected_grid_map[37][14] = 1025
expected_grid_map[37][15] = 1025
expected_grid_map[37][16] = 1025
expected_grid_map[37][17] = 1025
expected_grid_map[37][18] = 2064
expected_grid_map[37][19] = 32800
expected_grid_map[37][20] = 16386
expected_grid_map[37][21] = 1025
expected_grid_map[37][22] = 1025
expected_grid_map[37][23] = 2064
expected_grid_map[37][24] = 72
expected_grid_map[37][25] = 37408
expected_grid_map[37][26] = 32800
expected_grid_map[37][27] = 32800
expected_grid_map[37][41] = 32800
expected_grid_map[37][42] = 32800
expected_grid_map[37][43] = 32800
expected_grid_map[38][8] = 32800
expected_grid_map[38][9] = 32800
expected_grid_map[38][10] = 32800
expected_grid_map[38][13] = 49186
expected_grid_map[38][14] = 1025
expected_grid_map[38][15] = 1025
expected_grid_map[38][16] = 1025
expected_grid_map[38][17] = 1025
expected_grid_map[38][18] = 1025
expected_grid_map[38][19] = 2064
expected_grid_map[38][20] = 32800
expected_grid_map[38][25] = 32800
expected_grid_map[38][26] = 32800
expected_grid_map[38][27] = 32800
expected_grid_map[38][41] = 32800
expected_grid_map[38][42] = 32800
expected_grid_map[38][43] = 32800
expected_grid_map[39][8] = 72
expected_grid_map[39][9] = 1097
expected_grid_map[39][10] = 1097
expected_grid_map[39][11] = 1025
expected_grid_map[39][12] = 1025
expected_grid_map[39][13] = 3089
expected_grid_map[39][14] = 1025
expected_grid_map[39][15] = 1025
expected_grid_map[39][16] = 1025
expected_grid_map[39][17] = 1025
expected_grid_map[39][18] = 1025
expected_grid_map[39][19] = 1025
expected_grid_map[39][20] = 2064
expected_grid_map[39][25] = 32800
expected_grid_map[39][26] = 32872
expected_grid_map[39][27] = 37408
expected_grid_map[39][41] = 32800
expected_grid_map[39][42] = 32800
expected_grid_map[39][43] = 32800
expected_grid_map[40][25] = 32800
expected_grid_map[40][26] = 32800
expected_grid_map[40][27] = 32800
expected_grid_map[40][41] = 32800
expected_grid_map[40][42] = 32800
expected_grid_map[40][43] = 32800
expected_grid_map[41][25] = 49186
expected_grid_map[41][26] = 34864
expected_grid_map[41][27] = 32872
expected_grid_map[41][28] = 4608
expected_grid_map[41][41] = 32800
expected_grid_map[41][42] = 32800
expected_grid_map[41][43] = 32800
expected_grid_map[42][25] = 32800
expected_grid_map[42][26] = 32800
expected_grid_map[42][27] = 32800
expected_grid_map[42][28] = 32800
expected_grid_map[42][41] = 32800
expected_grid_map[42][42] = 32800
expected_grid_map[42][43] = 32800
expected_grid_map[43][25] = 32872
expected_grid_map[43][26] = 37408
expected_grid_map[43][27] = 49186
expected_grid_map[43][28] = 2064
expected_grid_map[43][41] = 32800
expected_grid_map[43][42] = 32800
expected_grid_map[43][43] = 32800
expected_grid_map[44][25] = 32800
expected_grid_map[44][26] = 32800
expected_grid_map[44][27] = 32800
expected_grid_map[44][30] = 16386
expected_grid_map[44][31] = 17411
expected_grid_map[44][32] = 1025
expected_grid_map[44][33] = 5633
expected_grid_map[44][34] = 17411
expected_grid_map[44][35] = 1025
expected_grid_map[44][36] = 1025
expected_grid_map[44][37] = 1025
expected_grid_map[44][38] = 5633
expected_grid_map[44][39] = 17411
expected_grid_map[44][40] = 1025
expected_grid_map[44][41] = 3089
expected_grid_map[44][42] = 3089
expected_grid_map[44][43] = 2064
expected_grid_map[45][25] = 32800
expected_grid_map[45][26] = 49186
expected_grid_map[45][27] = 34864
expected_grid_map[45][30] = 32800
expected_grid_map[45][31] = 32800
expected_grid_map[45][33] = 72
expected_grid_map[45][34] = 3089
expected_grid_map[45][35] = 1025
expected_grid_map[45][36] = 1025
expected_grid_map[45][37] = 1025
expected_grid_map[45][38] = 1097
expected_grid_map[45][39] = 2064
expected_grid_map[46][25] = 32800
expected_grid_map[46][26] = 32800
expected_grid_map[46][27] = 32800
expected_grid_map[46][30] = 32800
expected_grid_map[46][31] = 32800
expected_grid_map[47][25] = 72
expected_grid_map[47][26] = 1097
expected_grid_map[47][27] = 1097
expected_grid_map[47][28] = 1025
expected_grid_map[47][29] = 1025
expected_grid_map[47][30] = 3089
expected_grid_map[47][31] = 2064
# Attention, once we have fixed the generator this needs to be changed!!!!
expected_grid_map = env.rail.grid
assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid,
expected_grid_map)
s0 = 0
s1 = 0
for a in range(env.get_num_agents()):
s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0))
s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0))
assert s0 == 36, "actual={}".format(s0)
assert s1 == 27, "actual={}".format(s1)
def test_sparse_rail_generator_deterministic():
"""Check that sparse_rail_generator runs deterministic over different python versions!"""
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1, random_seed=1)
env.reset()
# for r in range(env.height):
# for c in range(env.width):
# print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c,
# env.rail.get_full_transitions(
# r, c), r, c))
assert env.rail.get_full_transitions(0, 1) == 0, "[0][1]"
assert env.rail.get_full_transitions(0, 2) == 0, "[0][2]"
assert env.rail.get_full_transitions(0, 3) == 0, "[0][3]"
assert env.rail.get_full_transitions(0, 4) == 0, "[0][4]"
assert env.rail.get_full_transitions(0, 5) == 0, "[0][5]"
assert env.rail.get_full_transitions(0, 6) == 0, "[0][6]"
assert env.rail.get_full_transitions(0, 7) == 0, "[0][7]"
assert env.rail.get_full_transitions(0, 8) == 0, "[0][8]"
assert env.rail.get_full_transitions(0, 9) == 0, "[0][9]"
assert env.rail.get_full_transitions(0, 10) == 0, "[0][10]"
assert env.rail.get_full_transitions(0, 11) == 0, "[0][11]"
assert env.rail.get_full_transitions(0, 12) == 0, "[0][12]"
assert env.rail.get_full_transitions(0, 13) == 0, "[0][13]"
assert env.rail.get_full_transitions(0, 14) == 0, "[0][14]"
assert env.rail.get_full_transitions(0, 15) == 0, "[0][15]"
assert env.rail.get_full_transitions(0, 16) == 0, "[0][16]"
assert env.rail.get_full_transitions(0, 17) == 0, "[0][17]"
assert env.rail.get_full_transitions(0, 18) == 0, "[0][18]"
assert env.rail.get_full_transitions(0, 19) == 0, "[0][19]"
assert env.rail.get_full_transitions(0, 20) == 0, "[0][20]"
assert env.rail.get_full_transitions(0, 21) == 0, "[0][21]"
assert env.rail.get_full_transitions(0, 22) == 0, "[0][22]"
assert env.rail.get_full_transitions(0, 23) == 0, "[0][23]"
assert env.rail.get_full_transitions(0, 24) == 0, "[0][24]"
assert env.rail.get_full_transitions(1, 0) == 0, "[1][0]"
assert env.rail.get_full_transitions(1, 1) == 0, "[1][1]"
assert env.rail.get_full_transitions(1, 2) == 0, "[1][2]"
assert env.rail.get_full_transitions(1, 3) == 0, "[1][3]"
assert env.rail.get_full_transitions(1, 4) == 0, "[1][4]"
assert env.rail.get_full_transitions(1, 5) == 0, "[1][5]"
assert env.rail.get_full_transitions(1, 6) == 0, "[1][6]"
assert env.rail.get_full_transitions(1, 7) == 0, "[1][7]"
assert env.rail.get_full_transitions(1, 8) == 0, "[1][8]"
assert env.rail.get_full_transitions(1, 9) == 0, "[1][9]"
assert env.rail.get_full_transitions(1, 10) == 0, "[1][10]"
assert env.rail.get_full_transitions(1, 11) == 16386, "[1][11]"
assert env.rail.get_full_transitions(1, 12) == 1025, "[1][12]"
assert env.rail.get_full_transitions(1, 13) == 17411, "[1][13]"
assert env.rail.get_full_transitions(1, 14) == 17411, "[1][14]"
assert env.rail.get_full_transitions(1, 15) == 1025, "[1][15]"
assert env.rail.get_full_transitions(1, 16) == 1025, "[1][16]"
assert env.rail.get_full_transitions(1, 17) == 1025, "[1][17]"
assert env.rail.get_full_transitions(1, 18) == 1025, "[1][18]"
assert env.rail.get_full_transitions(1, 19) == 5633, "[1][19]"
assert env.rail.get_full_transitions(1, 20) == 5633, "[1][20]"
assert env.rail.get_full_transitions(1, 21) == 4608, "[1][21]"
assert env.rail.get_full_transitions(1, 22) == 0, "[1][22]"
assert env.rail.get_full_transitions(1, 23) == 0, "[1][23]"
assert env.rail.get_full_transitions(1, 24) == 0, "[1][24]"
assert env.rail.get_full_transitions(2, 0) == 0, "[2][0]"
assert env.rail.get_full_transitions(2, 1) == 0, "[2][1]"
assert env.rail.get_full_transitions(2, 2) == 0, "[2][2]"
assert env.rail.get_full_transitions(2, 3) == 0, "[2][3]"
assert env.rail.get_full_transitions(2, 4) == 0, "[2][4]"
assert env.rail.get_full_transitions(2, 5) == 0, "[2][5]"
assert env.rail.get_full_transitions(2, 6) == 0, "[2][6]"
assert env.rail.get_full_transitions(2, 7) == 0, "[2][7]"
assert env.rail.get_full_transitions(2, 8) == 0, "[2][8]"
assert env.rail.get_full_transitions(2, 9) == 0, "[2][9]"
assert env.rail.get_full_transitions(2, 10) == 0, "[2][10]"
assert env.rail.get_full_transitions(2, 11) == 32800, "[2][11]"
assert env.rail.get_full_transitions(2, 12) == 16386, "[2][12]"
assert env.rail.get_full_transitions(2, 13) == 34864, "[2][13]"
assert env.rail.get_full_transitions(2, 14) == 32800, "[2][14]"
assert env.rail.get_full_transitions(2, 15) == 0, "[2][15]"
assert env.rail.get_full_transitions(2, 16) == 0, "[2][16]"
assert env.rail.get_full_transitions(2, 17) == 0, "[2][17]"
assert env.rail.get_full_transitions(2, 18) == 0, "[2][18]"
assert env.rail.get_full_transitions(2, 19) == 32800, "[2][19]"
assert env.rail.get_full_transitions(2, 20) == 32800, "[2][20]"
assert env.rail.get_full_transitions(2, 21) == 32800, "[2][21]"
assert env.rail.get_full_transitions(2, 22) == 0, "[2][22]"
assert env.rail.get_full_transitions(2, 23) == 0, "[2][23]"
assert env.rail.get_full_transitions(2, 24) == 0, "[2][24]"
assert env.rail.get_full_transitions(3, 0) == 0, "[3][0]"
assert env.rail.get_full_transitions(3, 1) == 0, "[3][1]"
assert env.rail.get_full_transitions(3, 2) == 0, "[3][2]"
assert env.rail.get_full_transitions(3, 3) == 0, "[3][3]"
assert env.rail.get_full_transitions(3, 4) == 0, "[3][4]"
assert env.rail.get_full_transitions(3, 5) == 0, "[3][5]"
assert env.rail.get_full_transitions(3, 6) == 0, "[3][6]"
assert env.rail.get_full_transitions(3, 7) == 0, "[3][7]"
assert env.rail.get_full_transitions(3, 8) == 0, "[3][8]"
assert env.rail.get_full_transitions(3, 9) == 0, "[3][9]"
assert env.rail.get_full_transitions(3, 10) == 0, "[3][10]"
assert env.rail.get_full_transitions(3, 11) == 32800, "[3][11]"
assert env.rail.get_full_transitions(3, 12) == 32800, "[3][12]"
assert env.rail.get_full_transitions(3, 13) == 32800, "[3][13]"
assert env.rail.get_full_transitions(3, 14) == 32800, "[3][14]"
assert env.rail.get_full_transitions(3, 15) == 0, "[3][15]"
assert env.rail.get_full_transitions(3, 16) == 0, "[3][16]"
assert env.rail.get_full_transitions(3, 17) == 0, "[3][17]"
assert env.rail.get_full_transitions(3, 18) == 0, "[3][18]"
assert env.rail.get_full_transitions(3, 19) == 32800, "[3][19]"
assert env.rail.get_full_transitions(3, 20) == 32872, "[3][20]"
assert env.rail.get_full_transitions(3, 21) == 37408, "[3][21]"
assert env.rail.get_full_transitions(3, 22) == 0, "[3][22]"
assert env.rail.get_full_transitions(3, 23) == 0, "[3][23]"
assert env.rail.get_full_transitions(3, 24) == 0, "[3][24]"
assert env.rail.get_full_transitions(4, 0) == 0, "[4][0]"
assert env.rail.get_full_transitions(4, 1) == 0, "[4][1]"
assert env.rail.get_full_transitions(4, 2) == 0, "[4][2]"
assert env.rail.get_full_transitions(4, 3) == 0, "[4][3]"
assert env.rail.get_full_transitions(4, 4) == 0, "[4][4]"
assert env.rail.get_full_transitions(4, 5) == 0, "[4][5]"
assert env.rail.get_full_transitions(4, 6) == 0, "[4][6]"
assert env.rail.get_full_transitions(4, 7) == 0, "[4][7]"
assert env.rail.get_full_transitions(4, 8) == 0, "[4][8]"
assert env.rail.get_full_transitions(4, 9) == 0, "[4][9]"
assert env.rail.get_full_transitions(4, 10) == 0, "[4][10]"
assert env.rail.get_full_transitions(4, 11) == 32800, "[4][11]"
assert env.rail.get_full_transitions(4, 12) == 32800, "[4][12]"
assert env.rail.get_full_transitions(4, 13) == 32800, "[4][13]"
assert env.rail.get_full_transitions(4, 14) == 32800, "[4][14]"
assert env.rail.get_full_transitions(4, 15) == 0, "[4][15]"
assert env.rail.get_full_transitions(4, 16) == 0, "[4][16]"
assert env.rail.get_full_transitions(4, 17) == 0, "[4][17]"
assert env.rail.get_full_transitions(4, 18) == 0, "[4][18]"
assert env.rail.get_full_transitions(4, 19) == 32800, "[4][19]"
assert env.rail.get_full_transitions(4, 20) == 32800, "[4][20]"
assert env.rail.get_full_transitions(4, 21) == 32800, "[4][21]"
assert env.rail.get_full_transitions(4, 22) == 0, "[4][22]"
assert env.rail.get_full_transitions(4, 23) == 0, "[4][23]"
assert env.rail.get_full_transitions(4, 24) == 0, "[4][24]"
assert env.rail.get_full_transitions(5, 0) == 0, "[5][0]"
assert env.rail.get_full_transitions(5, 1) == 0, "[5][1]"
assert env.rail.get_full_transitions(5, 2) == 0, "[5][2]"
assert env.rail.get_full_transitions(5, 3) == 0, "[5][3]"
assert env.rail.get_full_transitions(5, 4) == 0, "[5][4]"
assert env.rail.get_full_transitions(5, 5) == 0, "[5][5]"
assert env.rail.get_full_transitions(5, 6) == 0, "[5][6]"
assert env.rail.get_full_transitions(5, 7) == 0, "[5][7]"
assert env.rail.get_full_transitions(5, 8) == 0, "[5][8]"
assert env.rail.get_full_transitions(5, 9) == 0, "[5][9]"
assert env.rail.get_full_transitions(5, 10) == 0, "[5][10]"
assert env.rail.get_full_transitions(5, 11) == 49186, "[5][11]"
assert env.rail.get_full_transitions(5, 12) == 3089, "[5][12]"
assert env.rail.get_full_transitions(5, 13) == 2064, "[5][13]"
assert env.rail.get_full_transitions(5, 14) == 32800, "[5][14]"
assert env.rail.get_full_transitions(5, 15) == 0, "[5][15]"
assert env.rail.get_full_transitions(5, 16) == 0, "[5][16]"
assert env.rail.get_full_transitions(5, 17) == 0, "[5][17]"
assert env.rail.get_full_transitions(5, 18) == 0, "[5][18]"
assert env.rail.get_full_transitions(5, 19) == 49186, "[5][19]"
assert env.rail.get_full_transitions(5, 20) == 34864, "[5][20]"
assert env.rail.get_full_transitions(5, 21) == 32872, "[5][21]"
assert env.rail.get_full_transitions(5, 22) == 4608, "[5][22]"
assert env.rail.get_full_transitions(5, 23) == 0, "[5][23]"
assert env.rail.get_full_transitions(5, 24) == 0, "[5][24]"
assert env.rail.get_full_transitions(6, 0) == 16386, "[6][0]"
assert env.rail.get_full_transitions(6, 1) == 17411, "[6][1]"
assert env.rail.get_full_transitions(6, 2) == 1025, "[6][2]"
assert env.rail.get_full_transitions(6, 3) == 5633, "[6][3]"
assert env.rail.get_full_transitions(6, 4) == 17411, "[6][4]"
assert env.rail.get_full_transitions(6, 5) == 1025, "[6][5]"
assert env.rail.get_full_transitions(6, 6) == 1025, "[6][6]"
assert env.rail.get_full_transitions(6, 7) == 1025, "[6][7]"
assert env.rail.get_full_transitions(6, 8) == 5633, "[6][8]"
assert env.rail.get_full_transitions(6, 9) == 17411, "[6][9]"
assert env.rail.get_full_transitions(6, 10) == 1025, "[6][10]"
assert env.rail.get_full_transitions(6, 11) == 3089, "[6][11]"
assert env.rail.get_full_transitions(6, 12) == 1025, "[6][12]"
assert env.rail.get_full_transitions(6, 13) == 1025, "[6][13]"
assert env.rail.get_full_transitions(6, 14) == 2064, "[6][14]"
assert env.rail.get_full_transitions(6, 15) == 0, "[6][15]"
assert env.rail.get_full_transitions(6, 16) == 0, "[6][16]"
assert env.rail.get_full_transitions(6, 17) == 0, "[6][17]"
assert env.rail.get_full_transitions(6, 18) == 0, "[6][18]"
assert env.rail.get_full_transitions(6, 19) == 32800, "[6][19]"
assert env.rail.get_full_transitions(6, 20) == 32800, "[6][20]"
assert env.rail.get_full_transitions(6, 21) == 32800, "[6][21]"
assert env.rail.get_full_transitions(6, 22) == 32800, "[6][22]"
assert env.rail.get_full_transitions(6, 23) == 0, "[6][23]"
assert env.rail.get_full_transitions(6, 24) == 0, "[6][24]"
assert env.rail.get_full_transitions(7, 0) == 32800, "[7][0]"
assert env.rail.get_full_transitions(7, 1) == 32800, "[7][1]"
assert env.rail.get_full_transitions(7, 2) == 0, "[7][2]"
assert env.rail.get_full_transitions(7, 3) == 72, "[7][3]"
assert env.rail.get_full_transitions(7, 4) == 3089, "[7][4]"
assert env.rail.get_full_transitions(7, 5) == 1025, "[7][5]"
assert env.rail.get_full_transitions(7, 6) == 1025, "[7][6]"
assert env.rail.get_full_transitions(7, 7) == 1025, "[7][7]"
assert env.rail.get_full_transitions(7, 8) == 1097, "[7][8]"
assert env.rail.get_full_transitions(7, 9) == 2064, "[7][9]"
assert env.rail.get_full_transitions(7, 10) == 0, "[7][10]"
assert env.rail.get_full_transitions(7, 11) == 0, "[7][11]"
assert env.rail.get_full_transitions(7, 12) == 0, "[7][12]"
assert env.rail.get_full_transitions(7, 13) == 0, "[7][13]"
assert env.rail.get_full_transitions(7, 14) == 0, "[7][14]"
assert env.rail.get_full_transitions(7, 15) == 0, "[7][15]"
assert env.rail.get_full_transitions(7, 16) == 0, "[7][16]"
assert env.rail.get_full_transitions(7, 17) == 0, "[7][17]"
assert env.rail.get_full_transitions(7, 18) == 0, "[7][18]"
assert env.rail.get_full_transitions(7, 19) == 32872, "[7][19]"
assert env.rail.get_full_transitions(7, 20) == 37408, "[7][20]"
assert env.rail.get_full_transitions(7, 21) == 49186, "[7][21]"
assert env.rail.get_full_transitions(7, 22) == 2064, "[7][22]"
assert env.rail.get_full_transitions(7, 23) == 0, "[7][23]"
assert env.rail.get_full_transitions(7, 24) == 0, "[7][24]"
assert env.rail.get_full_transitions(8, 0) == 32800, "[8][0]"
assert env.rail.get_full_transitions(8, 1) == 32800, "[8][1]"
assert env.rail.get_full_transitions(8, 2) == 0, "[8][2]"
assert env.rail.get_full_transitions(8, 3) == 0, "[8][3]"
assert env.rail.get_full_transitions(8, 4) == 0, "[8][4]"
assert env.rail.get_full_transitions(8, 5) == 0, "[8][5]"
assert env.rail.get_full_transitions(8, 6) == 0, "[8][6]"
assert env.rail.get_full_transitions(8, 7) == 0, "[8][7]"
assert env.rail.get_full_transitions(8, 8) == 0, "[8][8]"
assert env.rail.get_full_transitions(8, 9) == 0, "[8][9]"
assert env.rail.get_full_transitions(8, 10) == 0, "[8][10]"
assert env.rail.get_full_transitions(8, 11) == 0, "[8][11]"
assert env.rail.get_full_transitions(8, 12) == 0, "[8][12]"
assert env.rail.get_full_transitions(8, 13) == 0, "[8][13]"
assert env.rail.get_full_transitions(8, 14) == 0, "[8][14]"
assert env.rail.get_full_transitions(8, 15) == 0, "[8][15]"
assert env.rail.get_full_transitions(8, 16) == 0, "[8][16]"
assert env.rail.get_full_transitions(8, 17) == 0, "[8][17]"
assert env.rail.get_full_transitions(8, 18) == 0, "[8][18]"
assert env.rail.get_full_transitions(8, 19) == 32800, "[8][19]"
assert env.rail.get_full_transitions(8, 20) == 32800, "[8][20]"
assert env.rail.get_full_transitions(8, 21) == 32800, "[8][21]"
assert env.rail.get_full_transitions(8, 22) == 0, "[8][22]"
assert env.rail.get_full_transitions(8, 23) == 0, "[8][23]"
assert env.rail.get_full_transitions(8, 24) == 0, "[8][24]"
assert env.rail.get_full_transitions(9, 0) == 32800, "[9][0]"
assert env.rail.get_full_transitions(9, 1) == 32800, "[9][1]"
assert env.rail.get_full_transitions(9, 2) == 0, "[9][2]"
assert env.rail.get_full_transitions(9, 3) == 0, "[9][3]"
assert env.rail.get_full_transitions(9, 4) == 0, "[9][4]"
assert env.rail.get_full_transitions(9, 5) == 0, "[9][5]"
assert env.rail.get_full_transitions(9, 6) == 0, "[9][6]"
assert env.rail.get_full_transitions(9, 7) == 0, "[9][7]"
assert env.rail.get_full_transitions(9, 8) == 0, "[9][8]"
assert env.rail.get_full_transitions(9, 9) == 0, "[9][9]"
assert env.rail.get_full_transitions(9, 10) == 0, "[9][10]"
assert env.rail.get_full_transitions(9, 11) == 0, "[9][11]"
assert env.rail.get_full_transitions(9, 12) == 0, "[9][12]"
assert env.rail.get_full_transitions(9, 13) == 0, "[9][13]"
assert env.rail.get_full_transitions(9, 14) == 0, "[9][14]"
assert env.rail.get_full_transitions(9, 15) == 0, "[9][15]"
assert env.rail.get_full_transitions(9, 16) == 0, "[9][16]"
assert env.rail.get_full_transitions(9, 17) == 0, "[9][17]"
assert env.rail.get_full_transitions(9, 18) == 0, "[9][18]"
assert env.rail.get_full_transitions(9, 19) == 32800, "[9][19]"
assert env.rail.get_full_transitions(9, 20) == 49186, "[9][20]"
assert env.rail.get_full_transitions(9, 21) == 34864, "[9][21]"
assert env.rail.get_full_transitions(9, 22) == 0, "[9][22]"
assert env.rail.get_full_transitions(9, 23) == 0, "[9][23]"
assert env.rail.get_full_transitions(9, 24) == 0, "[9][24]"
assert env.rail.get_full_transitions(10, 0) == 32800, "[10][0]"
assert env.rail.get_full_transitions(10, 1) == 32800, "[10][1]"
assert env.rail.get_full_transitions(10, 2) == 0, "[10][2]"
assert env.rail.get_full_transitions(10, 3) == 0, "[10][3]"
assert env.rail.get_full_transitions(10, 4) == 0, "[10][4]"
assert env.rail.get_full_transitions(10, 5) == 0, "[10][5]"
assert env.rail.get_full_transitions(10, 6) == 0, "[10][6]"
assert env.rail.get_full_transitions(10, 7) == 0, "[10][7]"
assert env.rail.get_full_transitions(10, 8) == 0, "[10][8]"
assert env.rail.get_full_transitions(10, 9) == 0, "[10][9]"
assert env.rail.get_full_transitions(10, 10) == 0, "[10][10]"
assert env.rail.get_full_transitions(10, 11) == 0, "[10][11]"
assert env.rail.get_full_transitions(10, 12) == 0, "[10][12]"
assert env.rail.get_full_transitions(10, 13) == 0, "[10][13]"
assert env.rail.get_full_transitions(10, 14) == 0, "[10][14]"
assert env.rail.get_full_transitions(10, 15) == 0, "[10][15]"
assert env.rail.get_full_transitions(10, 16) == 0, "[10][16]"
assert env.rail.get_full_transitions(10, 17) == 0, "[10][17]"
assert env.rail.get_full_transitions(10, 18) == 0, "[10][18]"
assert env.rail.get_full_transitions(10, 19) == 32800, "[10][19]"
assert env.rail.get_full_transitions(10, 20) == 32800, "[10][20]"
assert env.rail.get_full_transitions(10, 21) == 32800, "[10][21]"
assert env.rail.get_full_transitions(10, 22) == 0, "[10][22]"
assert env.rail.get_full_transitions(10, 23) == 0, "[10][23]"
assert env.rail.get_full_transitions(10, 24) == 0, "[10][24]"
assert env.rail.get_full_transitions(11, 0) == 32800, "[11][0]"
assert env.rail.get_full_transitions(11, 1) == 32800, "[11][1]"
assert env.rail.get_full_transitions(11, 2) == 0, "[11][2]"
assert env.rail.get_full_transitions(11, 3) == 0, "[11][3]"
assert env.rail.get_full_transitions(11, 4) == 0, "[11][4]"
assert env.rail.get_full_transitions(11, 5) == 0, "[11][5]"
assert env.rail.get_full_transitions(11, 6) == 0, "[11][6]"
assert env.rail.get_full_transitions(11, 7) == 0, "[11][7]"
assert env.rail.get_full_transitions(11, 8) == 0, "[11][8]"
assert env.rail.get_full_transitions(11, 9) == 0, "[11][9]"
assert env.rail.get_full_transitions(11, 10) == 0, "[11][10]"
assert env.rail.get_full_transitions(11, 11) == 0, "[11][11]"
assert env.rail.get_full_transitions(11, 12) == 0, "[11][12]"
assert env.rail.get_full_transitions(11, 13) == 0, "[11][13]"
assert env.rail.get_full_transitions(11, 14) == 0, "[11][14]"
assert env.rail.get_full_transitions(11, 15) == 0, "[11][15]"
assert env.rail.get_full_transitions(11, 16) == 0, "[11][16]"
assert env.rail.get_full_transitions(11, 17) == 0, "[11][17]"
assert env.rail.get_full_transitions(11, 18) == 0, "[11][18]"
assert env.rail.get_full_transitions(11, 19) == 32800, "[11][19]"
assert env.rail.get_full_transitions(11, 20) == 32800, "[11][20]"
assert env.rail.get_full_transitions(11, 21) == 32800, "[11][21]"
assert env.rail.get_full_transitions(11, 22) == 0, "[11][22]"
assert env.rail.get_full_transitions(11, 23) == 0, "[11][23]"
assert env.rail.get_full_transitions(11, 24) == 0, "[11][24]"
assert env.rail.get_full_transitions(12, 0) == 32800, "[12][0]"
assert env.rail.get_full_transitions(12, 1) == 32800, "[12][1]"
assert env.rail.get_full_transitions(12, 2) == 0, "[12][2]"
assert env.rail.get_full_transitions(12, 3) == 0, "[12][3]"
assert env.rail.get_full_transitions(12, 4) == 0, "[12][4]"
assert env.rail.get_full_transitions(12, 5) == 0, "[12][5]"
assert env.rail.get_full_transitions(12, 6) == 0, "[12][6]"
assert env.rail.get_full_transitions(12, 7) == 0, "[12][7]"
assert env.rail.get_full_transitions(12, 8) == 0, "[12][8]"
assert env.rail.get_full_transitions(12, 9) == 0, "[12][9]"
assert env.rail.get_full_transitions(12, 10) == 0, "[12][10]"
assert env.rail.get_full_transitions(12, 11) == 0, "[12][11]"
assert env.rail.get_full_transitions(12, 12) == 0, "[12][12]"
assert env.rail.get_full_transitions(12, 13) == 0, "[12][13]"
assert env.rail.get_full_transitions(12, 14) == 0, "[12][14]"
assert env.rail.get_full_transitions(12, 15) == 0, "[12][15]"
assert env.rail.get_full_transitions(12, 16) == 0, "[12][16]"
assert env.rail.get_full_transitions(12, 17) == 0, "[12][17]"
assert env.rail.get_full_transitions(12, 18) == 0, "[12][18]"
assert env.rail.get_full_transitions(12, 19) == 32800, "[12][19]"
assert env.rail.get_full_transitions(12, 20) == 32800, "[12][20]"
assert env.rail.get_full_transitions(12, 21) == 32800, "[12][21]"
assert env.rail.get_full_transitions(12, 22) == 0, "[12][22]"
assert env.rail.get_full_transitions(12, 23) == 0, "[12][23]"
assert env.rail.get_full_transitions(12, 24) == 0, "[12][24]"
assert env.rail.get_full_transitions(13, 0) == 32800, "[13][0]"
assert env.rail.get_full_transitions(13, 1) == 32800, "[13][1]"
assert env.rail.get_full_transitions(13, 2) == 0, "[13][2]"
assert env.rail.get_full_transitions(13, 3) == 0, "[13][3]"
assert env.rail.get_full_transitions(13, 4) == 0, "[13][4]"
assert env.rail.get_full_transitions(13, 5) == 0, "[13][5]"
assert env.rail.get_full_transitions(13, 6) == 0, "[13][6]"
assert env.rail.get_full_transitions(13, 7) == 0, "[13][7]"
assert env.rail.get_full_transitions(13, 8) == 0, "[13][8]"
assert env.rail.get_full_transitions(13, 9) == 0, "[13][9]"
assert env.rail.get_full_transitions(13, 10) == 0, "[13][10]"
assert env.rail.get_full_transitions(13, 11) == 0, "[13][11]"
assert env.rail.get_full_transitions(13, 12) == 0, "[13][12]"
assert env.rail.get_full_transitions(13, 13) == 0, "[13][13]"
assert env.rail.get_full_transitions(13, 14) == 0, "[13][14]"
assert env.rail.get_full_transitions(13, 15) == 0, "[13][15]"
assert env.rail.get_full_transitions(13, 16) == 0, "[13][16]"
assert env.rail.get_full_transitions(13, 17) == 0, "[13][17]"
assert env.rail.get_full_transitions(13, 18) == 0, "[13][18]"
assert env.rail.get_full_transitions(13, 19) == 32800, "[13][19]"
assert env.rail.get_full_transitions(13, 20) == 32800, "[13][20]"
assert env.rail.get_full_transitions(13, 21) == 32800, "[13][21]"
assert env.rail.get_full_transitions(13, 22) == 0, "[13][22]"
assert env.rail.get_full_transitions(13, 23) == 0, "[13][23]"
assert env.rail.get_full_transitions(13, 24) == 0, "[13][24]"
assert env.rail.get_full_transitions(14, 0) == 32800, "[14][0]"
assert env.rail.get_full_transitions(14, 1) == 32800, "[14][1]"
assert env.rail.get_full_transitions(14, 2) == 0, "[14][2]"
assert env.rail.get_full_transitions(14, 3) == 0, "[14][3]"
assert env.rail.get_full_transitions(14, 4) == 0, "[14][4]"
assert env.rail.get_full_transitions(14, 5) == 0, "[14][5]"
assert env.rail.get_full_transitions(14, 6) == 0, "[14][6]"
assert env.rail.get_full_transitions(14, 7) == 0, "[14][7]"
assert env.rail.get_full_transitions(14, 8) == 0, "[14][8]"
assert env.rail.get_full_transitions(14, 9) == 0, "[14][9]"
assert env.rail.get_full_transitions(14, 10) == 0, "[14][10]"
assert env.rail.get_full_transitions(14, 11) == 0, "[14][11]"
assert env.rail.get_full_transitions(14, 12) == 0, "[14][12]"
assert env.rail.get_full_transitions(14, 13) == 0, "[14][13]"
assert env.rail.get_full_transitions(14, 14) == 0, "[14][14]"
assert env.rail.get_full_transitions(14, 15) == 0, "[14][15]"
assert env.rail.get_full_transitions(14, 16) == 0, "[14][16]"
assert env.rail.get_full_transitions(14, 17) == 0, "[14][17]"
assert env.rail.get_full_transitions(14, 18) == 0, "[14][18]"
assert env.rail.get_full_transitions(14, 19) == 32800, "[14][19]"
assert env.rail.get_full_transitions(14, 20) == 32800, "[14][20]"
assert env.rail.get_full_transitions(14, 21) == 32800, "[14][21]"
assert env.rail.get_full_transitions(14, 22) == 0, "[14][22]"
assert env.rail.get_full_transitions(14, 23) == 0, "[14][23]"
assert env.rail.get_full_transitions(14, 24) == 0, "[14][24]"
assert env.rail.get_full_transitions(15, 0) == 32800, "[15][0]"
assert env.rail.get_full_transitions(15, 1) == 32800, "[15][1]"
assert env.rail.get_full_transitions(15, 2) == 0, "[15][2]"
assert env.rail.get_full_transitions(15, 3) == 0, "[15][3]"
assert env.rail.get_full_transitions(15, 4) == 0, "[15][4]"
assert env.rail.get_full_transitions(15, 5) == 0, "[15][5]"
assert env.rail.get_full_transitions(15, 6) == 0, "[15][6]"
assert env.rail.get_full_transitions(15, 7) == 0, "[15][7]"
assert env.rail.get_full_transitions(15, 8) == 0, "[15][8]"
assert env.rail.get_full_transitions(15, 9) == 0, "[15][9]"
assert env.rail.get_full_transitions(15, 10) == 0, "[15][10]"
assert env.rail.get_full_transitions(15, 11) == 0, "[15][11]"
assert env.rail.get_full_transitions(15, 12) == 0, "[15][12]"
assert env.rail.get_full_transitions(15, 13) == 0, "[15][13]"
assert env.rail.get_full_transitions(15, 14) == 0, "[15][14]"
assert env.rail.get_full_transitions(15, 15) == 0, "[15][15]"
assert env.rail.get_full_transitions(15, 16) == 0, "[15][16]"
assert env.rail.get_full_transitions(15, 17) == 0, "[15][17]"
assert env.rail.get_full_transitions(15, 18) == 0, "[15][18]"
assert env.rail.get_full_transitions(15, 19) == 32800, "[15][19]"
assert env.rail.get_full_transitions(15, 20) == 32800, "[15][20]"
assert env.rail.get_full_transitions(15, 21) == 32800, "[15][21]"
assert env.rail.get_full_transitions(15, 22) == 0, "[15][22]"
assert env.rail.get_full_transitions(15, 23) == 0, "[15][23]"
assert env.rail.get_full_transitions(15, 24) == 0, "[15][24]"
assert env.rail.get_full_transitions(16, 0) == 32800, "[16][0]"
assert env.rail.get_full_transitions(16, 1) == 32800, "[16][1]"
assert env.rail.get_full_transitions(16, 2) == 0, "[16][2]"
assert env.rail.get_full_transitions(16, 3) == 0, "[16][3]"
assert env.rail.get_full_transitions(16, 4) == 0, "[16][4]"
assert env.rail.get_full_transitions(16, 5) == 0, "[16][5]"
assert env.rail.get_full_transitions(16, 6) == 0, "[16][6]"
assert env.rail.get_full_transitions(16, 7) == 0, "[16][7]"
assert env.rail.get_full_transitions(16, 8) == 0, "[16][8]"
assert env.rail.get_full_transitions(16, 9) == 0, "[16][9]"
assert env.rail.get_full_transitions(16, 10) == 0, "[16][10]"
assert env.rail.get_full_transitions(16, 11) == 0, "[16][11]"
assert env.rail.get_full_transitions(16, 12) == 0, "[16][12]"
assert env.rail.get_full_transitions(16, 13) == 0, "[16][13]"
assert env.rail.get_full_transitions(16, 14) == 0, "[16][14]"
assert env.rail.get_full_transitions(16, 15) == 0, "[16][15]"
assert env.rail.get_full_transitions(16, 16) == 0, "[16][16]"
assert env.rail.get_full_transitions(16, 17) == 0, "[16][17]"
assert env.rail.get_full_transitions(16, 18) == 0, "[16][18]"
assert env.rail.get_full_transitions(16, 19) == 32800, "[16][19]"
assert env.rail.get_full_transitions(16, 20) == 32800, "[16][20]"
assert env.rail.get_full_transitions(16, 21) == 32800, "[16][21]"
assert env.rail.get_full_transitions(16, 22) == 0, "[16][22]"
assert env.rail.get_full_transitions(16, 23) == 0, "[16][23]"
assert env.rail.get_full_transitions(16, 24) == 0, "[16][24]"
assert env.rail.get_full_transitions(17, 0) == 32800, "[17][0]"
assert env.rail.get_full_transitions(17, 1) == 32800, "[17][1]"
assert env.rail.get_full_transitions(17, 2) == 0, "[17][2]"
assert env.rail.get_full_transitions(17, 3) == 0, "[17][3]"
assert env.rail.get_full_transitions(17, 4) == 0, "[17][4]"
assert env.rail.get_full_transitions(17, 5) == 0, "[17][5]"
assert env.rail.get_full_transitions(17, 6) == 0, "[17][6]"
assert env.rail.get_full_transitions(17, 7) == 0, "[17][7]"
assert env.rail.get_full_transitions(17, 8) == 0, "[17][8]"
assert env.rail.get_full_transitions(17, 9) == 0, "[17][9]"
assert env.rail.get_full_transitions(17, 10) == 0, "[17][10]"
assert env.rail.get_full_transitions(17, 11) == 0, "[17][11]"
assert env.rail.get_full_transitions(17, 12) == 0, "[17][12]"
assert env.rail.get_full_transitions(17, 13) == 0, "[17][13]"
assert env.rail.get_full_transitions(17, 14) == 0, "[17][14]"
assert env.rail.get_full_transitions(17, 15) == 0, "[17][15]"
assert env.rail.get_full_transitions(17, 16) == 0, "[17][16]"
assert env.rail.get_full_transitions(17, 17) == 0, "[17][17]"
assert env.rail.get_full_transitions(17, 18) == 0, "[17][18]"
assert env.rail.get_full_transitions(17, 19) == 32800, "[17][19]"
assert env.rail.get_full_transitions(17, 20) == 32800, "[17][20]"
assert env.rail.get_full_transitions(17, 21) == 32800, "[17][21]"
assert env.rail.get_full_transitions(17, 22) == 0, "[17][22]"
assert env.rail.get_full_transitions(17, 23) == 0, "[17][23]"
assert env.rail.get_full_transitions(17, 24) == 0, "[17][24]"
assert env.rail.get_full_transitions(18, 0) == 72, "[18][0]"
assert env.rail.get_full_transitions(18, 1) == 37408, "[18][1]"
assert env.rail.get_full_transitions(18, 2) == 0, "[18][2]"
assert env.rail.get_full_transitions(18, 3) == 0, "[18][3]"
assert env.rail.get_full_transitions(18, 4) == 0, "[18][4]"
assert env.rail.get_full_transitions(18, 5) == 0, "[18][5]"
assert env.rail.get_full_transitions(18, 6) == 0, "[18][6]"
assert env.rail.get_full_transitions(18, 7) == 0, "[18][7]"
assert env.rail.get_full_transitions(18, 8) == 0, "[18][8]"
assert env.rail.get_full_transitions(18, 9) == 0, "[18][9]"
assert env.rail.get_full_transitions(18, 10) == 0, "[18][10]"
assert env.rail.get_full_transitions(18, 11) == 0, "[18][11]"
assert env.rail.get_full_transitions(18, 12) == 0, "[18][12]"
assert env.rail.get_full_transitions(18, 13) == 0, "[18][13]"
assert env.rail.get_full_transitions(18, 14) == 0, "[18][14]"
assert env.rail.get_full_transitions(18, 15) == 0, "[18][15]"
assert env.rail.get_full_transitions(18, 16) == 0, "[18][16]"
assert env.rail.get_full_transitions(18, 17) == 0, "[18][17]"
assert env.rail.get_full_transitions(18, 18) == 0, "[18][18]"
assert env.rail.get_full_transitions(18, 19) == 32800, "[18][19]"
assert env.rail.get_full_transitions(18, 20) == 32800, "[18][20]"
assert env.rail.get_full_transitions(18, 21) == 32800, "[18][21]"
assert env.rail.get_full_transitions(18, 22) == 0, "[18][22]"
assert env.rail.get_full_transitions(18, 23) == 0, "[18][23]"
assert env.rail.get_full_transitions(18, 24) == 0, "[18][24]"
assert env.rail.get_full_transitions(19, 0) == 0, "[19][0]"
assert env.rail.get_full_transitions(19, 1) == 32800, "[19][1]"
assert env.rail.get_full_transitions(19, 2) == 0, "[19][2]"
assert env.rail.get_full_transitions(19, 3) == 0, "[19][3]"
assert env.rail.get_full_transitions(19, 4) == 0, "[19][4]"
assert env.rail.get_full_transitions(19, 5) == 0, "[19][5]"
assert env.rail.get_full_transitions(19, 6) == 0, "[19][6]"
assert env.rail.get_full_transitions(19, 7) == 0, "[19][7]"
assert env.rail.get_full_transitions(19, 8) == 0, "[19][8]"
assert env.rail.get_full_transitions(19, 9) == 0, "[19][9]"
assert env.rail.get_full_transitions(19, 10) == 0, "[19][10]"
assert env.rail.get_full_transitions(19, 11) == 0, "[19][11]"
assert env.rail.get_full_transitions(19, 12) == 0, "[19][12]"
assert env.rail.get_full_transitions(19, 13) == 0, "[19][13]"
assert env.rail.get_full_transitions(19, 14) == 16386, "[19][14]"
assert env.rail.get_full_transitions(19, 15) == 1025, "[19][15]"
assert env.rail.get_full_transitions(19, 16) == 1025, "[19][16]"
assert env.rail.get_full_transitions(19, 17) == 1025, "[19][17]"
assert env.rail.get_full_transitions(19, 18) == 1025, "[19][18]"
assert env.rail.get_full_transitions(19, 19) == 38505, "[19][19]"
assert env.rail.get_full_transitions(19, 20) == 3089, "[19][20]"
assert env.rail.get_full_transitions(19, 21) == 2064, "[19][21]"
assert env.rail.get_full_transitions(19, 22) == 0, "[19][22]"
assert env.rail.get_full_transitions(19, 23) == 0, "[19][23]"
assert env.rail.get_full_transitions(19, 24) == 0, "[19][24]"
assert env.rail.get_full_transitions(20, 0) == 0, "[20][0]"
assert env.rail.get_full_transitions(20, 1) == 32800, "[20][1]"
assert env.rail.get_full_transitions(20, 2) == 0, "[20][2]"
assert env.rail.get_full_transitions(20, 3) == 0, "[20][3]"
assert env.rail.get_full_transitions(20, 4) == 0, "[20][4]"
assert env.rail.get_full_transitions(20, 5) == 0, "[20][5]"
assert env.rail.get_full_transitions(20, 6) == 0, "[20][6]"
assert env.rail.get_full_transitions(20, 7) == 0, "[20][7]"
assert env.rail.get_full_transitions(20, 8) == 0, "[20][8]"
assert env.rail.get_full_transitions(20, 9) == 0, "[20][9]"
assert env.rail.get_full_transitions(20, 10) == 0, "[20][10]"
assert env.rail.get_full_transitions(20, 11) == 0, "[20][11]"
assert env.rail.get_full_transitions(20, 12) == 0, "[20][12]"
assert env.rail.get_full_transitions(20, 13) == 0, "[20][13]"
assert env.rail.get_full_transitions(20, 14) == 32800, "[20][14]"
assert env.rail.get_full_transitions(20, 15) == 0, "[20][15]"
assert env.rail.get_full_transitions(20, 16) == 0, "[20][16]"
assert env.rail.get_full_transitions(20, 17) == 0, "[20][17]"
assert env.rail.get_full_transitions(20, 18) == 0, "[20][18]"
assert env.rail.get_full_transitions(20, 19) == 32800, "[20][19]"
assert env.rail.get_full_transitions(20, 20) == 0, "[20][20]"
assert env.rail.get_full_transitions(20, 21) == 0, "[20][21]"
assert env.rail.get_full_transitions(20, 22) == 0, "[20][22]"
assert env.rail.get_full_transitions(20, 23) == 0, "[20][23]"
assert env.rail.get_full_transitions(20, 24) == 0, "[20][24]"
assert env.rail.get_full_transitions(21, 0) == 0, "[21][0]"
assert env.rail.get_full_transitions(21, 1) == 32800, "[21][1]"
assert env.rail.get_full_transitions(21, 2) == 0, "[21][2]"
assert env.rail.get_full_transitions(21, 3) == 0, "[21][3]"
assert env.rail.get_full_transitions(21, 4) == 0, "[21][4]"
assert env.rail.get_full_transitions(21, 5) == 0, "[21][5]"
assert env.rail.get_full_transitions(21, 6) == 0, "[21][6]"
assert env.rail.get_full_transitions(21, 7) == 0, "[21][7]"
assert env.rail.get_full_transitions(21, 8) == 0, "[21][8]"
assert env.rail.get_full_transitions(21, 9) == 0, "[21][9]"
assert env.rail.get_full_transitions(21, 10) == 0, "[21][10]"
assert env.rail.get_full_transitions(21, 11) == 0, "[21][11]"
assert env.rail.get_full_transitions(21, 12) == 0, "[21][12]"
assert env.rail.get_full_transitions(21, 13) == 0, "[21][13]"
assert env.rail.get_full_transitions(21, 14) == 32800, "[21][14]"
assert env.rail.get_full_transitions(21, 15) == 0, "[21][15]"
assert env.rail.get_full_transitions(21, 16) == 0, "[21][16]"
assert env.rail.get_full_transitions(21, 17) == 0, "[21][17]"
assert env.rail.get_full_transitions(21, 18) == 0, "[21][18]"
assert env.rail.get_full_transitions(21, 19) == 32872, "[21][19]"
assert env.rail.get_full_transitions(21, 20) == 4608, "[21][20]"
assert env.rail.get_full_transitions(21, 21) == 0, "[21][21]"
assert env.rail.get_full_transitions(21, 22) == 0, "[21][22]"
assert env.rail.get_full_transitions(21, 23) == 0, "[21][23]"
assert env.rail.get_full_transitions(21, 24) == 0, "[21][24]"
assert env.rail.get_full_transitions(22, 0) == 0, "[22][0]"
assert env.rail.get_full_transitions(22, 1) == 32800, "[22][1]"
assert env.rail.get_full_transitions(22, 2) == 0, "[22][2]"
assert env.rail.get_full_transitions(22, 3) == 0, "[22][3]"
assert env.rail.get_full_transitions(22, 4) == 0, "[22][4]"
assert env.rail.get_full_transitions(22, 5) == 0, "[22][5]"
assert env.rail.get_full_transitions(22, 6) == 0, "[22][6]"
assert env.rail.get_full_transitions(22, 7) == 0, "[22][7]"
assert env.rail.get_full_transitions(22, 8) == 0, "[22][8]"
assert env.rail.get_full_transitions(22, 9) == 0, "[22][9]"
assert env.rail.get_full_transitions(22, 10) == 0, "[22][10]"
assert env.rail.get_full_transitions(22, 11) == 0, "[22][11]"
assert env.rail.get_full_transitions(22, 12) == 0, "[22][12]"
assert env.rail.get_full_transitions(22, 13) == 0, "[22][13]"
assert env.rail.get_full_transitions(22, 14) == 32800, "[22][14]"
assert env.rail.get_full_transitions(22, 15) == 0, "[22][15]"
assert env.rail.get_full_transitions(22, 16) == 0, "[22][16]"
assert env.rail.get_full_transitions(22, 17) == 0, "[22][17]"
assert env.rail.get_full_transitions(22, 18) == 0, "[22][18]"
assert env.rail.get_full_transitions(22, 19) == 49186, "[22][19]"
assert env.rail.get_full_transitions(22, 20) == 34864, "[22][20]"
assert env.rail.get_full_transitions(22, 21) == 0, "[22][21]"
assert env.rail.get_full_transitions(22, 22) == 0, "[22][22]"
assert env.rail.get_full_transitions(22, 23) == 0, "[22][23]"
assert env.rail.get_full_transitions(22, 24) == 0, "[22][24]"
assert env.rail.get_full_transitions(23, 0) == 0, "[23][0]"
assert env.rail.get_full_transitions(23, 1) == 32800, "[23][1]"
assert env.rail.get_full_transitions(23, 2) == 0, "[23][2]"
assert env.rail.get_full_transitions(23, 3) == 0, "[23][3]"
assert env.rail.get_full_transitions(23, 4) == 0, "[23][4]"
assert env.rail.get_full_transitions(23, 5) == 16386, "[23][5]"
assert env.rail.get_full_transitions(23, 6) == 1025, "[23][6]"
assert env.rail.get_full_transitions(23, 7) == 4608, "[23][7]"
assert env.rail.get_full_transitions(23, 8) == 0, "[23][8]"
assert env.rail.get_full_transitions(23, 9) == 0, "[23][9]"
assert env.rail.get_full_transitions(23, 10) == 0, "[23][10]"
assert env.rail.get_full_transitions(23, 11) == 0, "[23][11]"
assert env.rail.get_full_transitions(23, 12) == 0, "[23][12]"
assert env.rail.get_full_transitions(23, 13) == 0, "[23][13]"
assert env.rail.get_full_transitions(23, 14) == 32800, "[23][14]"
assert env.rail.get_full_transitions(23, 15) == 0, "[23][15]"
assert env.rail.get_full_transitions(23, 16) == 0, "[23][16]"
assert env.rail.get_full_transitions(23, 17) == 0, "[23][17]"
assert env.rail.get_full_transitions(23, 18) == 16386, "[23][18]"
assert env.rail.get_full_transitions(23, 19) == 34864, "[23][19]"
assert env.rail.get_full_transitions(23, 20) == 32872, "[23][20]"
assert env.rail.get_full_transitions(23, 21) == 4608, "[23][21]"
assert env.rail.get_full_transitions(23, 22) == 0, "[23][22]"
assert env.rail.get_full_transitions(23, 23) == 0, "[23][23]"
assert env.rail.get_full_transitions(23, 24) == 0, "[23][24]"
assert env.rail.get_full_transitions(24, 0) == 0, "[24][0]"
assert env.rail.get_full_transitions(24, 1) == 72, "[24][1]"
assert env.rail.get_full_transitions(24, 2) == 1025, "[24][2]"
assert env.rail.get_full_transitions(24, 3) == 5633, "[24][3]"
assert env.rail.get_full_transitions(24, 4) == 17411, "[24][4]"
assert env.rail.get_full_transitions(24, 5) == 3089, "[24][5]"
assert env.rail.get_full_transitions(24, 6) == 1025, "[24][6]"
assert env.rail.get_full_transitions(24, 7) == 1097, "[24][7]"
assert env.rail.get_full_transitions(24, 8) == 5633, "[24][8]"
assert env.rail.get_full_transitions(24, 9) == 17411, "[24][9]"
assert env.rail.get_full_transitions(24, 10) == 1025, "[24][10]"
assert env.rail.get_full_transitions(24, 11) == 5633, "[24][11]"
assert env.rail.get_full_transitions(24, 12) == 1025, "[24][12]"
assert env.rail.get_full_transitions(24, 13) == 1025, "[24][13]"
assert env.rail.get_full_transitions(24, 14) == 2064, "[24][14]"
assert env.rail.get_full_transitions(24, 15) == 0, "[24][15]"
assert env.rail.get_full_transitions(24, 16) == 0, "[24][16]"
assert env.rail.get_full_transitions(24, 17) == 0, "[24][17]"
assert env.rail.get_full_transitions(24, 18) == 32800, "[24][18]"
assert env.rail.get_full_transitions(24, 19) == 32800, "[24][19]"
assert env.rail.get_full_transitions(24, 20) == 32800, "[24][20]"
assert env.rail.get_full_transitions(24, 21) == 32800, "[24][21]"
assert env.rail.get_full_transitions(24, 22) == 0, "[24][22]"
assert env.rail.get_full_transitions(24, 23) == 0, "[24][23]"
assert env.rail.get_full_transitions(24, 24) == 0, "[24][24]"
assert env.rail.get_full_transitions(25, 0) == 0, "[25][0]"
assert env.rail.get_full_transitions(25, 1) == 0, "[25][1]"
assert env.rail.get_full_transitions(25, 2) == 0, "[25][2]"
assert env.rail.get_full_transitions(25, 3) == 72, "[25][3]"
assert env.rail.get_full_transitions(25, 4) == 3089, "[25][4]"
assert env.rail.get_full_transitions(25, 5) == 5633, "[25][5]"
assert env.rail.get_full_transitions(25, 6) == 1025, "[25][6]"
assert env.rail.get_full_transitions(25, 7) == 17411, "[25][7]"
assert env.rail.get_full_transitions(25, 8) == 1097, "[25][8]"
assert env.rail.get_full_transitions(25, 9) == 2064, "[25][9]"
assert env.rail.get_full_transitions(25, 10) == 0, "[25][10]"
assert env.rail.get_full_transitions(25, 11) == 32800, "[25][11]"
assert env.rail.get_full_transitions(25, 12) == 0, "[25][12]"
assert env.rail.get_full_transitions(25, 13) == 0, "[25][13]"
assert env.rail.get_full_transitions(25, 14) == 0, "[25][14]"
assert env.rail.get_full_transitions(25, 15) == 0, "[25][15]"
assert env.rail.get_full_transitions(25, 16) == 0, "[25][16]"
assert env.rail.get_full_transitions(25, 17) == 0, "[25][17]"
assert env.rail.get_full_transitions(25, 18) == 72, "[25][18]"
assert env.rail.get_full_transitions(25, 19) == 37408, "[25][19]"
assert env.rail.get_full_transitions(25, 20) == 49186, "[25][20]"
assert env.rail.get_full_transitions(25, 21) == 2064, "[25][21]"
assert env.rail.get_full_transitions(25, 22) == 0, "[25][22]"
assert env.rail.get_full_transitions(25, 23) == 0, "[25][23]"
assert env.rail.get_full_transitions(25, 24) == 0, "[25][24]"
assert env.rail.get_full_transitions(26, 0) == 0, "[26][0]"
assert env.rail.get_full_transitions(26, 1) == 0, "[26][1]"
assert env.rail.get_full_transitions(26, 2) == 0, "[26][2]"
assert env.rail.get_full_transitions(26, 3) == 0, "[26][3]"
assert env.rail.get_full_transitions(26, 4) == 0, "[26][4]"
assert env.rail.get_full_transitions(26, 5) == 72, "[26][5]"
assert env.rail.get_full_transitions(26, 6) == 1025, "[26][6]"
assert env.rail.get_full_transitions(26, 7) == 2064, "[26][7]"
assert env.rail.get_full_transitions(26, 8) == 0, "[26][8]"
assert env.rail.get_full_transitions(26, 9) == 0, "[26][9]"
assert env.rail.get_full_transitions(26, 10) == 0, "[26][10]"
assert env.rail.get_full_transitions(26, 11) == 32800, "[26][11]"
assert env.rail.get_full_transitions(26, 12) == 0, "[26][12]"
assert env.rail.get_full_transitions(26, 13) == 0, "[26][13]"
assert env.rail.get_full_transitions(26, 14) == 0, "[26][14]"
assert env.rail.get_full_transitions(26, 15) == 0, "[26][15]"
assert env.rail.get_full_transitions(26, 16) == 0, "[26][16]"
assert env.rail.get_full_transitions(26, 17) == 0, "[26][17]"
assert env.rail.get_full_transitions(26, 18) == 0, "[26][18]"
assert env.rail.get_full_transitions(26, 19) == 32872, "[26][19]"
assert env.rail.get_full_transitions(26, 20) == 37408, "[26][20]"
assert env.rail.get_full_transitions(26, 21) == 0, "[26][21]"
assert env.rail.get_full_transitions(26, 22) == 0, "[26][22]"
assert env.rail.get_full_transitions(26, 23) == 0, "[26][23]"
assert env.rail.get_full_transitions(26, 24) == 0, "[26][24]"
assert env.rail.get_full_transitions(27, 0) == 0, "[27][0]"
assert env.rail.get_full_transitions(27, 1) == 0, "[27][1]"
assert env.rail.get_full_transitions(27, 2) == 0, "[27][2]"
assert env.rail.get_full_transitions(27, 3) == 0, "[27][3]"
assert env.rail.get_full_transitions(27, 4) == 0, "[27][4]"
assert env.rail.get_full_transitions(27, 5) == 0, "[27][5]"
assert env.rail.get_full_transitions(27, 6) == 0, "[27][6]"
assert env.rail.get_full_transitions(27, 7) == 0, "[27][7]"
assert env.rail.get_full_transitions(27, 8) == 0, "[27][8]"
assert env.rail.get_full_transitions(27, 9) == 0, "[27][9]"
assert env.rail.get_full_transitions(27, 10) == 0, "[27][10]"
assert env.rail.get_full_transitions(27, 11) == 32800, "[27][11]"
assert env.rail.get_full_transitions(27, 12) == 0, "[27][12]"
assert env.rail.get_full_transitions(27, 13) == 0, "[27][13]"
assert env.rail.get_full_transitions(27, 14) == 0, "[27][14]"
assert env.rail.get_full_transitions(27, 15) == 0, "[27][15]"
assert env.rail.get_full_transitions(27, 16) == 0, "[27][16]"
assert env.rail.get_full_transitions(27, 17) == 0, "[27][17]"
assert env.rail.get_full_transitions(27, 18) == 0, "[27][18]"
assert env.rail.get_full_transitions(27, 19) == 49186, "[27][19]"
assert env.rail.get_full_transitions(27, 20) == 2064, "[27][20]"
assert env.rail.get_full_transitions(27, 21) == 0, "[27][21]"
assert env.rail.get_full_transitions(27, 22) == 0, "[27][22]"
assert env.rail.get_full_transitions(27, 23) == 0, "[27][23]"
assert env.rail.get_full_transitions(27, 24) == 0, "[27][24]"
assert env.rail.get_full_transitions(28, 0) == 0, "[28][0]"
assert env.rail.get_full_transitions(28, 1) == 0, "[28][1]"
assert env.rail.get_full_transitions(28, 2) == 0, "[28][2]"
assert env.rail.get_full_transitions(28, 3) == 0, "[28][3]"
assert env.rail.get_full_transitions(28, 4) == 0, "[28][4]"
assert env.rail.get_full_transitions(28, 5) == 0, "[28][5]"
assert env.rail.get_full_transitions(28, 6) == 0, "[28][6]"
assert env.rail.get_full_transitions(28, 7) == 0, "[28][7]"
assert env.rail.get_full_transitions(28, 8) == 0, "[28][8]"
assert env.rail.get_full_transitions(28, 9) == 0, "[28][9]"
assert env.rail.get_full_transitions(28, 10) == 0, "[28][10]"
assert env.rail.get_full_transitions(28, 11) == 32800, "[28][11]"
assert env.rail.get_full_transitions(28, 12) == 0, "[28][12]"
assert env.rail.get_full_transitions(28, 13) == 0, "[28][13]"
assert env.rail.get_full_transitions(28, 14) == 0, "[28][14]"
assert env.rail.get_full_transitions(28, 15) == 0, "[28][15]"
assert env.rail.get_full_transitions(28, 16) == 0, "[28][16]"
assert env.rail.get_full_transitions(28, 17) == 0, "[28][17]"
assert env.rail.get_full_transitions(28, 18) == 0, "[28][18]"
assert env.rail.get_full_transitions(28, 19) == 32800, "[28][19]"
assert env.rail.get_full_transitions(28, 20) == 0, "[28][20]"
assert env.rail.get_full_transitions(28, 21) == 0, "[28][21]"
assert env.rail.get_full_transitions(28, 22) == 0, "[28][22]"
assert env.rail.get_full_transitions(28, 23) == 0, "[28][23]"
assert env.rail.get_full_transitions(28, 24) == 0, "[28][24]"
assert env.rail.get_full_transitions(29, 0) == 0, "[29][0]"
assert env.rail.get_full_transitions(29, 1) == 0, "[29][1]"
assert env.rail.get_full_transitions(29, 2) == 0, "[29][2]"
assert env.rail.get_full_transitions(29, 3) == 0, "[29][3]"
assert env.rail.get_full_transitions(29, 4) == 0, "[29][4]"
assert env.rail.get_full_transitions(29, 5) == 0, "[29][5]"
assert env.rail.get_full_transitions(29, 6) == 0, "[29][6]"
assert env.rail.get_full_transitions(29, 7) == 0, "[29][7]"
assert env.rail.get_full_transitions(29, 8) == 0, "[29][8]"
assert env.rail.get_full_transitions(29, 9) == 0, "[29][9]"
assert env.rail.get_full_transitions(29, 10) == 0, "[29][10]"
assert env.rail.get_full_transitions(29, 11) == 72, "[29][11]"
assert env.rail.get_full_transitions(29, 12) == 1025, "[29][12]"
assert env.rail.get_full_transitions(29, 13) == 1025, "[29][13]"
assert env.rail.get_full_transitions(29, 14) == 1025, "[29][14]"
assert env.rail.get_full_transitions(29, 15) == 1025, "[29][15]"
assert env.rail.get_full_transitions(29, 16) == 1025, "[29][16]"
assert env.rail.get_full_transitions(29, 17) == 1025, "[29][17]"
assert env.rail.get_full_transitions(29, 18) == 1025, "[29][18]"
assert env.rail.get_full_transitions(29, 19) == 2064, "[29][19]"
assert env.rail.get_full_transitions(29, 20) == 0, "[29][20]"
assert env.rail.get_full_transitions(29, 21) == 0, "[29][21]"
assert env.rail.get_full_transitions(29, 22) == 0, "[29][22]"
assert env.rail.get_full_transitions(29, 23) == 0, "[29][23]"
assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]"
def test_rail_env_action_required_info():
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env_always_action = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=10,
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=10,
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
# Reset the envs
env_always_action.reset(False, False, random_seed=5)
env_only_if_action_required.reset(False, False, random_seed=5)
assert env_only_if_action_required.rail.grid.tolist() == env_always_action.rail.grid.tolist()
for step in range(50):
print("step {}".format(step))
action_dict_always_action = dict()
action_dict_only_if_action_required = dict()
# Chose an action for each agent in the environment
for a in range(env_always_action.get_num_agents()):
action = np.random.choice(np.arange(4))
action_dict_always_action.update({a: action})
if step == 0 or info_only_if_action_required['action_required'][a]:
action_dict_only_if_action_required.update({a: action})
else:
print("[{}] not action_required {}, speed_counter={}".format(step, a,
env_always_action.agents[a].speed_counter))
obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
action_dict_always_action)
obs_only_if_action_required, rewards_only_if_action_required, done_only_if_action_required, info_only_if_action_required = env_only_if_action_required.step(
action_dict_only_if_action_required)
for a in range(env_always_action.get_num_agents()):
assert len(obs_always_action[a]) == len(obs_only_if_action_required[a])
for i in range(len(obs_always_action[a])):
assert len(obs_always_action[a][i]) == len(obs_only_if_action_required[a][i])
equal = np.array_equal(obs_always_action[a][i], obs_only_if_action_required[a][i])
if not equal:
for r in range(50):
for c in range(50):
assert np.array_equal(obs_always_action[a][i][(r, c)], obs_only_if_action_required[a][i][
(r, c)]), "[{}] a={},i={},{}\n{}\n\nvs.\n\n{}".format(step, a, i, (r, c),
obs_always_action[a][i][(r, c)],
obs_only_if_action_required[a][
i][(r, c)])
assert equal, \
"[{}] [{}][{}] {} vs. {}".format(step, a, i, obs_always_action[a][i],
obs_only_if_action_required[a][i])
assert np.array_equal(rewards_always_action[a], rewards_only_if_action_required[a])
assert np.array_equal(done_always_action[a], done_only_if_action_required[a])
assert info_always_action['action_required'][a] == info_only_if_action_required['action_required'][a]
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
if done_always_action['__all__']:
break
env_renderer.close_window()
def test_rail_env_malfunction_speed_info():
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=5,
grid_mode=False
),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False)
env_renderer = RenderTool(env, gl="PILSVG", )
for step in range(100):
action_dict = dict()
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = np.random.choice(np.arange(4))
action_dict.update({a: action})
obs, rewards, done, info = env.step(
action_dict)
assert 'malfunction' in info
for a in range(env.get_num_agents()):
assert info['malfunction'][a] >= 0
assert info['speed'][a] >= 0 and info['speed'][a] <= 1
assert info['speed'][a] == env.agents[a].speed_counter.speed
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
if done['__all__']:
break
env_renderer.close_window()
def test_sparse_generator_with_too_man_cities_does_not_break_down():
RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=3,
seed=5,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv())
def test_sparse_generator_with_illegal_params_aborts():
"""
Test that the constructor aborts if the initial parameters don't allow more than one city to be built.
"""
with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, ValueError):
RailEnv(width=6, height=6, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=3,
seed=5,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()).reset()
with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, ValueError):
RailEnv(width=60, height=60, rail_generator=sparse_rail_generator(
max_num_cities=1,
max_rails_between_cities=3,
seed=5,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()).reset()
def test_sparse_generator_changes_to_grid_mode():
"""
Test that grid mode is evoked and two cities are created when env is too small to find random cities.
We set the limit of the env such that two cities fit in grid mode but unlikely under random mode
we initiate random seed to be sure that we never create random cities.
"""
rail_env = RailEnv(width=10, height=20, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=2,
max_rail_pairs_in_city=1,
seed=15,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
with warnings.catch_warnings(record=True) as w:
rail_env.reset(True, True, random_seed=15)
assert "[WARNING]" in str(w[-1].message)
from test_utils import create_and_save_env
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.envs.line_generators import sparse_line_generator, line_from_file
def test_line_from_file_sparse():
"""
Test to see that all parameters are loaded as expected
Returns
-------
"""
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# Generate Sparse test env
rail_generator = sparse_rail_generator(max_num_cities=5,
seed=1,
grid_mode=False,
max_rails_between_cities=3,
max_rail_pairs_in_city=3,
)
line_generator = sparse_line_generator(speed_ration_map)
env = create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator,
line_generator=line_generator)
old_num_steps = env._max_episode_steps
old_num_agents = len(env.agents)
# Sparse generator
rail_generator = rail_from_file("./sparse_env_test.pkl")
line_generator = line_from_file("./sparse_env_test.pkl")
sparse_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator,
line_generator=line_generator)
sparse_env_from_file.reset(True, True)
# Assert loaded agent number is correct
assert sparse_env_from_file.get_num_agents() == old_num_agents
# Assert max steps is correct
assert sparse_env_from_file._max_episode_steps == old_num_steps
\ No newline at end of file
import random
from typing import Dict, List
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
class SingleAgentNavigationObs(ObservationBuilder):
"""
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__()
def reset(self):
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent_virtual_position, direction)
min_distances.append(
self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = MalfunctionParameters(malfunction_rate=1, # Rate of malfunction occurence
min_duration=3, # Minimal duration of malfunction
max_duration=3 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
obs, info = env.reset(False, False, random_seed=10)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].state = TrainState.MOVING
agent_halts = 0
total_down_time = 0
agent_old_position = env.agents[0].position
# Move target to unreachable position in order to not interfere with test
env.agents[0].target = (0, 0)
# Add in max episode steps because scheudule generator sets it to 0 for dummy data
env._max_episode_steps = 200
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
obs, all_rewards, done, _ = env.step(actions)
if done["__all__"]:
break
if env.agents[0].malfunction_handler.malfunction_down_counter > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
if agent_malfunctioning:
# Check that agent is not moving while malfunctioning
assert agent_old_position == env.agents[0].position
agent_old_position = env.agents[0].position
total_down_time += env.agents[0].malfunction_handler.malfunction_down_counter
# Check that the appropriate number of malfunctions is achieved
# Dipam: The number of malfunctions varies by seed
assert env.agents[0].malfunction_handler.num_malfunctions == 28, "Actual {}".format(
env.agents[0].malfunction_handler.num_malfunctions)
# Check that malfunctioning data was standing around
assert total_down_time > 0
def test_malfunction_process_statistically():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = MalfunctionParameters(malfunction_rate=1/5, # Rate of malfunction occurence
min_duration=5, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(True, True, random_seed=10)
env._max_episode_steps = 1000
env.agents[0].target = (0, 0)
# Next line only for test generation
agent_malfunction_list = [[] for i in range(2)]
agent_malfunction_list = [[0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1],
[0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent_idx in range(env.get_num_agents()):
# We randomly select an action
action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
# For generating tests only:
# agent_malfunction_list[agent_idx].append(
# env.agents[agent_idx].malfunction_handler.malfunction_down_counter)
assert env.agents[agent_idx].malfunction_handler.malfunction_down_counter == \
agent_malfunction_list[agent_idx][step]
env.step(action_dict)
def test_malfunction_before_entry():
"""Tests that malfunctions are working properly for agents before entering the environment!"""
# Set fixed malfunction duration for this test
stochastic_data = MalfunctionParameters(malfunction_rate=1/2, # Rate of malfunction occurrence
min_duration=10, # Minimal duration of malfunction
max_duration=10 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(False, False, random_seed=10)
env.agents[0].target = (0, 0)
# Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents
malfunction_values = [env.malfunction_generator(env.np_random).num_broken_steps for _ in range(1000)]
expected_value = (1 - np.exp(-0.5)) * 10
assert np.allclose(np.mean(malfunction_values), expected_value, rtol=0.1), "Mean values of malfunction don't match rate"
def test_malfunction_values_and_behavior():
"""
Test the malfunction counts down as desired
Returns
-------
"""
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001, # Rate of malfunction occurence
min_duration=10, # Minimal duration of malfunction
max_duration=10 # Max duration of malfunction
)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 20
# Assertions
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5]
for time_step in range(15):
# Move in the env
_, _, dones,_ = env.step(action_dict)
# Check that next_step decreases as expected
assert env.agents[0].malfunction_handler.malfunction_down_counter == assert_list[time_step]
if dones['__all__']:
break
def test_initial_malfunction():
stochastic_data = MalfunctionParameters(malfunction_rate=1/1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, random_seed=10)
env._max_episode_steps = 1000
print(env.agents[0].malfunction_handler)
env.agents[0].target = (0, 5)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay( # 0
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty when malfunctioning
),
Replay( # 1
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2,
reward=env.step_penalty # full step penalty when malfunctioning
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay( # 2
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1,
reward=env.step_penalty
), # malfunctioning ends: starting and running at speed 1.0
Replay( # 3
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0
),
Replay( # 4
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty # running at speed 1.0
)
],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], skip_reward_check=True)
def test_initial_malfunction_stop_moving():
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
env.reset()
env._max_episode_steps = 1000
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].state)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay( # 0
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty, # full step penalty when stopped
state=TrainState.READY_TO_DEPART
),
Replay( # 1
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty, # full step penalty when stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
#
Replay( # 2
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# we have stopped and do nothing --> should stand still
Replay( # 3
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# we start to move forward --> should go to next cell now
Replay( # 4
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.MOVING
),
Replay( # 5
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.STOPPED
),
Replay( # 6
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.MOVING
),
Replay( # 6
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.STOPPED
)
],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], activate_agents=False,
skip_reward_check=True, set_ready_to_depart=True, skip_action_required_check=True)
def test_initial_malfunction_do_nothing():
stochastic_data = MalfunctionParameters(malfunction_rate=1/70, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
)
env.reset()
env._max_episode_steps = 1000
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty, # full step penalty while malfunctioning
state=TrainState.READY_TO_DEPART
),
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=None,
malfunction=2,
reward=env.step_penalty, # full step penalty while malfunctioning
state=TrainState.MALFUNCTION_OFF_MAP
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=None,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# we haven't started moving yet --> stay here
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=None,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
Replay(
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
state=TrainState.MOVING
), # we start to move forward --> should go to next cell now
Replay(
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # step penalty for speed 1.0
state=TrainState.MOVING
)
],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], activate_agents=False,
skip_reward_check=True, set_ready_to_depart=True)
def tests_random_interference_from_outside():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_counter = SpeedCounter(speed=0.33)
env.reset(False, False, random_seed=10)
env_data = []
for step in range(200):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(2)
_, reward, dones, _ = env.step(action_dict)
# Append the rewards of the first trial
env_data.append((reward[0], env.agents[0].position))
assert reward[0] == env_data[step][0]
assert env.agents[0].position == env_data[step][1]
if dones['__all__']:
break
# Run the same test as above but with an external random generator running
# Check that the reward stays the same
rail, rail_map, optionals = make_simple_rail2()
random.seed(47)
np.random.seed(1234)
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_counter = SpeedCounter(speed=0.33)
env.reset(False, False, random_seed=10)
dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
for step in range(200):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(2)
# Do dummy random number generations
random.shuffle(dummy_list)
np.random.rand()
_, reward, dones, _ = env.step(action_dict)
assert reward[0] == env_data[step][0]
assert env.agents[0].position == env_data[step][1]
if dones['__all__']:
break
def test_last_malfunction_step():
"""
Test to check that agent moves when it is not malfunctioning
"""
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_counter = SpeedCounter(speed=1./3.)
env.agents[0].initial_position = (6, 6)
env.agents[0].initial_direction = 2
env.agents[0].target = (0, 3)
env._max_episode_steps = 1000
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].state = TrainState.MOVING
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
# env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_handler.malfunction_down_counter = 0
env_data = []
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# Go forward all the time
action_dict[agent.handle] = RailEnvActions(2)
if env.agents[0].malfunction_handler.malfunction_down_counter < 1:
agent_can_move = True
# Store the position before and after the step
pre_position = env.agents[0].speed_counter.counter
_, reward, _, _ = env.step(action_dict)
# Check if the agent is still allowed to move in this step
if env.agents[0].malfunction_handler.malfunction_down_counter > 0:
agent_can_move = False
post_position = env.agents[0].speed_counter.counter
# Assert that the agent moved while it was still allowed
if agent_can_move:
assert pre_position != post_position
else:
assert post_position == pre_position
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from multiprocessing.pool import Pool
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
def test_multiprocessing_tree_obs():
number_of_agents = 5
rail, rail_map, optionals = make_simple_rail()
optionals['agents_hints']['num_agents'] = number_of_agents
obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=number_of_agents,
obs_builder_object=obs_builder)
env.reset(True, True)
pool = Pool()
pool.map(obs_builder.get, range(number_of_agents))
def main():
test_multiprocessing_tree_obs()
if __name__ == "__main__":
main()
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
env.reset()
env._max_episode_steps = 1000
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5,
),
Replay(
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_LEFT,
reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
state=TrainState.MOVING
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, #
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, #
state=TrainState.MOVING
),
# Replay(
# position=(3, 5),
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE
# ),
# Replay(
# position=(3, 5),
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE
# )
],
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
target=(3, 5),
speed=0.5
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True,
set_ready_to_depart=True)
assert env.agents[0].state == TrainState.DONE
def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True)
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
env._max_episode_steps = 1000
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5,
),
Replay(
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, # running at speed 0.5
state=TrainState.MOVING
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # done
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, # already done
state=TrainState.MOVING
),
# Replay(
# position=None,
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE_REMOVED
# ),
# Replay(
# position=None,
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE_REMOVED
# )
],
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
target=(3, 5),
speed=0.5
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True,
set_ready_to_depart=True)
assert env.agents[0].state == TrainState.DONE
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.step_utils.states import TrainState
def test_return_to_ready_to_depart():
"""
When going from ready to depart to malfunction off map, if do nothing is provided, should return to ready to depart
"""
stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence
min_duration=0, # Minimal duration of malfunction
max_duration=0 # Max duration of malfunction
)
rail, _, optionals = make_simple_rail()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 100
for _ in range(3):
env.step({0: RailEnvActions.DO_NOTHING})
env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
env.step({0: RailEnvActions.DO_NOTHING})
assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
for _ in range(2):
env.step({0: RailEnvActions.DO_NOTHING})
assert env.agents[0].state == TrainState.READY_TO_DEPART
def test_ready_to_depart_to_stopped():
"""
When going from ready to depart to malfunction off map, if stopped is provided, should go to stopped
"""
stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence
min_duration=0, # Minimal duration of malfunction
max_duration=0 # Max duration of malfunction
)
rail, _, optionals = make_simple_rail()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 100
for _ in range(3):
env.step({0: RailEnvActions.STOP_MOVING})
assert env.agents[0].state == TrainState.READY_TO_DEPART
env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
env.step({0: RailEnvActions.STOP_MOVING})
assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
for _ in range(2):
env.step({0: RailEnvActions.STOP_MOVING})
assert env.agents[0].state == TrainState.STOPPED
def test_malfunction_no_phase_through():
"""
A moving train shouldn't phase through a malfunctioning train
"""
stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence
min_duration=0, # Minimal duration of malfunction
max_duration=0 # Max duration of malfunction
)
rail, _, optionals = make_simple_rail()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=2,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
)
env.reset(False, False, random_seed=10)
for _ in range(5):
env.step({0: RailEnvActions.MOVE_FORWARD, 1: RailEnvActions.MOVE_FORWARD})
env.agents[1].malfunction_handler._set_malfunction_down_counter(10)
for _ in range(3):
env.step({0: RailEnvActions.MOVE_FORWARD, 1: RailEnvActions.DO_NOTHING})
assert env.agents[0].state == TrainState.STOPPED
assert env.agents[0].position == (3, 5)
\ No newline at end of file
......@@ -6,22 +6,22 @@ Tests for `flatland` package.
import sys
import matplotlib.pyplot as plt
import numpy as np
from importlib_resources import path
import flatland.utils.rendertools as rt
import images.test
from flatland.envs.generators import empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import empty_rail_generator
import pytest
def checkFrozenImage(oRT, sFileImage, resave=False):
sDirRoot = "."
sDirImages = sDirRoot + "/images/"
img_test = oRT.getImage()
img_test = oRT.get_image()
if resave:
np.savez_compressed(sDirImages + sFileImage, img=img_test)
......@@ -35,45 +35,20 @@ def checkFrozenImage(oRT, sFileImage, resave=False):
# assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \ # noqa: E800
# "Image {} does not match".format(sFileImage) \ # noqa: E800
@pytest.mark.skip("Only needed for visual editor, Flatland 3 line generator won't allow empty enviroment")
def test_render_env(save_new_images=False):
np.random.seed(100)
oEnv = RailEnv(width=10, height=10,
rail_generator=empty_rail_generator(),
number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2)
)
oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
oEnv.reset()
oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
oRT = rt.RenderTool(oEnv, gl="PILSVG")
oRT.renderEnv(show=False)
oRT.render_env(show=False)
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
oRT = rt.RenderTool(oEnv, gl="PIL")
oRT.renderEnv()
oRT.render_env()
checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
# disable the tree / observation tests until env-agent save/load is available
if False:
lVisits = oRT.getTreeFromRail(
oEnv.agents_position[0],
oEnv.agents_direction[0],
nDepth=17, bPlot=True)
checkFrozenImage("env-tree-spatial.png")
plt.figure(figsize=(8, 8))
xyTarg = oRT.env.agents_target[0]
visitDest = oRT.plotTree(lVisits, xyTarg)
checkFrozenImage("env-tree-graph.png")
plt.figure(figsize=(10, 10))
oRT.renderEnv()
oRT.plotPath(visitDest)
checkFrozenImage("env-path.png")
def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, empty_rail_generator
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.step_utils.states import TrainState
def test_empty_rail_generator():
n_agents = 2
x_dim = 5
y_dim = 10
# Check that a random level at with correct parameters is generated
rail, _ = empty_rail_generator().generate(width=x_dim, height=y_dim, num_agents=n_agents)
# Check the dimensions
assert rail.grid.shape == (y_dim, x_dim)
# Check that no grid was generated
assert np.count_nonzero(rail.grid) == 0
def test_rail_from_grid_transition_map():
rail, rail_map, optionals = make_simple_rail()
n_agents = 2
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx]._set_state(TrainState.MOVING)
nr_rail_elements = np.count_nonzero(env.rail.grid)
# Check if the number of non-empty rail cells is ok
assert nr_rail_elements == 16
# Check that agents are placed on a rail
for a in env.agents:
assert env.rail.grid[a.position] != 0
assert env.get_num_agents() == n_agents
def tests_rail_from_file():
file_name = "test_with_distance_map.pkl"
# Test to save and load file with distance map.
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
#env.save(file_name)
RailEnvPersister.save(env, file_name)
dist_map_shape = np.shape(env.distance_map.get())
rails_initial = env.rail.grid
agents_initial = env.agents
env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
rails_loaded = env.rail.grid
agents_loaded = env.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert np.shape(env.distance_map.get()) == dist_map_shape
assert env.distance_map.get() is not None
# Test to save and load file without distance map.
file_name_2 = "test_without_distance_map.pkl"
env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(),
number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
env2.reset()
#env2.save(file_name_2)
RailEnvPersister.save(env2, file_name_2)
rails_initial_2 = env2.rail.grid
agents_initial_2 = env2.agents
env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
line_generator=line_from_file(file_name_2), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env2.reset()
rails_loaded_2 = env2.rail.grid
agents_loaded_2 = env2.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_2):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
assert agents_initial_2 == agents_loaded_2
assert not hasattr(env2.obs_builder, "distance_map")
# Test to save with distance map and load without
env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env3.reset()
rails_loaded_3 = env3.rail.grid
agents_loaded_3 = env3.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded_3):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded_3))
assert agents_initial == agents_loaded_3
assert not hasattr(env2.obs_builder, "distance_map")
# Test to save without distance map and load with generating distance map
env4 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name_2),
line_generator=line_from_file(file_name_2),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
)
env4.reset()
rails_loaded_4 = env4.rail.grid
agents_loaded_4 = env4.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_4):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
# Check that no distance map was saved
assert not hasattr(env2.obs_builder, "distance_map")
assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
assert agents_initial_2 == agents_loaded_4
# Check that distance map was generated with correct shape
assert env4.distance_map.get() is not None
assert np.shape(env4.distance_map.get()) == dist_map_shape
def main():
tests_rail_from_file()
if __name__ == "__main__":
main()
import numpy as np
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.step_utils.states import TrainState
def test_get_global_observation():
number_of_agents = 20
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=6,
max_rails_between_cities=4,
seed=15,
grid_mode=False
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=number_of_agents,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
for i in range(len(env.agents)):
agent: EnvAgent = env.agents[i]
print("[{}] state={}, position={}, target={}, initial_position={}".format(i, agent.state, agent.position,
agent.target,
agent.initial_position))
for i, agent in enumerate(env.agents):
obs_agents_state = obs[i][1]
obs_targets = obs[i][2]
# test first channel of obs_targets: own target
nr_agents = np.count_nonzero(obs_targets[:, :, 0])
assert nr_agents == 1, "agent {}: something wrong with own target, found {}".format(i, nr_agents)
# test second channel of obs_targets: other agent's target
for r in range(env.height):
for c in range(env.width):
_other_agent_target = 0
for other_i, other_agent in enumerate(env.agents):
if other_agent.target == (r, c):
_other_agent_target = 1
break
assert obs_targets[(r, c)][
1] == _other_agent_target, "agent {}: at {} expected to be other agent's target = {}".format(
i, (r, c),
_other_agent_target)
# test first channel of obs_agents_state: direction at own position
for r in range(env.height):
for c in range(env.width):
if (agent.state.is_on_map_state() or agent.state == TrainState.DONE) and (
r, c) == agent.position:
assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
"agent {} in state {} at {} expected to contain own direction {}, found {}" \
.format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
elif (agent.state == TrainState.READY_TO_DEPART) and (r, c) == agent.initial_position:
assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
"agent {} in state {} at {} expected to contain own direction {}, found {}" \
.format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
else:
assert np.isclose(obs_agents_state[(r, c)][0], -1), \
"agent {} in state {} at {} expected contain -1 found {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][0])
# test second channel of obs_agents_state: direction at other agents position
for r in range(env.height):
for c in range(env.width):
has_agent = False
for other_i, other_agent in enumerate(env.agents):
if i == other_i:
continue
if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, TrainState.DONE] and (
r, c) == other_agent.position:
assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \
"agent {} in state {} at {} should see other agent with direction {}, found = {}" \
.format(i, agent.state, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
has_agent = True
if not has_agent:
assert np.isclose(obs_agents_state[(r, c)][1], -1), \
"agent {} in state {} at {} should see no other agent direction (-1), found = {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][1])
# test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid
for r in range(env.height):
for c in range(env.width):
has_agent = False
for other_i, other_agent in enumerate(env.agents):
if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED,
TrainState.DONE] and other_agent.position == (r, c):
assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_handler.malfunction_down_counter), \
"agent {} in state {} at {} should see agent malfunction {}, found = {}" \
.format(i, agent.state, (r, c), other_agent.malfunction_handler.malfunction_down_counter,
obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed)
has_agent = True
if not has_agent:
assert np.isclose(obs_agents_state[(r, c)][2], -1), \
"agent {} in state {} at {} should see no agent malfunction (-1), found = {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], -1), \
"agent {} in state {} at {} should see no agent speed (-1), found = {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][3])
# test fifth channel of obs_agents_state: number of agents ready to depart in to this cell
for r in range(env.height):
for c in range(env.width):
count = 0
for other_i, other_agent in enumerate(env.agents):
if other_agent.state == TrainState.READY_TO_DEPART and other_agent.initial_position == (r, c):
count += 1
assert np.isclose(obs_agents_state[(r, c)][4], count), \
"agent {} in state {} at {} should see {} agents ready to depart, found{}" \
.format(i, agent.state, (r, c), count, obs_agents_state[(r, c)][4])
from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file, \
single_malfunction_generator, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from flatland.envs.persistence import RailEnvPersister
import pytest
def test_malfanction_from_params():
"""
Test loading malfunction from
Returns
-------
"""
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env.reset()
assert env.malfunction_process_data.malfunction_rate == 1000
assert env.malfunction_process_data.min_duration == 2
assert env.malfunction_process_data.max_duration == 5
def test_malfanction_to_and_from_file():
"""
Test loading malfunction from
Returns
-------
"""
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env.reset()
#env.save("./malfunction_saving_loading_tests.pkl")
RailEnvPersister.save(env, "./malfunction_saving_loading_tests.pkl")
malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
env2 = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env2.reset()
assert env2.malfunction_process_data == env.malfunction_process_data
assert env2.malfunction_process_data.malfunction_rate == 1000
assert env2.malfunction_process_data.min_duration == 2
assert env2.malfunction_process_data.max_duration == 5
@pytest.mark.skip("Single malfunction generator is deprecated")
def test_single_malfunction_generator():
"""
Test single malfunction generator
Returns
-------
"""
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=3,
malfunction_duration=5)
)
for test in range(10):
env.reset()
action_dict = dict()
tot_malfunctions = 0
print(test)
for i in range(10):
for agent in env.agents:
# Go forward all the time
action_dict[agent.handle] = RailEnvActions(2)
_, _, dones, _ = env.step(action_dict)
if dones['__all__']:
break
for agent in env.agents:
# Go forward all the time
tot_malfunctions += agent.malfunction_handler.num_malfunctions
assert tot_malfunctions == 1
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
# Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
#
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.np_random = np.random.RandomState(seed=42)
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return self.np_random.choice([1, 2, 3])
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
def test_multi_speed_init():
env = RailEnv(width=50, height=50,
rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(),
random_seed=3,
number_of_agents=3)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
# Empty dictionary for all agent action
action_dict = dict()
# Set all the different speeds
# Reset environment and get initial observations for all agents
env.reset(False, False)
env._max_episode_steps = 1000
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx]._set_state(TrainState.MOVING)
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
old_pos = []
for i_agent in range(env.get_num_agents()):
env.agents[i_agent].speed_counter = SpeedCounter(speed = 1. / (i_agent + 1))
old_pos.append(env.agents[i_agent].position)
print(env.agents[i_agent].position)
# Run episode
for step in range(100):
# Choose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(0)
action_dict.update({a: action})
# Check that agent did not move in between its speed updates
assert old_pos[a] == env.agents[a].position
# Environment step which returns the observations for all agents, their corresponding
# reward and whether they are done
_, _, _, _ = env.step(action_dict)
# Update old position whenever an agent was allowed to move
for i_agent in range(env.get_num_agents()):
if (step + 1) % (i_agent + 1) == 0:
print(step, i_agent, env.agents[i_agent].position)
old_pos[i_agent] = env.agents[i_agent].position
def test_multispeed_actions_no_malfunction_no_blocking():
"""Test that actions are correctly performed on cell exit for a single agent."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
env._max_episode_steps = 1000
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING,
reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty
),
#
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING,
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when stopped
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting + running at speed 0.5
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
],
target=(3, 0), # west dead-end
speed=0.5,
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [test_config], skip_reward_check=True, skip_action_required_check=True)
def test_multispeed_actions_no_malfunction_blocking():
"""The second agent blocks the first because it is slower."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
random_seed=1)
env.reset()
set_penalties_for_replay(env)
test_configs = [
ReplayConfig(
replay=[
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 1.0 / 3.0 # starting and running at speed 1/3
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 1.0 / 3.0 # running at speed 1/3
)
],
target=(3, 0), # west dead-end
speed=1 / 3,
initial_position=(3, 8),
initial_direction=Grid4TransitionsEnum.WEST,
),
ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# blocked although fraction >= 1.0
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# blocked although fraction >= 1.0
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# blocked although fraction >= 1.0
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# not blocked, action required!
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
],
target=(3, 0), # west dead-end
speed=0.5,
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
)
]
run_replay_config(env, test_configs, skip_reward_check=True)
def test_multispeed_actions_malfunction_no_blocking():
"""Test on a single agent whether action on cell exit work correctly despite malfunction."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents]) + 1):
env.step({}) # DO_NOTHING for all agents
env._max_episode_steps = 10000
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay( # 0
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5
),
Replay( # 1
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 2
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# add additional step in the cell
Replay( # 3
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
set_malfunction=2, # recovers in two steps from now!,
malfunction=2,
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning
),
# agent recovers in this step
Replay( # 4
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=1,
reward=env.step_penalty * 0.5 # recovered: running at speed 0.5
),
Replay( # 5
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 6
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 7
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
set_malfunction=2, # recovers in two steps from now!
malfunction=2,
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning
),
# agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
Replay( # 8
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=1,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 9
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 10
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.STOP_MOVING,
reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty for speed 0.5
),
Replay( # 11
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.STOP_MOVING,
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 while stopped
),
Replay( # 12
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5
),
Replay( # 13
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# DO_NOTHING keeps moving!
Replay( # 14
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 15
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay( # 16
position=(3, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
],
target=(3, 0), # west dead-end
speed=0.5,
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [test_config], skip_reward_check=True)
# TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour?
def test_multispeed_actions_no_malfunction_invalid_actions():
"""Test that actions are correctly performed on cell exit for a single agent."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
env._max_episode_steps = 10000
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_LEFT,
reward=env.start_penalty + env.step_penalty * 0.5 # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5 # wrong action is corrected to forward without penalty!
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5 # wrong action is corrected to forward without penalty!
), Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
],
target=(3, 0), # west dead-end
speed=0.5,
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [test_config], skip_reward_check=True)
import pytest
@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers")
def test_petting_zoo_interface_env():
import numpy as np
import os
import PIL
import shutil
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# First of all we import the Flatland rail environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper
from flatland.contrib.wrappers.flatland_wrappers import ShortestPathActionWrapper # noqa
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 11
save = True
np.random.seed(seed)
experiment_name = "flatland_pettingzoo"
total_episodes = 2
if save:
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
rail_env.reset(random_seed=seed)
# For Shortest Path Action Wrapper, change action to 1
# rail_env = ShortestPathActionWrapper(rail_env)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
dones = {}
dones['__all__'] = False
step = 0
ep_no = 0
frame_list = []
all_actions_env = []
all_actions_pettingzoo_env = []
# while not dones['__all__']:
while ep_no < total_episodes:
action_dict = {}
# Chose an action for each agent
for a in range(rail_env.get_num_agents()):
# action = env_generators.get_shortest_path_action(rail_env, a)
action = 2
all_actions_env.append(action)
action_dict.update({a: action})
step += 1
# Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict)
frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array")))
if dones['__all__']:
completion = env_generators.perc_completion(rail_env)
print("Final Agents Completed:", completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
rail_env.reset(random_seed=seed+ep_no)
# __sphinx_doc_begin__
env = flatland_env.env(environment=rail_env)
seed = 11
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < total_episodes:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
# act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent))
act = 2
all_actions_pettingzoo_env.append(act)
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
# __sphinx_doc_end__
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env))
assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match"
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-sv", __file__]))
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map, sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
def ndom_seeding():
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
# Move target to unreachable position in order to not interfere with test
for idx in range(100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=12), number_of_agents=10)
env.reset(True, True, random_seed=1)
env.agents[0].target = (0, 0)
for step in range(10):
actions = {}
actions[0] = 2
env.step(actions)
agent_positions = []
env.agents[0].initial_position == (3, 2)
env.agents[1].initial_position == (3, 5)
env.agents[2].initial_position == (3, 6)
env.agents[3].initial_position == (5, 6)
env.agents[4].initial_position == (3, 4)
env.agents[5].initial_position == (3, 1)
env.agents[6].initial_position == (3, 9)
env.agents[7].initial_position == (4, 6)
env.agents[8].initial_position == (0, 3)
env.agents[9].initial_position == (3, 7)
# Test generation print
# for a in range(env.get_num_agents()):
# print("env.agents[{}].initial_position == {}".format(a,env.agents[a].initial_position))
# print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position))
# print("assert env.agents[0].position == {}".format(env.agents[0].position))
def test_seeding_and_observations():
# Test if two different instances diverge with different observations
rail, rail_map, optionals = make_simple_rail2()
optionals['agents_hints']['num_agents'] = 10
# Make two seperate envs with different observation builders
# Global Observation
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=12), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
# Tree Observation
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=12), number_of_agents=10,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset(False, False, random_seed=12)
env2.reset(False, False, random_seed=12)
# Check that both environments produce the same initial start positions
assert env.agents[0].initial_position == env2.agents[0].initial_position
assert env.agents[1].initial_position == env2.agents[1].initial_position
assert env.agents[2].initial_position == env2.agents[2].initial_position
assert env.agents[3].initial_position == env2.agents[3].initial_position
assert env.agents[4].initial_position == env2.agents[4].initial_position
assert env.agents[5].initial_position == env2.agents[5].initial_position
assert env.agents[6].initial_position == env2.agents[6].initial_position
assert env.agents[7].initial_position == env2.agents[7].initial_position
assert env.agents[8].initial_position == env2.agents[8].initial_position
assert env.agents[9].initial_position == env2.agents[9].initial_position
action_dict = {}
for step in range(10):
for a in range(env.get_num_agents()):
action = np.random.randint(4)
action_dict[a] = action
env.step(action_dict)
env2.step(action_dict)
# Check that both environments end up in the same position
assert env.agents[0].position == env2.agents[0].position
assert env.agents[1].position == env2.agents[1].position
assert env.agents[2].position == env2.agents[2].position
assert env.agents[3].position == env2.agents[3].position
assert env.agents[4].position == env2.agents[4].position
assert env.agents[5].position == env2.agents[5].position
assert env.agents[6].position == env2.agents[6].position
assert env.agents[7].position == env2.agents[7].position
assert env.agents[8].position == env2.agents[8].position
assert env.agents[9].position == env2.agents[9].position
for a in range(env.get_num_agents()):
print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a))
def test_seeding_and_malfunction():
# Test if two different instances diverge with different observations
rail, rail_map, optionals = make_simple_rail2()
optionals['agents_hints']['num_agents'] = 10
stochastic_data = {'prop_malfunction': 0.4,
'malfunction_rate': 2,
'min_duration': 10,
'max_duration': 10}
# Make two seperate envs with different and see if the exhibit the same malfunctions
# Global Observation
for tests in range(1, 100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
# Tree Observation
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(True, False, random_seed=tests)
env2.reset(True, False, random_seed=tests)
# Check that both environments produce the same initial start positions
assert env.agents[0].initial_position == env2.agents[0].initial_position
assert env.agents[1].initial_position == env2.agents[1].initial_position
assert env.agents[2].initial_position == env2.agents[2].initial_position
assert env.agents[3].initial_position == env2.agents[3].initial_position
assert env.agents[4].initial_position == env2.agents[4].initial_position
assert env.agents[5].initial_position == env2.agents[5].initial_position
assert env.agents[6].initial_position == env2.agents[6].initial_position
assert env.agents[7].initial_position == env2.agents[7].initial_position
assert env.agents[8].initial_position == env2.agents[8].initial_position
assert env.agents[9].initial_position == env2.agents[9].initial_position
action_dict = {}
for step in range(10):
for a in range(env.get_num_agents()):
action = np.random.randint(4)
action_dict[a] = action
# print("----------------------")
# print(env.agents[a].malfunction_handler, env.agents[a].status)
# print(env2.agents[a].malfunction_handler, env2.agents[a].status)
_, reward1, done1, _ = env.step(action_dict)
_, reward2, done2, _ = env2.step(action_dict)
for a in range(env.get_num_agents()):
assert reward1[a] == reward2[a]
assert done1[a] == done2[a]
# Check that both environments end up in the same position
assert env.agents[0].position == env2.agents[0].position
assert env.agents[1].position == env2.agents[1].position
assert env.agents[2].position == env2.agents[2].position
assert env.agents[3].position == env2.agents[3].position
assert env.agents[4].position == env2.agents[4].position
assert env.agents[5].position == env2.agents[5].position
assert env.agents[6].position == env2.agents[6].position
assert env.agents[7].position == env2.agents[7].position
assert env.agents[8].position == env2.agents[8].position
assert env.agents[9].position == env2.agents[9].position
def test_reproducability_env():
"""
Test that no random generators are present within the env that get influenced by external np random
"""
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=10, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
env.reset(True, True, random_seed=1)
excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[16386, 17411, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608],
[32800, 32800, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 72, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408],
[32800, 32800, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 34864],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[72, 37408, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[0, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 37408],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32872, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 34864],
[0, 72, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 2064],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
assert env.rail.grid.tolist() == excpeted_grid
# Test that we don't have interference from calling mulitple function outisde
env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=10, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
np.random.seed(1)
for i in range(10):
np.random.randn()
env2.reset(True, True, random_seed=1)
assert env2.rail.grid.tolist() == excpeted_grid