Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 5059 additions and 0 deletions
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
def test_vec2d_is_equal():
node_a = (1, 2)
node_b = (2, 4)
node_c = (1, 2)
res_1 = Vec2d.is_equal(node_a, node_b)
res_2 = Vec2d.is_equal(node_a, node_c)
assert not res_1
assert res_2
def test_vec2d_subtract():
node_a = (1, 2)
node_b = (2, 4)
res_1 = Vec2d.subtract(node_a, node_b)
res_2 = Vec2d.subtract(node_b, node_a)
assert res_1 != res_2
assert res_1 == (-1, -2)
assert res_2 == (1, 2)
def test_vec2d_add():
node_a = (1, 2)
node_b = (2, 3)
res_1 = Vec2d.add(node_a, node_b)
res_2 = Vec2d.add(node_b, node_a)
assert res_1 == res_2
assert res_1 == (3, 5)
def test_vec2d_make_orthogonal():
node_a = (1, 2)
res_1 = Vec2d.make_orthogonal(node_a)
assert res_1 == (2, -1)
def test_vec2d_euclidean_distance():
node_a = (3, -7)
node_0 = (0, 0)
assert Vec2d.get_euclidean_distance(node_a, node_0) == Vec2d.get_norm(node_a)
def test_vec2d_manhattan_distance():
node_a = (3, -7)
node_0 = (0, 0)
assert Vec2d.get_manhattan_distance(node_a, node_0) == 3 + 7
def test_vec2d_chebyshev_distance():
node_a = (3, -7)
node_0 = (0, 0)
assert Vec2d.get_chebyshev_distance(node_a, node_0) == 7
node_b = (-3, 7)
node_0 = (0, 0)
assert Vec2d.get_chebyshev_distance(node_b, node_0) == 7
node_c = (3, 7)
node_0 = (0, 0)
assert Vec2d.get_chebyshev_distance(node_c, node_0) == 7
def test_vec2d_norm():
node_a = (1, 2)
node_b = (1, -2)
res_1 = Vec2d.get_norm(node_a)
res_2 = Vec2d.get_norm(node_b)
assert np.sqrt(1 * 1 + 2 * 2) == res_1
assert np.sqrt(1 * 1 + (-2) * (-2)) == res_2
def test_vec2d_normalize():
node_a = (1, 2)
node_b = (1, -2)
res_1 = Vec2d.normalize(node_a)
res_2 = Vec2d.normalize(node_b)
assert np.isclose(1.0, Vec2d.get_norm(res_1))
assert np.isclose(1.0, Vec2d.get_norm(res_2))
def test_vec2d_scale():
node_a = (1, 2)
node_b = (1, -2)
res_1 = Vec2d.scale(node_a, 2)
res_2 = Vec2d.scale(node_b, -2.5)
assert res_1 == (2, 4)
assert res_2 == (-2.5, 5)
def test_vec2d_round():
node_a = (-1.95, -2.2)
node_b = (1.95, 2.2)
res_1 = Vec2d.round(node_a)
res_2 = Vec2d.round(node_b)
assert res_1 == (-2, -2)
assert res_2 == (2, 2)
def test_vec2d_ceil():
node_a = (-1.95, -2.2)
node_b = (1.95, 2.2)
res_1 = Vec2d.ceil(node_a)
res_2 = Vec2d.ceil(node_b)
assert res_1 == (-1, -2)
assert res_2 == (2, 3)
def test_vec2d_floor():
node_a = (-1.95, -2.2)
node_b = (1.95, 2.2)
res_1 = Vec2d.floor(node_a)
res_2 = Vec2d.floor(node_b)
assert res_1 == (-2, -3)
assert res_2 == (1, 2)
def test_vec2d_bound():
node_a = (-1.95, -2.2)
node_b = (1.95, 2.2)
res_1 = Vec2d.bound(node_a, -1, 0)
res_2 = Vec2d.bound(node_b, 2, 2.2)
assert res_1 == (-1, -1)
assert res_2 == (2, 2.2)
def test_vec2d_rotate():
node_a = (-1.95, -2.2)
res_1 = Vec2d.rotate(node_a, -90.0)
res_2 = Vec2d.rotate(node_a, 0.0)
res_3 = Vec2d.rotate(node_a, 90.0)
res_4 = Vec2d.rotate(node_a, 180.0)
res_5 = Vec2d.rotate(node_a, 270.0)
res_6 = Vec2d.rotate(node_a, 30.0)
res_1 = (Vec2d.get_norm(Vec2d.subtract(res_1, (-2.2, 1.95))))
res_2 = (Vec2d.get_norm(Vec2d.subtract(res_2, (-1.95, -2.2))))
res_3 = (Vec2d.get_norm(Vec2d.subtract(res_3, (2.2, -1.95))))
res_4 = (Vec2d.get_norm(Vec2d.subtract(res_4, (1.95, 2.2))))
res_5 = (Vec2d.get_norm(Vec2d.subtract(res_5, (-2.2, 1.95))))
res_6 = (Vec2d.get_norm(Vec2d.subtract(res_6, (-0.5887495373796556, -2.880255888325765))))
assert np.isclose(0, res_1)
assert np.isclose(0, res_2)
assert np.isclose(0, res_3)
assert np.isclose(0, res_4)
assert np.isclose(0, res_5)
assert np.isclose(0, res_6)
from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
def test_grid4_get_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == 0
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) # the most significant bit is on
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the most significant and the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
def test_grid8_set_transitions():
grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
def check_path(env, rail, position, direction, target, expected, rendering=False):
agent = env.agents[0]
agent.position = position # south dead-end
agent.direction = direction # north
agent.target = target # east dead-end
agent.moving = True
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
assert rail.check_path_exists(agent.position, agent.direction, agent.target) == expected
def test_path_exists(rendering=False):
rail, rail_map, optiionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optiionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
check_path(
env,
rail,
(5, 6), # north of south dead-end
0, # north
(3, 9), # east dead-end
True
)
check_path(
env,
rail,
(6, 6), # south dead-end
2, # south
(3, 9), # east dead-end
True
)
check_path(
env,
rail,
(3, 0), # east dead-end
3, # west
(0, 3), # north dead-end
True
)
check_path(
env,
rail,
(5, 6), # east dead-end
0, # west
(1, 3), # north dead-end
True)
check_path(
env,
rail,
(1, 3), # east dead-end
2, # south
(3, 3), # north dead-end
True
)
check_path(
env,
rail,
(1, 3), # east dead-end
0, # north
(3, 3), # north dead-end
True
)
def test_path_not_exists(rendering=False):
rail, rail_map, optionals = make_simple_rail_unconnected()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
check_path(
env,
rail,
(5, 6), # south dead-end
0, # north
(0, 3), # north dead-end
False
)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
def test_get_entry_directions():
transitions = RailEnvTransitions()
cells = transitions.transition_list
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
south_east_turn = int('0100000000000010', 2)
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
def _assert(transition, expected):
actual = Grid4Transitions.get_entry_directions(transition)
assert actual == expected, "Found {}, expected {}.".format(actual, expected)
_assert(south_east_turn, [True, False, False, True])
_assert(south_west_turn, [True, True, False, False])
_assert(north_east_turn, [False, False, True, True])
_assert(north_west_turn, [False, True, True, False])
_assert(vertical_line, [True, False, True, False])
_assert(south_symmetrical_switch, [True, True, False, True])
_assert(north_symmetrical_switch, [False, True, True, True])
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Tests for `flatland` package."""
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid8 import Grid8Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
# remove whitespace in string; keep whitespace below for easier reading
def rw(s):
return s.replace(" ", "")
def test_rotate_railenv_transition():
rail_env_transitions = RailEnvTransitions()
# TODO test all cases
transition_cycles = [
# empty cell - Case 0
[int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2),
int('0000000000000000', 2)],
# Case 1 - straight
# |
# |
# |
[int(rw('1000 0000 0010 0000'), 2), int(rw('0000 0100 0000 0001'), 2)],
# Case 1b (8) - simple turn right
# _
# |
# |
[
int(rw('0100 0000 0000 0010'), 2),
int(rw('0001 0010 0000 0000'), 2),
int(rw('0000 1000 0001 0000'), 2),
int(rw('0000 0000 0100 1000'), 2),
],
# Case 1c (9) - simple turn left
# _
# |
# |
# int('0001001000000000', 2),\ # noqa: E800
# Case 2 - simple left switch
# _ _|
# |
# |
[
int(rw('1001 0010 0010 0000'), 2),
int(rw('0000 1100 0001 0001'), 2),
int(rw('1000 0000 0110 1000'), 2),
int(rw('0100 0100 0000 0011'), 2),
],
# Case 2b (10) - simple right switch
# |
# |
# |
# int('1100000000100010', 2) \ # noqa: E800
# Case 3 - diamond drossing
# int('1000010000100001', 2), \ # noqa: E800
# Case 4 - single slip
# int('1001011000100001', 2), \ # noqa: E800
# Case 5 - double slip
# int('1100110000110011', 2), \ # noqa: E800
# Case 6 - symmetrical
# int('0101001000000010', 2), \ # noqa: E800
# Case 7 - dead end
#
#
# |
[
int(rw('0010 0000 0000 0000'), 2),
int(rw('0000 0001 0000 0000'), 2),
int(rw('0000 0000 1000 0000'), 2),
int(rw('0000 0000 0000 0100'), 2),
],
]
for index, cycle in enumerate(transition_cycles):
for i in range(4):
actual_transition = rail_env_transitions.rotate_transition(cycle[0], i * 90)
expected_transition = cycle[i % len(cycle)]
try:
assert actual_transition == expected_transition, \
"Case {}: rotate_transition({}, {}) should equal {} but was {}.".format(
i, cycle[0], i, expected_transition, actual_transition)
except Exception as e:
print("expected:")
rail_env_transitions.print(expected_transition)
print("actual:")
rail_env_transitions.print(actual_transition)
raise e
def test_is_valid_railenv_transitions():
rail_env_trans = RailEnvTransitions()
transition_list = rail_env_trans.transitions
for t in transition_list:
assert (rail_env_trans.is_valid(t) is True)
for i in range(3):
rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
assert (rail_env_trans.is_valid(rot_trans) is True)
assert (rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
assert (rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
assert (rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
def test_adding_new_valid_transition():
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans)
# adding straight
assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5), (10, 10)) is True)
# adding valid right turn
assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6), (10, 10)) is True)
# adding valid left turn
assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
grid_map.grid[(5, 5)] = rail_trans.transitions[2]
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
# should create #4 -> valid
grid_map.grid[(5, 5)] = rail_trans.transitions[3]
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
grid_map.grid[(5, 5)] = rail_trans.transitions[7]
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
# test path start condition
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True)
# test path end condition
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True)
def test_valid_railenv_transitions():
rail_env_trans = RailEnvTransitions()
# directions:
# 'N': 0
# 'E': 1
# 'S': 2
# 'W': 3
for i in range(2):
assert (rail_env_trans.get_transitions(
int('1100110000110011', 2), i) == (1, 1, 0, 0))
assert (rail_env_trans.get_transitions(
int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1))
no_transition_cell = int('0000000000000000', 2)
for i in range(4):
assert (rail_env_trans.get_transitions(
no_transition_cell, i) == (0, 0, 0, 0))
# Facing south, going south
north_south_transition = rail_env_trans.set_transitions(no_transition_cell, 2, (0, 0, 1, 0))
assert (rail_env_trans.set_transition(
north_south_transition, 2, 2, 0) == no_transition_cell)
assert (rail_env_trans.get_transition(
north_south_transition, 2, 2))
# Facing north, going east
south_east_transition = \
rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
assert (rail_env_trans.get_transition(
south_east_transition, 0, 1))
# The opposite transitions are not feasible
assert (not rail_env_trans.get_transition(
north_south_transition, 2, 0))
assert (not rail_env_trans.get_transition(
south_east_transition, 2, 1))
east_west_transition = rail_env_trans.rotate_transition(north_south_transition, 90)
north_west_transition = rail_env_trans.rotate_transition(south_east_transition, 180)
# Facing west, going west
assert (rail_env_trans.get_transition(
east_west_transition, 3, 3))
# Facing south, going west
assert (rail_env_trans.get_transition(
north_west_transition, 2, 3))
assert (south_east_transition == rail_env_trans.rotate_transition(
south_east_transition, 360))
def test_diagonal_transitions():
diagonal_trans_env = Grid8Transitions([])
# Facing north, going north-east
south_northeast_transition = int('01000000' + '0' * 8 * 7, 2)
assert (diagonal_trans_env.get_transitions(
south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
# Allowing transition from north to southwest: Facing south, going SW
north_southwest_transition = \
diagonal_trans_env.set_transitions(0, 4, (0, 0, 0, 0, 0, 1, 0, 0))
assert (diagonal_trans_env.rotate_transition(
south_northeast_transition, 180) == north_southwest_transition)
def test_rail_env_has_deadend():
deadends = set([int(rw('0010 0000 0000 0000'), 2),
int(rw('0000 0001 0000 0000'), 2),
int(rw('0000 0000 1000 0000'), 2),
int(rw('0000 0000 0000 0100'), 2)])
ret = RailEnvTransitions()
transitions_all = ret.transitions_all
for t in transitions_all:
expected_has_deadend = t in deadends
actual_had_deadend = Grid4Transitions.has_deadend(t)
assert actual_had_deadend == expected_has_deadend, \
"{} should be deadend = {}, actual = {}".format(t, )
def test_rail_env_remove_deadend():
ret = Grid4Transitions([])
rail_env_deadends = set([int(rw('0010 0000 0000 0000'), 2),
int(rw('0000 0001 0000 0000'), 2),
int(rw('0000 0000 1000 0000'), 2),
int(rw('0000 0000 0000 0100'), 2)])
for t in rail_env_deadends:
expected_has_deadend = 0
actual_had_deadend = ret.remove_deadends(t)
assert actual_had_deadend == expected_has_deadend, \
"{} should be deadend = {}, actual = {}".format(t, )
assert ret.remove_deadends(int(rw('0010 0001 1000 0100'), 2)) == 0
assert ret.remove_deadends(int(rw('0010 0001 1000 0110'), 2)) == int(rw('0000 0000 0000 0010'), 2)
import pytest
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_oval_rail
def test_shortest_paths():
rail, rail_map, optionals = make_oval_rail()
speed_ratio_map = {1.: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_shortest_path = env.agents[0].get_shortest_path(env.distance_map)
agent1_shortest_path = env.agents[1].get_shortest_path(env.distance_map)
assert len(agent0_shortest_path) == 10
assert len(agent1_shortest_path) == 10
def test_travel_time_on_shortest_paths():
rail, rail_map, optionals = make_oval_rail()
speed_ratio_map = {1.: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 10
assert agent1_travel_time == 10
speed_ratio_map = {1/2: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 20
assert agent1_travel_time == 20
speed_ratio_map = {1/3: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 30
assert agent1_travel_time == 30
speed_ratio_map = {1/4: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 40
assert agent1_travel_time == 40
# def test_latest_arrival_validity():
# pass
# def test_time_remaining_until_latest_arrival():
# pass
def main():
pass
if __name__ == "__main__":
main()
import numpy as np
import pytest
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction
from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
depth_to_test = 5
positions_to_test = [0, 5, 1, 6, 20, 30]
coordinates_to_test = [[0, 0], [0, 1], [1, 0], [1, 1], [0, 4], [0, 6]]
def test_position_to_coordinate():
actual_coordinates = position_to_coordinate(depth_to_test, positions_to_test)
expected_coordinates = coordinates_to_test
assert np.array_equal(actual_coordinates, expected_coordinates), \
"converted positions {}, expected {}".format(actual_coordinates, expected_coordinates)
def test_coordinate_to_position():
actual_positions = coordinate_to_position(depth_to_test, coordinates_to_test)
expected_positions = positions_to_test
assert np.array_equal(actual_positions, expected_positions), \
"converted positions {}, expected {}".format(actual_positions, expected_positions)
def test_get_direction():
assert get_direction((0, 0), (0, 1)) == Grid4TransitionsEnum.EAST
assert get_direction((0, 0), (0, 2)) == Grid4TransitionsEnum.EAST
assert get_direction((0, 0), (1, 0)) == Grid4TransitionsEnum.SOUTH
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
with pytest.raises(Exception, match="Could not determine direction"):
get_direction((0, 0), (0, 0))
def test_load():
load_flatland_environment_from_file('test_001.pkl', 'env_data.tests')
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.step_utils.states import TrainState
"""Tests for `flatland` package."""
def test_global_obs():
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
global_obs, info = env.reset()
# we have to take step for the agent to enter the grid.
global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})
assert (global_obs[0][0].shape == rail_map.shape + (16,))
rail_map_recons = np.zeros_like(rail_map)
for i in range(global_obs[0][0].shape[0]):
for j in range(global_obs[0][0].shape[1]):
rail_map_recons[i, j] = int(
''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
assert (rail_map_recons.all() == rail_map.all())
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
obs_agents_state = global_obs[0][1]
obs_agents_state = obs_agents_state + 1
assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
def _step_along_shortest_path(env, obs_builder, rail):
actions = {}
expected_next_position = {}
for agent in env.agents:
shortest_distance = np.inf
for exit_direction in range(4):
neighbour = get_new_position(agent.position, exit_direction)
if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
desired_movement_from_new_cell = (exit_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `entry_direction` to the neighbour possible?
is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
distance_to_target = obs_builder.env.distance_map.get()[
(agent.handle, *agent.position, exit_direction)]
print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
agent.direction,
exit_direction,
distance_to_target))
if distance_to_target < shortest_distance:
shortest_distance = distance_to_target
actions_to_be_taken_when_facing_north = {
Grid4TransitionsEnum.NORTH: RailEnvActions.MOVE_FORWARD,
Grid4TransitionsEnum.EAST: RailEnvActions.MOVE_RIGHT,
Grid4TransitionsEnum.WEST: RailEnvActions.MOVE_LEFT,
Grid4TransitionsEnum.SOUTH: RailEnvActions.DO_NOTHING,
}
print(" improved (direction) -> {}".format(exit_direction))
actions[agent.handle] = actions_to_be_taken_when_facing_north[
(exit_direction - agent.direction) % len(rail.transitions.get_direction_enum())]
expected_next_position[agent.handle] = neighbour
print(" improved (action) -> {}".format(actions[agent.handle]))
_, rewards, dones, _ = env.step(actions)
return rewards, dones
def test_reward_function_conflict(rendering=False):
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
obs_builder: TreeObsForRailEnv = env.obs_builder
env.reset()
# set the initial position
agent = env.agents[0]
agent.position = (5, 6) # south dead-end
agent.initial_position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
agent = env.agents[1]
agent.position = (3, 8) # east dead-end
agent.initial_position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.initial_direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
env.agents[0].position = (5, 6)
env.agents[1].position = (3, 8)
print("\n")
print(env.agents[0])
print(env.agents[1])
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=True)
iteration = 0
expected_positions = {
0: {
0: (5, 6),
1: (3, 8)
},
# both can move
1: {
0: (4, 6),
1: (3, 7)
},
# first can move, second stuck
2: {
0: (3, 6),
1: (3, 7)
},
# both stuck from now on
3: {
0: (3, 6),
1: (3, 7)
},
4: {
0: (3, 6),
1: (3, 7)
},
5: {
0: (3, 6),
1: (3, 7)
},
}
while iteration < 5:
rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
if dones["__all__"]:
break
for agent in env.agents:
# assert rewards[agent.handle] == 0
expected_position = expected_positions[iteration + 1][agent.handle]
assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1,
agent.handle,
agent.position,
expected_position)
if rendering:
renderer.render_env(show=True, show_observations=True)
iteration += 1
def test_reward_function_waiting(rendering=False):
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False, random_seed=1)
obs_builder: TreeObsForRailEnv = env.obs_builder
env.reset()
# set the initial position
agent = env.agents[0]
agent.initial_position = (3, 8) # east dead-end
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.initial_direction = 3 # west
agent.target = (3, 1) # west dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
agent = env.agents[1]
agent.initial_position = (5, 6) # south dead-end
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 8) # east dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
env.agents[0].position = (3, 8)
env.agents[1].position = (5, 6)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=True)
iteration = 0
expectations = {
0: {
'positions': {
0: (3, 8),
1: (5, 6),
},
'rewards': [0, 0],
},
1: {
'positions': {
0: (3, 7),
1: (4, 6),
},
'rewards': [0, 0],
},
# second agent has to wait for first, first can continue
2: {
'positions': {
0: (3, 6),
1: (4, 6),
},
'rewards': [0, 0],
},
# both can move again
3: {
'positions': {
0: (3, 5),
1: (3, 6),
},
'rewards': [0, 0],
},
4: {
'positions': {
0: (3, 4),
1: (3, 7),
},
'rewards': [0, 0],
},
# second reached target
5: {
'positions': {
0: (3, 3),
1: (3, 8),
},
'rewards': [0, 0],
},
6: {
'positions': {
0: (3, 2),
1: (3, 8),
},
'rewards': [0, 0],
},
# first reaches, target too
7: {
'positions': {
0: (3, 1),
1: (3, 8),
},
'rewards': [0, 0],
},
8: {
'positions': {
0: (3, 1),
1: (3, 8),
},
'rewards': [0, 0],
},
}
while iteration < 7:
rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
if dones["__all__"]:
break
if rendering:
renderer.render_env(show=True, show_observations=True)
print(env.dones["__all__"])
for agent in env.agents:
print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
for agent in env.agents:
expected_position = expectations[iteration + 1]['positions'][agent.handle]
assert agent.position == expected_position, \
"[{}] agent {} at {}, expected {}".format(iteration + 1,
agent.handle,
agent.position,
expected_position)
# expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
# actual_reward = rewards[agent.handle]
# assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
# agent.handle,
# actual_reward,
# expected_reward)
iteration += 1
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
def test_load_new():
filename = "test_load_new.pkl"
rail, rail_map, optionals = make_simple_rail()
n_agents = 2
env_initial = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env_initial.reset(False, False)
rails_initial = env_initial.rail.grid
agents_initial = env_initial.agents
RailEnvPersister.save(env_initial, filename)
env_loaded, _ = RailEnvPersister.load_new(filename)
rails_loaded = env_loaded.rail.grid
agents_loaded = env_loaded.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
def main():
pass
if __name__ == "__main__":
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv, Node
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.step_utils.states import TrainState
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
"""Test predictions for `flatland` package."""
def test_dummy_predictor(rendering=False):
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
env.reset()
# set initial position and direction for testing...
env.agents[0].initial_position = (5, 6)
env.agents[0].initial_direction = 0
env.agents[0].direction = 0
env.agents[0].target = (3, 0)
env.reset(False, False)
env.agents[0].earliest_departure = 1
env._max_episode_steps = 100
# Make Agent 0 active
env.step({})
env.step({0: RailEnvActions.MOVE_FORWARD})
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
# test assertions
predictions = env.obs_builder.predictor.get(None)
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
# compare against expected values
expected_positions = np.array([[5., 6.],
[4., 6.],
[3., 6.],
[3., 5.],
[3., 4.],
[3., 3.],
[3., 2.],
[3., 1.],
# at target (3,0): stay in this position from here on
[3., 0.],
[3., 0.],
[3., 0.],
])
expected_directions = np.array([[0.],
[0.],
[0.],
[3.],
[3.],
[3.],
[3.],
[3.],
# at target (3,0): stay in this position from here on
[3.],
[3.],
[3.]
])
expected_time_offsets = np.array([[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
])
expected_actions = np.array([[0.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
# reaching target by straight
[2.],
# at target: stopped moving
[4.],
[4.],
])
assert np.array_equal(positions, expected_positions)
assert np.array_equal(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions)
def test_shortest_path_predictor(rendering=False):
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
# set the initial position
agent = env.agents[0]
agent.initial_position = (5, 6) # south dead-end
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent._set_state(TrainState.MOVING)
env.reset(False, False)
env.distance_map._compute(env.agents, env.rail)
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
# compute the observations and predictions
distance_map = env.distance_map.get()
distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction]
assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0)
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
Waypoint((5, 6), 0),
Waypoint((4, 6), 0),
Waypoint((3, 6), 0),
Waypoint((3, 7), 1),
Waypoint((3, 8), 1),
Waypoint((3, 9), 1)
]
# extract the data
predictions = env.obs_builder.predictions
positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
# test if data meets expectations
expected_positions = [
[5, 6],
[4, 6],
[3, 6],
[3, 7],
[3, 8],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
[3, 9],
]
expected_directions = [
[Grid4TransitionsEnum.NORTH], # next is [5,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [4,6] heading north
[Grid4TransitionsEnum.NORTH], # next is [3,6] heading north
[Grid4TransitionsEnum.EAST], # next is [3,7] heading east
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
[Grid4TransitionsEnum.EAST],
]
expected_time_offsets = np.array([
[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.],
[20.],
])
assert np.array_equal(time_offsets, expected_time_offsets), \
"time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
assert np.array_equal(positions, expected_positions), \
"positions {}, expected {}".format(positions, expected_positions)
assert np.array_equal(directions, expected_directions), \
"directions {}, expected {}".format(directions, expected_directions)
def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map, optionals = make_invalid_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
# set the initial position
env.agents[0].initial_position = (5, 6) # south dead-end
env.agents[0].position = (5, 6) # south dead-end
env.agents[0].direction = 0 # north
env.agents[0].initial_direction = 0 # north
env.agents[0].target = (3, 9) # east dead-end
env.agents[0].moving = True
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1].initial_position = (3, 8) # east dead-end
env.agents[1].position = (3, 8) # east dead-end
env.agents[1].direction = 3 # west
env.agents[1].initial_direction = 3 # west
env.agents[1].target = (6, 6) # south dead-end
env.agents[1].moving = True
env.agents[1]._set_state(TrainState.MOVING)
observations, info = env.reset(False, False)
env.agents[0].position = (5, 6) # south dead-end
env.agent_positions[env.agents[0].position] = 0
env.agents[1].position = (3, 8) # east dead-end
env.agent_positions[env.agents[1].position] = 1
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
observations = env._get_observations()
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
# get the trees to test
obs_builder: TreeObsForRailEnv = env.obs_builder
pp = pprint.PrettyPrinter(indent=4)
tree_0 = observations[0]
tree_1 = observations[1]
env.obs_builder.util_print_obs_subtree(tree_0)
env.obs_builder.util_print_obs_subtree(tree_1)
# check the expectations
expected_conflicts_0 = [('F', 'R')]
expected_conflicts_1 = [('F', 'L')]
_check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''):
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_opposite_direction
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import numpy as np
import pytest
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
def test_save_load():
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
agent_1_pos = env.agents[0].position
agent_1_dir = env.agents[0].direction
agent_1_tar = env.agents[0].target
agent_2_pos = env.agents[1].position
agent_2_dir = env.agents[1].direction
agent_2_tar = env.agents[1].target
os.makedirs("tmp", exist_ok=True)
RailEnvPersister.save(env, "tmp/test_save.pkl")
env.save("tmp/test_save_2.pkl")
# env.load("test_save.dat")
env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
assert (env.width == 30)
assert (env.height == 30)
assert (len(env.agents) == 2)
assert (agent_1_pos == env.agents[0].position)
assert (agent_1_dir == env.agents[0].direction)
assert (agent_1_tar == env.agents[0].target)
assert (agent_2_pos == env.agents[1].position)
assert (agent_2_dir == env.agents[1].direction)
assert (agent_2_tar == env.agents[1].target)
@pytest.mark.skip("Msgpack serializing not supported")
def test_save_load_mpk():
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
os.makedirs("tmp", exist_ok=True)
RailEnvPersister.save(env, "tmp/test_save.mpk")
# env.load("test_save.dat")
env2, env_dict = RailEnvPersister.load_new("tmp/test_save.mpk")
assert (env.width == env2.width)
assert (env.height == env2.height)
assert (len(env2.agents) == len(env.agents))
for agent1, agent2 in zip(env.agents, env2.agents):
assert (agent1.position == agent2.position)
assert (agent1.direction == agent2.direction)
assert (agent1.target == agent2.target)
@pytest.mark.skip(reason="Old file used to create env, not sure how to regenerate")
def test_rail_environment_single_agent(show=False):
# We instantiate the following map on a 3x3 grid
# _ _
# / \/ \
# | | |
# \_/\_/
transitions = RailEnvTransitions()
if False:
# This env creation doesn't quite work right.
cells = transitions.transition_list
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
south_east_turn = int('0100000000000010', 2)
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
rail_map = np.array([[south_east_turn, south_symmetrical_switch,
south_west_turn],
[vertical_line, vertical_line, vertical_line],
[north_east_turn, north_symmetrical_switch,
north_west_turn]],
dtype=np.uint16)
rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
else:
rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
rail_map = rail_env.rail.grid
rail_env._max_episode_steps = 1000
_ = rail_env.reset(False, False, True)
liActions = [int(a) for a in RailEnvActions]
env_renderer = RenderTool(rail_env)
# RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
for _ in range(5):
# rail_env.agents[0].initial_position = (1,2)
_ = rail_env.reset(False, False, True)
# We do not care about target for the moment
agent = rail_env.agents[0]
agent.target = [-1, -1]
# Check that trains are always initialized at a consistent position
# or direction.
# They should always be able to go somewhere.
if show:
print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
print(transitions.get_transitions(rail_map[agent.position], agent.direction))
# assert (transitions.get_transitions(
# rail_map[agent.position],
# agent.direction) != (0, 0, 0, 0))
# HACK - force the direction to one we know is good.
# agent.initial_position = agent.position = (2,3)
agent.initial_direction = agent.direction = 0
if show:
print("handle:", agent.handle)
# agent.initial_position = initial_pos = agent.position
valid_active_actions_done = 0
pos = agent.position
if show:
env_renderer.render_env(show=show, show_agents=True)
time.sleep(0.01)
iStep = 0
while valid_active_actions_done < 6:
# We randomly select an action
action = np.random.choice(liActions)
# action = RailEnvActions.MOVE_FORWARD
_, _, dict_done, _ = rail_env.step({0: action})
prev_pos = pos
pos = agent.position # rail_env.agents_position[0]
print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
print(dict_done)
if prev_pos != pos:
valid_active_actions_done += 1
iStep += 1
if show:
env_renderer.render_env(show=show, show_agents=True, step=iStep)
time.sleep(0.01)
assert iStep < 100, "valid actions should have been performed by now - hung agent"
# After 6 movements on this railway network, the train should be back
# to its original height on the map.
# assert (initial_pos[0] == agent.position[0])
# We check that the train always attains its target after some time
for _ in range(10):
_ = rail_env.reset()
rail_env.agents[0].direction = 0
# JW - to avoid problem with sparse_line_generator.
# rail_env.agents[0].position = (1,2)
iStep = 0
while iStep < 100:
# We randomly select an action
action = np.random.choice(liActions)
_, _, dones, _ = rail_env.step({0: action})
done = dones['__all__']
if done:
break
iStep += 1
assert iStep < 100, "agent should have finished by now"
env_renderer.render_env(show=show)
def test_dead_end():
transitions = RailEnvTransitions()
straight_vertical = int('1000000000100000', 2) # Case 1 - straight
straight_horizontal = transitions.rotate_transition(straight_vertical,
90)
dead_end_from_south = int('0010000000000000', 2) # Case 7 - dead end
# We instantiate the following railway
# O->-- where > is the train and O the target. After 6 steps,
# the train should be done.
rail_map = np.array(
[[transitions.rotate_transition(dead_end_from_south, 270)] +
[straight_horizontal] * 3 +
[transitions.rotate_transition(dead_end_from_south, 90)]],
dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0],
transitions=transitions)
rail.grid = rail_map
city_positions = [(0, 0), (0, 3)]
train_stations = [
[((0, 0), 0)],
[((0, 0), 0)],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
# We try the configuration in the 4 directions:
rail_env.reset()
rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]
rail_env.reset()
rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]
# In the vertical configuration:
rail_map = np.array(
[[dead_end_from_south]] + [[straight_vertical]] * 3 +
[[transitions.rotate_transition(dead_end_from_south, 180)]],
dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0],
transitions=transitions)
city_positions = [(0, 0), (0, 3)]
train_stations = [
[((0, 0), 0)],
[((0, 0), 0)],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
rail_env.reset()
rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]
rail_env.reset()
rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
# TODO make assertions
def test_get_entry_directions():
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
def _assert(position, expected):
actual = env.get_valid_directions_on_grid(*position)
assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected)
# north dead end
_assert((0, 3), [True, False, False, False])
# west dead end
_assert((3, 0), [False, False, False, True])
# switch
_assert((3, 3), [False, True, True, True])
# horizontal
_assert((3, 2), [False, True, False, True])
# vertical
_assert((2, 3), [True, False, True, False])
# nowhere
_assert((0, 0), [False, False, False, False])
def test_rail_env_reset():
file_name = "test_rail_env_reset.pkl"
# Test to save and load file.
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# env.save(file_name)
RailEnvPersister.save(env, file_name)
dist_map_shape = np.shape(env.distance_map.get())
rails_initial = env.rail.grid
agents_initial = env.agents
# env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
# line_generator=line_from_file(file_name), number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
# env2.reset(False, False, False)
env2, env2_dict = RailEnvPersister.load_new(file_name)
rails_loaded = env2.rail.grid
agents_loaded = env2.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env3.reset(False, True)
rails_loaded = env3.rail.grid
agents_loaded = env3.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env4.reset(True, False)
rails_loaded = env4.rail.grid
agents_loaded = env4.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
def main():
# test_rail_environment_single_agent(show=True)
test_rail_env_reset()
if __name__ == "__main__":
main()
import sys
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives
from flatland.envs.persistence import RailEnvPersister
def test_get_shortest_paths_unreachable():
rail, rail_map, optionals = make_disconnected_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
# set the initial position
agent = env.agents[0]
agent.position = (3, 1) # west dead-end
agent.initial_position = (3, 1) # west dead-end
agent.direction = Grid4TransitionsEnum.WEST
agent.target = (3, 9) # east dead-end
agent.moving = True
env.reset(False, False)
actual = get_shortest_paths(env.distance_map)
expected = {0: None}
assert actual[0] == expected[0], "actual={},expected={}".format(actual[0], expected[0])
# todo file test_002.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
def test_get_shortest_paths():
#env = load_flatland_environment_from_file('test_002.mpk', 'env_data.tests')
env, env_dict = RailEnvPersister.load_new("test_002.mpk", "env_data.tests")
#print("env len(agents): ", len(env.agents))
#print(env.distance_map)
#print("env number_of_agents:", env.number_of_agents)
#print("env agents:", env.agents)
#env.distance_map.reset(env.agents, env.rail)
#actual = get_shortest_paths(env.distance_map)
#print("shortest paths:", actual)
#print(env.distance_map)
#print("Dist map agents:", env.distance_map.agents)
#print("\nenv reset()")
env.reset()
actual = get_shortest_paths(env.distance_map)
#print("env agents: ", len(env.agents))
#print("env number_of_agents: ", env.number_of_agents)
assert len(actual) == 2, "get_shortest_paths should return a dict of length 2"
expected = {
0: [
Waypoint(position=(1, 1), direction=1),
Waypoint(position=(1, 2), direction=1),
Waypoint(position=(1, 3), direction=1),
Waypoint(position=(2, 3), direction=2),
Waypoint(position=(2, 4), direction=1),
Waypoint(position=(2, 5), direction=1),
Waypoint(position=(2, 6), direction=1),
Waypoint(position=(2, 7), direction=1),
Waypoint(position=(2, 8), direction=1),
Waypoint(position=(2, 9), direction=1),
Waypoint(position=(2, 10), direction=1),
Waypoint(position=(2, 11), direction=1),
Waypoint(position=(2, 12), direction=1),
Waypoint(position=(2, 13), direction=1),
Waypoint(position=(2, 14), direction=1),
Waypoint(position=(2, 15), direction=1),
Waypoint(position=(2, 16), direction=1),
Waypoint(position=(2, 17), direction=1),
Waypoint(position=(2, 18), direction=1)],
1: [
Waypoint(position=(3, 18), direction=3),
Waypoint(position=(3, 17), direction=3),
Waypoint(position=(3, 16), direction=3),
Waypoint(position=(2, 16), direction=0),
Waypoint(position=(2, 15), direction=3),
Waypoint(position=(2, 14), direction=3),
Waypoint(position=(2, 13), direction=3),
Waypoint(position=(2, 12), direction=3),
Waypoint(position=(2, 11), direction=3),
Waypoint(position=(2, 10), direction=3),
Waypoint(position=(2, 9), direction=3),
Waypoint(position=(2, 8), direction=3),
Waypoint(position=(2, 7), direction=3),
Waypoint(position=(2, 6), direction=3),
Waypoint(position=(2, 5), direction=3),
Waypoint(position=(2, 4), direction=3),
Waypoint(position=(2, 3), direction=3),
Waypoint(position=(2, 2), direction=3),
Waypoint(position=(2, 1), direction=3)]
}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
# todo file test_002.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
def test_get_shortest_paths_max_depth():
#env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
env, _ = RailEnvPersister.load_new("test_002.mpk", "env_data.tests")
env.reset()
actual = get_shortest_paths(env.distance_map, max_depth=2)
expected = {
0: [
Waypoint(position=(1, 1), direction=1),
Waypoint(position=(1, 2), direction=1)
],
1: [
Waypoint(position=(3, 18), direction=3),
Waypoint(position=(3, 17), direction=3),
]
}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
# todo file Level_distance_map_shortest_path.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
def test_get_shortest_paths_agent_handle():
#env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests')
env, _ = RailEnvPersister.load_new("Level_distance_map_shortest_path.mpk", "env_data.tests")
env.reset()
actual = get_shortest_paths(env.distance_map, agent_handle=6)
print(actual, file=sys.stderr)
expected = {6:
[Waypoint(position=(5, 5),
direction=0),
Waypoint(position=(4, 5),
direction=0),
Waypoint(position=(3, 5),
direction=0),
Waypoint(position=(2, 5),
direction=0),
Waypoint(position=(1, 5),
direction=0),
Waypoint(position=(0, 5),
direction=0),
Waypoint(position=(0, 6),
direction=1),
Waypoint(position=(0, 7), direction=1),
Waypoint(position=(0, 8),
direction=1),
Waypoint(position=(0, 9),
direction=1),
Waypoint(position=(0, 10),
direction=1),
Waypoint(position=(1, 10),
direction=2),
Waypoint(position=(2, 10),
direction=2),
Waypoint(position=(3, 10),
direction=2),
Waypoint(position=(4, 10),
direction=2),
Waypoint(position=(5, 10),
direction=2),
Waypoint(position=(6, 10),
direction=2),
Waypoint(position=(7, 10),
direction=2),
Waypoint(position=(8, 10),
direction=2),
Waypoint(position=(9, 10),
direction=2),
Waypoint(position=(10, 10),
direction=2),
Waypoint(position=(11, 10),
direction=2),
Waypoint(position=(12, 10),
direction=2),
Waypoint(position=(13, 10),
direction=2),
Waypoint(position=(14, 10),
direction=2),
Waypoint(position=(15, 10),
direction=2),
Waypoint(position=(16, 10),
direction=2),
Waypoint(position=(17, 10),
direction=2),
Waypoint(position=(18, 10),
direction=2),
Waypoint(position=(19, 10),
direction=2),
Waypoint(position=(20, 10),
direction=2),
Waypoint(position=(20, 9),
direction=3),
Waypoint(position=(20, 8),
direction=3),
Waypoint(position=(21, 8),
direction=2),
Waypoint(position=(21, 7),
direction=3),
Waypoint(position=(21, 6),
direction=3),
Waypoint(position=(21, 5),
direction=3)
]}
for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
def test_get_k_shortest_paths(rendering=False):
rail, rail_map, optionals = make_simple_rail_with_alternatives()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
env.reset()
initial_position = (3, 1) # west dead-end
initial_direction = Grid4TransitionsEnum.WEST # west
target_position = (3, 9) # east
# set the initial position
agent = env.agents[0]
agent.position = initial_position
agent.initial_position = initial_position
agent.direction = initial_direction
agent.target = target_position # east dead-end
agent.moving = True
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input()
actual = set(get_k_shortest_paths(
env=env,
source_position=initial_position, # west dead-end
source_direction=int(initial_direction), # east
target_position=target_position,
k=10
))
expected = set([
(
Waypoint(position=(3, 1), direction=3),
Waypoint(position=(3, 0), direction=3),
Waypoint(position=(3, 1), direction=1),
Waypoint(position=(3, 2), direction=1),
Waypoint(position=(3, 3), direction=1),
Waypoint(position=(2, 3), direction=0),
Waypoint(position=(1, 3), direction=0),
Waypoint(position=(0, 3), direction=0),
Waypoint(position=(0, 4), direction=1),
Waypoint(position=(0, 5), direction=1),
Waypoint(position=(0, 6), direction=1),
Waypoint(position=(0, 7), direction=1),
Waypoint(position=(0, 8), direction=1),
Waypoint(position=(0, 9), direction=1),
Waypoint(position=(1, 9), direction=2),
Waypoint(position=(2, 9), direction=2),
Waypoint(position=(3, 9), direction=2)),
(
Waypoint(position=(3, 1), direction=3),
Waypoint(position=(3, 0), direction=3),
Waypoint(position=(3, 1), direction=1),
Waypoint(position=(3, 2), direction=1),
Waypoint(position=(3, 3), direction=1),
Waypoint(position=(3, 4), direction=1),
Waypoint(position=(3, 5), direction=1),
Waypoint(position=(3, 6), direction=1),
Waypoint(position=(4, 6), direction=2),
Waypoint(position=(5, 6), direction=2),
Waypoint(position=(6, 6), direction=2),
Waypoint(position=(5, 6), direction=0),
Waypoint(position=(4, 6), direction=0),
Waypoint(position=(4, 7), direction=1),
Waypoint(position=(4, 8), direction=1),
Waypoint(position=(4, 9), direction=1),
Waypoint(position=(3, 9), direction=0))
])
assert actual == expected, "actual={},expected={}".format(actual, expected)
def main():
test_get_shortest_paths()
if __name__ == "__main__":
main()
import unittest
import warnings
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
def test_sparse_rail_generator():
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=1,
grid_mode=False
),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(),
random_seed=1)
env.reset(False, False)
# for r in range(env.height):
# for c in range(env.width):
# if env.rail.grid[r][c] > 0:
# print("expected_grid_map[{}][{}] = {}".format(r, c, env.rail.grid[r][c]))
expected_grid_map = env.rail.grid
expected_grid_map[4][9] = 16386
expected_grid_map[4][10] = 1025
expected_grid_map[4][11] = 1025
expected_grid_map[4][12] = 1025
expected_grid_map[4][13] = 1025
expected_grid_map[4][14] = 1025
expected_grid_map[4][15] = 1025
expected_grid_map[4][16] = 1025
expected_grid_map[4][17] = 1025
expected_grid_map[4][18] = 1025
expected_grid_map[4][19] = 1025
expected_grid_map[4][20] = 1025
expected_grid_map[4][21] = 1025
expected_grid_map[4][22] = 17411
expected_grid_map[4][23] = 17411
expected_grid_map[4][24] = 1025
expected_grid_map[4][25] = 1025
expected_grid_map[4][26] = 1025
expected_grid_map[4][27] = 1025
expected_grid_map[4][28] = 5633
expected_grid_map[4][29] = 5633
expected_grid_map[4][30] = 4608
expected_grid_map[5][9] = 49186
expected_grid_map[5][10] = 1025
expected_grid_map[5][11] = 1025
expected_grid_map[5][12] = 1025
expected_grid_map[5][13] = 1025
expected_grid_map[5][14] = 1025
expected_grid_map[5][15] = 1025
expected_grid_map[5][16] = 1025
expected_grid_map[5][17] = 1025
expected_grid_map[5][18] = 1025
expected_grid_map[5][19] = 1025
expected_grid_map[5][20] = 1025
expected_grid_map[5][21] = 1025
expected_grid_map[5][22] = 2064
expected_grid_map[5][23] = 32800
expected_grid_map[5][28] = 32800
expected_grid_map[5][29] = 32800
expected_grid_map[5][30] = 32800
expected_grid_map[6][9] = 49186
expected_grid_map[6][10] = 1025
expected_grid_map[6][11] = 1025
expected_grid_map[6][12] = 1025
expected_grid_map[6][13] = 1025
expected_grid_map[6][14] = 1025
expected_grid_map[6][15] = 1025
expected_grid_map[6][16] = 1025
expected_grid_map[6][17] = 1025
expected_grid_map[6][18] = 1025
expected_grid_map[6][19] = 1025
expected_grid_map[6][20] = 1025
expected_grid_map[6][21] = 1025
expected_grid_map[6][22] = 1025
expected_grid_map[6][23] = 2064
expected_grid_map[6][28] = 32800
expected_grid_map[6][29] = 32872
expected_grid_map[6][30] = 37408
expected_grid_map[7][9] = 32800
expected_grid_map[7][28] = 32800
expected_grid_map[7][29] = 32800
expected_grid_map[7][30] = 32800
expected_grid_map[8][9] = 32872
expected_grid_map[8][10] = 4608
expected_grid_map[8][28] = 49186
expected_grid_map[8][29] = 34864
expected_grid_map[8][30] = 32872
expected_grid_map[8][31] = 4608
expected_grid_map[9][9] = 49186
expected_grid_map[9][10] = 34864
expected_grid_map[9][28] = 32800
expected_grid_map[9][29] = 32800
expected_grid_map[9][30] = 32800
expected_grid_map[9][31] = 32800
expected_grid_map[10][9] = 32800
expected_grid_map[10][10] = 32800
expected_grid_map[10][28] = 32872
expected_grid_map[10][29] = 37408
expected_grid_map[10][30] = 49186
expected_grid_map[10][31] = 2064
expected_grid_map[11][9] = 32800
expected_grid_map[11][10] = 32800
expected_grid_map[11][28] = 32800
expected_grid_map[11][29] = 32800
expected_grid_map[11][30] = 32800
expected_grid_map[12][9] = 32800
expected_grid_map[12][10] = 32800
expected_grid_map[12][28] = 32800
expected_grid_map[12][29] = 49186
expected_grid_map[12][30] = 34864
expected_grid_map[12][33] = 16386
expected_grid_map[12][34] = 1025
expected_grid_map[12][35] = 1025
expected_grid_map[12][36] = 1025
expected_grid_map[12][37] = 1025
expected_grid_map[12][38] = 5633
expected_grid_map[12][39] = 17411
expected_grid_map[12][40] = 1025
expected_grid_map[12][41] = 1025
expected_grid_map[12][42] = 1025
expected_grid_map[12][43] = 5633
expected_grid_map[12][44] = 17411
expected_grid_map[12][45] = 1025
expected_grid_map[12][46] = 4608
expected_grid_map[13][9] = 32872
expected_grid_map[13][10] = 37408
expected_grid_map[13][28] = 32800
expected_grid_map[13][29] = 32800
expected_grid_map[13][30] = 32800
expected_grid_map[13][33] = 32800
expected_grid_map[13][38] = 72
expected_grid_map[13][39] = 3089
expected_grid_map[13][40] = 1025
expected_grid_map[13][41] = 1025
expected_grid_map[13][42] = 1025
expected_grid_map[13][43] = 1097
expected_grid_map[13][44] = 2064
expected_grid_map[13][46] = 32800
expected_grid_map[14][9] = 49186
expected_grid_map[14][10] = 2064
expected_grid_map[14][24] = 16386
expected_grid_map[14][25] = 17411
expected_grid_map[14][26] = 1025
expected_grid_map[14][27] = 1025
expected_grid_map[14][28] = 34864
expected_grid_map[14][29] = 32800
expected_grid_map[14][30] = 32872
expected_grid_map[14][31] = 1025
expected_grid_map[14][32] = 1025
expected_grid_map[14][33] = 2064
expected_grid_map[14][46] = 32800
expected_grid_map[15][9] = 32800
expected_grid_map[15][24] = 32800
expected_grid_map[15][25] = 49186
expected_grid_map[15][26] = 1025
expected_grid_map[15][27] = 1025
expected_grid_map[15][28] = 3089
expected_grid_map[15][29] = 3089
expected_grid_map[15][30] = 2064
expected_grid_map[15][46] = 32800
expected_grid_map[16][8] = 16386
expected_grid_map[16][9] = 52275
expected_grid_map[16][10] = 4608
expected_grid_map[16][24] = 32800
expected_grid_map[16][25] = 32800
expected_grid_map[16][46] = 32800
expected_grid_map[17][8] = 32800
expected_grid_map[17][9] = 32800
expected_grid_map[17][10] = 32800
expected_grid_map[17][24] = 32872
expected_grid_map[17][25] = 37408
expected_grid_map[17][44] = 16386
expected_grid_map[17][45] = 17411
expected_grid_map[17][46] = 34864
expected_grid_map[18][8] = 32800
expected_grid_map[18][9] = 32800
expected_grid_map[18][10] = 32800
expected_grid_map[18][24] = 49186
expected_grid_map[18][25] = 34864
expected_grid_map[18][44] = 32800
expected_grid_map[18][45] = 32800
expected_grid_map[18][46] = 32800
expected_grid_map[19][8] = 32800
expected_grid_map[19][9] = 32800
expected_grid_map[19][10] = 32800
expected_grid_map[19][23] = 16386
expected_grid_map[19][24] = 34864
expected_grid_map[19][25] = 32872
expected_grid_map[19][26] = 4608
expected_grid_map[19][44] = 32800
expected_grid_map[19][45] = 32800
expected_grid_map[19][46] = 32800
expected_grid_map[20][8] = 32800
expected_grid_map[20][9] = 32872
expected_grid_map[20][10] = 37408
expected_grid_map[20][23] = 32800
expected_grid_map[20][24] = 32800
expected_grid_map[20][25] = 32800
expected_grid_map[20][26] = 32800
expected_grid_map[20][44] = 32800
expected_grid_map[20][45] = 32800
expected_grid_map[20][46] = 32800
expected_grid_map[21][8] = 32800
expected_grid_map[21][9] = 32800
expected_grid_map[21][10] = 32800
expected_grid_map[21][23] = 72
expected_grid_map[21][24] = 37408
expected_grid_map[21][25] = 49186
expected_grid_map[21][26] = 2064
expected_grid_map[21][44] = 32800
expected_grid_map[21][45] = 32800
expected_grid_map[21][46] = 32800
expected_grid_map[22][8] = 49186
expected_grid_map[22][9] = 34864
expected_grid_map[22][10] = 32872
expected_grid_map[22][11] = 4608
expected_grid_map[22][24] = 32872
expected_grid_map[22][25] = 37408
expected_grid_map[22][43] = 16386
expected_grid_map[22][44] = 2064
expected_grid_map[22][45] = 32800
expected_grid_map[22][46] = 32800
expected_grid_map[23][8] = 32800
expected_grid_map[23][9] = 32800
expected_grid_map[23][10] = 32800
expected_grid_map[23][11] = 32800
expected_grid_map[23][24] = 49186
expected_grid_map[23][25] = 34864
expected_grid_map[23][42] = 16386
expected_grid_map[23][43] = 33825
expected_grid_map[23][44] = 17411
expected_grid_map[23][45] = 3089
expected_grid_map[23][46] = 2064
expected_grid_map[24][8] = 32872
expected_grid_map[24][9] = 37408
expected_grid_map[24][10] = 49186
expected_grid_map[24][11] = 2064
expected_grid_map[24][24] = 32800
expected_grid_map[24][25] = 32800
expected_grid_map[24][42] = 32800
expected_grid_map[24][43] = 32800
expected_grid_map[24][44] = 32800
expected_grid_map[25][8] = 32800
expected_grid_map[25][9] = 32800
expected_grid_map[25][10] = 32800
expected_grid_map[25][24] = 32800
expected_grid_map[25][25] = 32800
expected_grid_map[25][42] = 32800
expected_grid_map[25][43] = 32872
expected_grid_map[25][44] = 37408
expected_grid_map[26][8] = 32800
expected_grid_map[26][9] = 49186
expected_grid_map[26][10] = 34864
expected_grid_map[26][24] = 49186
expected_grid_map[26][25] = 2064
expected_grid_map[26][42] = 32800
expected_grid_map[26][43] = 32800
expected_grid_map[26][44] = 32800
expected_grid_map[27][8] = 32800
expected_grid_map[27][9] = 32800
expected_grid_map[27][10] = 32800
expected_grid_map[27][24] = 32800
expected_grid_map[27][42] = 49186
expected_grid_map[27][43] = 34864
expected_grid_map[27][44] = 32872
expected_grid_map[27][45] = 4608
expected_grid_map[28][8] = 32800
expected_grid_map[28][9] = 32800
expected_grid_map[28][10] = 32800
expected_grid_map[28][24] = 32872
expected_grid_map[28][25] = 4608
expected_grid_map[28][42] = 32800
expected_grid_map[28][43] = 32800
expected_grid_map[28][44] = 32800
expected_grid_map[28][45] = 32800
expected_grid_map[29][8] = 32800
expected_grid_map[29][9] = 32800
expected_grid_map[29][10] = 32800
expected_grid_map[29][24] = 49186
expected_grid_map[29][25] = 34864
expected_grid_map[29][42] = 32872
expected_grid_map[29][43] = 37408
expected_grid_map[29][44] = 49186
expected_grid_map[29][45] = 2064
expected_grid_map[30][8] = 32800
expected_grid_map[30][9] = 32800
expected_grid_map[30][10] = 32800
expected_grid_map[30][23] = 16386
expected_grid_map[30][24] = 34864
expected_grid_map[30][25] = 32872
expected_grid_map[30][26] = 4608
expected_grid_map[30][42] = 32800
expected_grid_map[30][43] = 32800
expected_grid_map[30][44] = 32800
expected_grid_map[31][8] = 32800
expected_grid_map[31][9] = 32872
expected_grid_map[31][10] = 37408
expected_grid_map[31][23] = 32800
expected_grid_map[31][24] = 32800
expected_grid_map[31][25] = 32800
expected_grid_map[31][26] = 32800
expected_grid_map[31][42] = 32800
expected_grid_map[31][43] = 49186
expected_grid_map[31][44] = 34864
expected_grid_map[32][8] = 32800
expected_grid_map[32][9] = 32800
expected_grid_map[32][10] = 32800
expected_grid_map[32][23] = 72
expected_grid_map[32][24] = 37408
expected_grid_map[32][25] = 49186
expected_grid_map[32][26] = 2064
expected_grid_map[32][42] = 32800
expected_grid_map[32][43] = 32800
expected_grid_map[32][44] = 32800
expected_grid_map[33][8] = 49186
expected_grid_map[33][9] = 34864
expected_grid_map[33][10] = 32872
expected_grid_map[33][11] = 4608
expected_grid_map[33][24] = 32872
expected_grid_map[33][25] = 37408
expected_grid_map[33][41] = 16386
expected_grid_map[33][42] = 34864
expected_grid_map[33][43] = 32800
expected_grid_map[33][44] = 32800
expected_grid_map[34][8] = 32800
expected_grid_map[34][9] = 32800
expected_grid_map[34][10] = 32800
expected_grid_map[34][11] = 32800
expected_grid_map[34][24] = 49186
expected_grid_map[34][25] = 2064
expected_grid_map[34][41] = 32800
expected_grid_map[34][42] = 49186
expected_grid_map[34][43] = 2064
expected_grid_map[34][44] = 32800
expected_grid_map[35][8] = 32872
expected_grid_map[35][9] = 37408
expected_grid_map[35][10] = 49186
expected_grid_map[35][11] = 2064
expected_grid_map[35][24] = 32800
expected_grid_map[35][41] = 32800
expected_grid_map[35][42] = 32800
expected_grid_map[35][43] = 16386
expected_grid_map[35][44] = 2064
expected_grid_map[36][8] = 32800
expected_grid_map[36][9] = 32800
expected_grid_map[36][10] = 32800
expected_grid_map[36][18] = 16386
expected_grid_map[36][19] = 17411
expected_grid_map[36][20] = 1025
expected_grid_map[36][21] = 1025
expected_grid_map[36][22] = 1025
expected_grid_map[36][23] = 17411
expected_grid_map[36][24] = 52275
expected_grid_map[36][25] = 5633
expected_grid_map[36][26] = 5633
expected_grid_map[36][27] = 4608
expected_grid_map[36][41] = 32800
expected_grid_map[36][42] = 32800
expected_grid_map[36][43] = 32800
expected_grid_map[37][8] = 32800
expected_grid_map[37][9] = 49186
expected_grid_map[37][10] = 34864
expected_grid_map[37][13] = 16386
expected_grid_map[37][14] = 1025
expected_grid_map[37][15] = 1025
expected_grid_map[37][16] = 1025
expected_grid_map[37][17] = 1025
expected_grid_map[37][18] = 2064
expected_grid_map[37][19] = 32800
expected_grid_map[37][20] = 16386
expected_grid_map[37][21] = 1025
expected_grid_map[37][22] = 1025
expected_grid_map[37][23] = 2064
expected_grid_map[37][24] = 72
expected_grid_map[37][25] = 37408
expected_grid_map[37][26] = 32800
expected_grid_map[37][27] = 32800
expected_grid_map[37][41] = 32800
expected_grid_map[37][42] = 32800
expected_grid_map[37][43] = 32800
expected_grid_map[38][8] = 32800
expected_grid_map[38][9] = 32800
expected_grid_map[38][10] = 32800
expected_grid_map[38][13] = 49186
expected_grid_map[38][14] = 1025
expected_grid_map[38][15] = 1025
expected_grid_map[38][16] = 1025
expected_grid_map[38][17] = 1025
expected_grid_map[38][18] = 1025
expected_grid_map[38][19] = 2064
expected_grid_map[38][20] = 32800
expected_grid_map[38][25] = 32800
expected_grid_map[38][26] = 32800
expected_grid_map[38][27] = 32800
expected_grid_map[38][41] = 32800
expected_grid_map[38][42] = 32800
expected_grid_map[38][43] = 32800
expected_grid_map[39][8] = 72
expected_grid_map[39][9] = 1097
expected_grid_map[39][10] = 1097
expected_grid_map[39][11] = 1025
expected_grid_map[39][12] = 1025
expected_grid_map[39][13] = 3089
expected_grid_map[39][14] = 1025
expected_grid_map[39][15] = 1025
expected_grid_map[39][16] = 1025
expected_grid_map[39][17] = 1025
expected_grid_map[39][18] = 1025
expected_grid_map[39][19] = 1025
expected_grid_map[39][20] = 2064
expected_grid_map[39][25] = 32800
expected_grid_map[39][26] = 32872
expected_grid_map[39][27] = 37408
expected_grid_map[39][41] = 32800
expected_grid_map[39][42] = 32800
expected_grid_map[39][43] = 32800
expected_grid_map[40][25] = 32800
expected_grid_map[40][26] = 32800
expected_grid_map[40][27] = 32800
expected_grid_map[40][41] = 32800
expected_grid_map[40][42] = 32800
expected_grid_map[40][43] = 32800
expected_grid_map[41][25] = 49186
expected_grid_map[41][26] = 34864
expected_grid_map[41][27] = 32872
expected_grid_map[41][28] = 4608
expected_grid_map[41][41] = 32800
expected_grid_map[41][42] = 32800
expected_grid_map[41][43] = 32800
expected_grid_map[42][25] = 32800
expected_grid_map[42][26] = 32800
expected_grid_map[42][27] = 32800
expected_grid_map[42][28] = 32800
expected_grid_map[42][41] = 32800
expected_grid_map[42][42] = 32800
expected_grid_map[42][43] = 32800
expected_grid_map[43][25] = 32872
expected_grid_map[43][26] = 37408
expected_grid_map[43][27] = 49186
expected_grid_map[43][28] = 2064
expected_grid_map[43][41] = 32800
expected_grid_map[43][42] = 32800
expected_grid_map[43][43] = 32800
expected_grid_map[44][25] = 32800
expected_grid_map[44][26] = 32800
expected_grid_map[44][27] = 32800
expected_grid_map[44][30] = 16386
expected_grid_map[44][31] = 17411
expected_grid_map[44][32] = 1025
expected_grid_map[44][33] = 5633
expected_grid_map[44][34] = 17411
expected_grid_map[44][35] = 1025
expected_grid_map[44][36] = 1025
expected_grid_map[44][37] = 1025
expected_grid_map[44][38] = 5633
expected_grid_map[44][39] = 17411
expected_grid_map[44][40] = 1025
expected_grid_map[44][41] = 3089
expected_grid_map[44][42] = 3089
expected_grid_map[44][43] = 2064
expected_grid_map[45][25] = 32800
expected_grid_map[45][26] = 49186
expected_grid_map[45][27] = 34864
expected_grid_map[45][30] = 32800
expected_grid_map[45][31] = 32800
expected_grid_map[45][33] = 72
expected_grid_map[45][34] = 3089
expected_grid_map[45][35] = 1025
expected_grid_map[45][36] = 1025
expected_grid_map[45][37] = 1025
expected_grid_map[45][38] = 1097
expected_grid_map[45][39] = 2064
expected_grid_map[46][25] = 32800
expected_grid_map[46][26] = 32800
expected_grid_map[46][27] = 32800
expected_grid_map[46][30] = 32800
expected_grid_map[46][31] = 32800
expected_grid_map[47][25] = 72
expected_grid_map[47][26] = 1097
expected_grid_map[47][27] = 1097
expected_grid_map[47][28] = 1025
expected_grid_map[47][29] = 1025
expected_grid_map[47][30] = 3089
expected_grid_map[47][31] = 2064
# Attention, once we have fixed the generator this needs to be changed!!!!
expected_grid_map = env.rail.grid
assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid,
expected_grid_map)
s0 = 0
s1 = 0
for a in range(env.get_num_agents()):
s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0))
s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0))
assert s0 == 36, "actual={}".format(s0)
assert s1 == 27, "actual={}".format(s1)
def test_sparse_rail_generator_deterministic():
"""Check that sparse_rail_generator runs deterministic over different python versions!"""
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1, random_seed=1)
env.reset()
# for r in range(env.height):
# for c in range(env.width):
# print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c,
# env.rail.get_full_transitions(
# r, c), r, c))
assert env.rail.get_full_transitions(0, 1) == 0, "[0][1]"
assert env.rail.get_full_transitions(0, 2) == 0, "[0][2]"
assert env.rail.get_full_transitions(0, 3) == 0, "[0][3]"
assert env.rail.get_full_transitions(0, 4) == 0, "[0][4]"
assert env.rail.get_full_transitions(0, 5) == 0, "[0][5]"
assert env.rail.get_full_transitions(0, 6) == 0, "[0][6]"
assert env.rail.get_full_transitions(0, 7) == 0, "[0][7]"
assert env.rail.get_full_transitions(0, 8) == 0, "[0][8]"
assert env.rail.get_full_transitions(0, 9) == 0, "[0][9]"
assert env.rail.get_full_transitions(0, 10) == 0, "[0][10]"
assert env.rail.get_full_transitions(0, 11) == 0, "[0][11]"
assert env.rail.get_full_transitions(0, 12) == 0, "[0][12]"
assert env.rail.get_full_transitions(0, 13) == 0, "[0][13]"
assert env.rail.get_full_transitions(0, 14) == 0, "[0][14]"
assert env.rail.get_full_transitions(0, 15) == 0, "[0][15]"
assert env.rail.get_full_transitions(0, 16) == 0, "[0][16]"
assert env.rail.get_full_transitions(0, 17) == 0, "[0][17]"
assert env.rail.get_full_transitions(0, 18) == 0, "[0][18]"
assert env.rail.get_full_transitions(0, 19) == 0, "[0][19]"
assert env.rail.get_full_transitions(0, 20) == 0, "[0][20]"
assert env.rail.get_full_transitions(0, 21) == 0, "[0][21]"
assert env.rail.get_full_transitions(0, 22) == 0, "[0][22]"
assert env.rail.get_full_transitions(0, 23) == 0, "[0][23]"
assert env.rail.get_full_transitions(0, 24) == 0, "[0][24]"
assert env.rail.get_full_transitions(1, 0) == 0, "[1][0]"
assert env.rail.get_full_transitions(1, 1) == 0, "[1][1]"
assert env.rail.get_full_transitions(1, 2) == 0, "[1][2]"
assert env.rail.get_full_transitions(1, 3) == 0, "[1][3]"
assert env.rail.get_full_transitions(1, 4) == 0, "[1][4]"
assert env.rail.get_full_transitions(1, 5) == 0, "[1][5]"
assert env.rail.get_full_transitions(1, 6) == 0, "[1][6]"
assert env.rail.get_full_transitions(1, 7) == 0, "[1][7]"
assert env.rail.get_full_transitions(1, 8) == 0, "[1][8]"
assert env.rail.get_full_transitions(1, 9) == 0, "[1][9]"
assert env.rail.get_full_transitions(1, 10) == 0, "[1][10]"
assert env.rail.get_full_transitions(1, 11) == 16386, "[1][11]"
assert env.rail.get_full_transitions(1, 12) == 1025, "[1][12]"
assert env.rail.get_full_transitions(1, 13) == 17411, "[1][13]"
assert env.rail.get_full_transitions(1, 14) == 17411, "[1][14]"
assert env.rail.get_full_transitions(1, 15) == 1025, "[1][15]"
assert env.rail.get_full_transitions(1, 16) == 1025, "[1][16]"
assert env.rail.get_full_transitions(1, 17) == 1025, "[1][17]"
assert env.rail.get_full_transitions(1, 18) == 1025, "[1][18]"
assert env.rail.get_full_transitions(1, 19) == 5633, "[1][19]"
assert env.rail.get_full_transitions(1, 20) == 5633, "[1][20]"
assert env.rail.get_full_transitions(1, 21) == 4608, "[1][21]"
assert env.rail.get_full_transitions(1, 22) == 0, "[1][22]"
assert env.rail.get_full_transitions(1, 23) == 0, "[1][23]"
assert env.rail.get_full_transitions(1, 24) == 0, "[1][24]"
assert env.rail.get_full_transitions(2, 0) == 0, "[2][0]"
assert env.rail.get_full_transitions(2, 1) == 0, "[2][1]"
assert env.rail.get_full_transitions(2, 2) == 0, "[2][2]"
assert env.rail.get_full_transitions(2, 3) == 0, "[2][3]"
assert env.rail.get_full_transitions(2, 4) == 0, "[2][4]"
assert env.rail.get_full_transitions(2, 5) == 0, "[2][5]"
assert env.rail.get_full_transitions(2, 6) == 0, "[2][6]"
assert env.rail.get_full_transitions(2, 7) == 0, "[2][7]"
assert env.rail.get_full_transitions(2, 8) == 0, "[2][8]"
assert env.rail.get_full_transitions(2, 9) == 0, "[2][9]"
assert env.rail.get_full_transitions(2, 10) == 0, "[2][10]"
assert env.rail.get_full_transitions(2, 11) == 32800, "[2][11]"
assert env.rail.get_full_transitions(2, 12) == 16386, "[2][12]"
assert env.rail.get_full_transitions(2, 13) == 34864, "[2][13]"
assert env.rail.get_full_transitions(2, 14) == 32800, "[2][14]"
assert env.rail.get_full_transitions(2, 15) == 0, "[2][15]"
assert env.rail.get_full_transitions(2, 16) == 0, "[2][16]"
assert env.rail.get_full_transitions(2, 17) == 0, "[2][17]"
assert env.rail.get_full_transitions(2, 18) == 0, "[2][18]"
assert env.rail.get_full_transitions(2, 19) == 32800, "[2][19]"
assert env.rail.get_full_transitions(2, 20) == 32800, "[2][20]"
assert env.rail.get_full_transitions(2, 21) == 32800, "[2][21]"
assert env.rail.get_full_transitions(2, 22) == 0, "[2][22]"
assert env.rail.get_full_transitions(2, 23) == 0, "[2][23]"
assert env.rail.get_full_transitions(2, 24) == 0, "[2][24]"
assert env.rail.get_full_transitions(3, 0) == 0, "[3][0]"
assert env.rail.get_full_transitions(3, 1) == 0, "[3][1]"
assert env.rail.get_full_transitions(3, 2) == 0, "[3][2]"
assert env.rail.get_full_transitions(3, 3) == 0, "[3][3]"
assert env.rail.get_full_transitions(3, 4) == 0, "[3][4]"
assert env.rail.get_full_transitions(3, 5) == 0, "[3][5]"
assert env.rail.get_full_transitions(3, 6) == 0, "[3][6]"
assert env.rail.get_full_transitions(3, 7) == 0, "[3][7]"
assert env.rail.get_full_transitions(3, 8) == 0, "[3][8]"
assert env.rail.get_full_transitions(3, 9) == 0, "[3][9]"
assert env.rail.get_full_transitions(3, 10) == 0, "[3][10]"
assert env.rail.get_full_transitions(3, 11) == 32800, "[3][11]"
assert env.rail.get_full_transitions(3, 12) == 32800, "[3][12]"
assert env.rail.get_full_transitions(3, 13) == 32800, "[3][13]"
assert env.rail.get_full_transitions(3, 14) == 32800, "[3][14]"
assert env.rail.get_full_transitions(3, 15) == 0, "[3][15]"
assert env.rail.get_full_transitions(3, 16) == 0, "[3][16]"
assert env.rail.get_full_transitions(3, 17) == 0, "[3][17]"
assert env.rail.get_full_transitions(3, 18) == 0, "[3][18]"
assert env.rail.get_full_transitions(3, 19) == 32800, "[3][19]"
assert env.rail.get_full_transitions(3, 20) == 32872, "[3][20]"
assert env.rail.get_full_transitions(3, 21) == 37408, "[3][21]"
assert env.rail.get_full_transitions(3, 22) == 0, "[3][22]"
assert env.rail.get_full_transitions(3, 23) == 0, "[3][23]"
assert env.rail.get_full_transitions(3, 24) == 0, "[3][24]"
assert env.rail.get_full_transitions(4, 0) == 0, "[4][0]"
assert env.rail.get_full_transitions(4, 1) == 0, "[4][1]"
assert env.rail.get_full_transitions(4, 2) == 0, "[4][2]"
assert env.rail.get_full_transitions(4, 3) == 0, "[4][3]"
assert env.rail.get_full_transitions(4, 4) == 0, "[4][4]"
assert env.rail.get_full_transitions(4, 5) == 0, "[4][5]"
assert env.rail.get_full_transitions(4, 6) == 0, "[4][6]"
assert env.rail.get_full_transitions(4, 7) == 0, "[4][7]"
assert env.rail.get_full_transitions(4, 8) == 0, "[4][8]"
assert env.rail.get_full_transitions(4, 9) == 0, "[4][9]"
assert env.rail.get_full_transitions(4, 10) == 0, "[4][10]"
assert env.rail.get_full_transitions(4, 11) == 32800, "[4][11]"
assert env.rail.get_full_transitions(4, 12) == 32800, "[4][12]"
assert env.rail.get_full_transitions(4, 13) == 32800, "[4][13]"
assert env.rail.get_full_transitions(4, 14) == 32800, "[4][14]"
assert env.rail.get_full_transitions(4, 15) == 0, "[4][15]"
assert env.rail.get_full_transitions(4, 16) == 0, "[4][16]"
assert env.rail.get_full_transitions(4, 17) == 0, "[4][17]"
assert env.rail.get_full_transitions(4, 18) == 0, "[4][18]"
assert env.rail.get_full_transitions(4, 19) == 32800, "[4][19]"
assert env.rail.get_full_transitions(4, 20) == 32800, "[4][20]"
assert env.rail.get_full_transitions(4, 21) == 32800, "[4][21]"
assert env.rail.get_full_transitions(4, 22) == 0, "[4][22]"
assert env.rail.get_full_transitions(4, 23) == 0, "[4][23]"
assert env.rail.get_full_transitions(4, 24) == 0, "[4][24]"
assert env.rail.get_full_transitions(5, 0) == 0, "[5][0]"
assert env.rail.get_full_transitions(5, 1) == 0, "[5][1]"
assert env.rail.get_full_transitions(5, 2) == 0, "[5][2]"
assert env.rail.get_full_transitions(5, 3) == 0, "[5][3]"
assert env.rail.get_full_transitions(5, 4) == 0, "[5][4]"
assert env.rail.get_full_transitions(5, 5) == 0, "[5][5]"
assert env.rail.get_full_transitions(5, 6) == 0, "[5][6]"
assert env.rail.get_full_transitions(5, 7) == 0, "[5][7]"
assert env.rail.get_full_transitions(5, 8) == 0, "[5][8]"
assert env.rail.get_full_transitions(5, 9) == 0, "[5][9]"
assert env.rail.get_full_transitions(5, 10) == 0, "[5][10]"
assert env.rail.get_full_transitions(5, 11) == 49186, "[5][11]"
assert env.rail.get_full_transitions(5, 12) == 3089, "[5][12]"
assert env.rail.get_full_transitions(5, 13) == 2064, "[5][13]"
assert env.rail.get_full_transitions(5, 14) == 32800, "[5][14]"
assert env.rail.get_full_transitions(5, 15) == 0, "[5][15]"
assert env.rail.get_full_transitions(5, 16) == 0, "[5][16]"
assert env.rail.get_full_transitions(5, 17) == 0, "[5][17]"
assert env.rail.get_full_transitions(5, 18) == 0, "[5][18]"
assert env.rail.get_full_transitions(5, 19) == 49186, "[5][19]"
assert env.rail.get_full_transitions(5, 20) == 34864, "[5][20]"
assert env.rail.get_full_transitions(5, 21) == 32872, "[5][21]"
assert env.rail.get_full_transitions(5, 22) == 4608, "[5][22]"
assert env.rail.get_full_transitions(5, 23) == 0, "[5][23]"
assert env.rail.get_full_transitions(5, 24) == 0, "[5][24]"
assert env.rail.get_full_transitions(6, 0) == 16386, "[6][0]"
assert env.rail.get_full_transitions(6, 1) == 17411, "[6][1]"
assert env.rail.get_full_transitions(6, 2) == 1025, "[6][2]"
assert env.rail.get_full_transitions(6, 3) == 5633, "[6][3]"
assert env.rail.get_full_transitions(6, 4) == 17411, "[6][4]"
assert env.rail.get_full_transitions(6, 5) == 1025, "[6][5]"
assert env.rail.get_full_transitions(6, 6) == 1025, "[6][6]"
assert env.rail.get_full_transitions(6, 7) == 1025, "[6][7]"
assert env.rail.get_full_transitions(6, 8) == 5633, "[6][8]"
assert env.rail.get_full_transitions(6, 9) == 17411, "[6][9]"
assert env.rail.get_full_transitions(6, 10) == 1025, "[6][10]"
assert env.rail.get_full_transitions(6, 11) == 3089, "[6][11]"
assert env.rail.get_full_transitions(6, 12) == 1025, "[6][12]"
assert env.rail.get_full_transitions(6, 13) == 1025, "[6][13]"
assert env.rail.get_full_transitions(6, 14) == 2064, "[6][14]"
assert env.rail.get_full_transitions(6, 15) == 0, "[6][15]"
assert env.rail.get_full_transitions(6, 16) == 0, "[6][16]"
assert env.rail.get_full_transitions(6, 17) == 0, "[6][17]"
assert env.rail.get_full_transitions(6, 18) == 0, "[6][18]"
assert env.rail.get_full_transitions(6, 19) == 32800, "[6][19]"
assert env.rail.get_full_transitions(6, 20) == 32800, "[6][20]"
assert env.rail.get_full_transitions(6, 21) == 32800, "[6][21]"
assert env.rail.get_full_transitions(6, 22) == 32800, "[6][22]"
assert env.rail.get_full_transitions(6, 23) == 0, "[6][23]"
assert env.rail.get_full_transitions(6, 24) == 0, "[6][24]"
assert env.rail.get_full_transitions(7, 0) == 32800, "[7][0]"
assert env.rail.get_full_transitions(7, 1) == 32800, "[7][1]"
assert env.rail.get_full_transitions(7, 2) == 0, "[7][2]"
assert env.rail.get_full_transitions(7, 3) == 72, "[7][3]"
assert env.rail.get_full_transitions(7, 4) == 3089, "[7][4]"
assert env.rail.get_full_transitions(7, 5) == 1025, "[7][5]"
assert env.rail.get_full_transitions(7, 6) == 1025, "[7][6]"
assert env.rail.get_full_transitions(7, 7) == 1025, "[7][7]"
assert env.rail.get_full_transitions(7, 8) == 1097, "[7][8]"
assert env.rail.get_full_transitions(7, 9) == 2064, "[7][9]"
assert env.rail.get_full_transitions(7, 10) == 0, "[7][10]"
assert env.rail.get_full_transitions(7, 11) == 0, "[7][11]"
assert env.rail.get_full_transitions(7, 12) == 0, "[7][12]"
assert env.rail.get_full_transitions(7, 13) == 0, "[7][13]"
assert env.rail.get_full_transitions(7, 14) == 0, "[7][14]"
assert env.rail.get_full_transitions(7, 15) == 0, "[7][15]"
assert env.rail.get_full_transitions(7, 16) == 0, "[7][16]"
assert env.rail.get_full_transitions(7, 17) == 0, "[7][17]"
assert env.rail.get_full_transitions(7, 18) == 0, "[7][18]"
assert env.rail.get_full_transitions(7, 19) == 32872, "[7][19]"
assert env.rail.get_full_transitions(7, 20) == 37408, "[7][20]"
assert env.rail.get_full_transitions(7, 21) == 49186, "[7][21]"
assert env.rail.get_full_transitions(7, 22) == 2064, "[7][22]"
assert env.rail.get_full_transitions(7, 23) == 0, "[7][23]"
assert env.rail.get_full_transitions(7, 24) == 0, "[7][24]"
assert env.rail.get_full_transitions(8, 0) == 32800, "[8][0]"
assert env.rail.get_full_transitions(8, 1) == 32800, "[8][1]"
assert env.rail.get_full_transitions(8, 2) == 0, "[8][2]"
assert env.rail.get_full_transitions(8, 3) == 0, "[8][3]"
assert env.rail.get_full_transitions(8, 4) == 0, "[8][4]"
assert env.rail.get_full_transitions(8, 5) == 0, "[8][5]"
assert env.rail.get_full_transitions(8, 6) == 0, "[8][6]"
assert env.rail.get_full_transitions(8, 7) == 0, "[8][7]"
assert env.rail.get_full_transitions(8, 8) == 0, "[8][8]"
assert env.rail.get_full_transitions(8, 9) == 0, "[8][9]"
assert env.rail.get_full_transitions(8, 10) == 0, "[8][10]"
assert env.rail.get_full_transitions(8, 11) == 0, "[8][11]"
assert env.rail.get_full_transitions(8, 12) == 0, "[8][12]"
assert env.rail.get_full_transitions(8, 13) == 0, "[8][13]"
assert env.rail.get_full_transitions(8, 14) == 0, "[8][14]"
assert env.rail.get_full_transitions(8, 15) == 0, "[8][15]"
assert env.rail.get_full_transitions(8, 16) == 0, "[8][16]"
assert env.rail.get_full_transitions(8, 17) == 0, "[8][17]"
assert env.rail.get_full_transitions(8, 18) == 0, "[8][18]"
assert env.rail.get_full_transitions(8, 19) == 32800, "[8][19]"
assert env.rail.get_full_transitions(8, 20) == 32800, "[8][20]"
assert env.rail.get_full_transitions(8, 21) == 32800, "[8][21]"
assert env.rail.get_full_transitions(8, 22) == 0, "[8][22]"
assert env.rail.get_full_transitions(8, 23) == 0, "[8][23]"
assert env.rail.get_full_transitions(8, 24) == 0, "[8][24]"
assert env.rail.get_full_transitions(9, 0) == 32800, "[9][0]"
assert env.rail.get_full_transitions(9, 1) == 32800, "[9][1]"
assert env.rail.get_full_transitions(9, 2) == 0, "[9][2]"
assert env.rail.get_full_transitions(9, 3) == 0, "[9][3]"
assert env.rail.get_full_transitions(9, 4) == 0, "[9][4]"
assert env.rail.get_full_transitions(9, 5) == 0, "[9][5]"
assert env.rail.get_full_transitions(9, 6) == 0, "[9][6]"
assert env.rail.get_full_transitions(9, 7) == 0, "[9][7]"
assert env.rail.get_full_transitions(9, 8) == 0, "[9][8]"
assert env.rail.get_full_transitions(9, 9) == 0, "[9][9]"
assert env.rail.get_full_transitions(9, 10) == 0, "[9][10]"
assert env.rail.get_full_transitions(9, 11) == 0, "[9][11]"
assert env.rail.get_full_transitions(9, 12) == 0, "[9][12]"
assert env.rail.get_full_transitions(9, 13) == 0, "[9][13]"
assert env.rail.get_full_transitions(9, 14) == 0, "[9][14]"
assert env.rail.get_full_transitions(9, 15) == 0, "[9][15]"
assert env.rail.get_full_transitions(9, 16) == 0, "[9][16]"
assert env.rail.get_full_transitions(9, 17) == 0, "[9][17]"
assert env.rail.get_full_transitions(9, 18) == 0, "[9][18]"
assert env.rail.get_full_transitions(9, 19) == 32800, "[9][19]"
assert env.rail.get_full_transitions(9, 20) == 49186, "[9][20]"
assert env.rail.get_full_transitions(9, 21) == 34864, "[9][21]"
assert env.rail.get_full_transitions(9, 22) == 0, "[9][22]"
assert env.rail.get_full_transitions(9, 23) == 0, "[9][23]"
assert env.rail.get_full_transitions(9, 24) == 0, "[9][24]"
assert env.rail.get_full_transitions(10, 0) == 32800, "[10][0]"
assert env.rail.get_full_transitions(10, 1) == 32800, "[10][1]"
assert env.rail.get_full_transitions(10, 2) == 0, "[10][2]"
assert env.rail.get_full_transitions(10, 3) == 0, "[10][3]"
assert env.rail.get_full_transitions(10, 4) == 0, "[10][4]"
assert env.rail.get_full_transitions(10, 5) == 0, "[10][5]"
assert env.rail.get_full_transitions(10, 6) == 0, "[10][6]"
assert env.rail.get_full_transitions(10, 7) == 0, "[10][7]"
assert env.rail.get_full_transitions(10, 8) == 0, "[10][8]"
assert env.rail.get_full_transitions(10, 9) == 0, "[10][9]"
assert env.rail.get_full_transitions(10, 10) == 0, "[10][10]"
assert env.rail.get_full_transitions(10, 11) == 0, "[10][11]"
assert env.rail.get_full_transitions(10, 12) == 0, "[10][12]"
assert env.rail.get_full_transitions(10, 13) == 0, "[10][13]"
assert env.rail.get_full_transitions(10, 14) == 0, "[10][14]"
assert env.rail.get_full_transitions(10, 15) == 0, "[10][15]"
assert env.rail.get_full_transitions(10, 16) == 0, "[10][16]"
assert env.rail.get_full_transitions(10, 17) == 0, "[10][17]"
assert env.rail.get_full_transitions(10, 18) == 0, "[10][18]"
assert env.rail.get_full_transitions(10, 19) == 32800, "[10][19]"
assert env.rail.get_full_transitions(10, 20) == 32800, "[10][20]"
assert env.rail.get_full_transitions(10, 21) == 32800, "[10][21]"
assert env.rail.get_full_transitions(10, 22) == 0, "[10][22]"
assert env.rail.get_full_transitions(10, 23) == 0, "[10][23]"
assert env.rail.get_full_transitions(10, 24) == 0, "[10][24]"
assert env.rail.get_full_transitions(11, 0) == 32800, "[11][0]"
assert env.rail.get_full_transitions(11, 1) == 32800, "[11][1]"
assert env.rail.get_full_transitions(11, 2) == 0, "[11][2]"
assert env.rail.get_full_transitions(11, 3) == 0, "[11][3]"
assert env.rail.get_full_transitions(11, 4) == 0, "[11][4]"
assert env.rail.get_full_transitions(11, 5) == 0, "[11][5]"
assert env.rail.get_full_transitions(11, 6) == 0, "[11][6]"
assert env.rail.get_full_transitions(11, 7) == 0, "[11][7]"
assert env.rail.get_full_transitions(11, 8) == 0, "[11][8]"
assert env.rail.get_full_transitions(11, 9) == 0, "[11][9]"
assert env.rail.get_full_transitions(11, 10) == 0, "[11][10]"
assert env.rail.get_full_transitions(11, 11) == 0, "[11][11]"
assert env.rail.get_full_transitions(11, 12) == 0, "[11][12]"
assert env.rail.get_full_transitions(11, 13) == 0, "[11][13]"
assert env.rail.get_full_transitions(11, 14) == 0, "[11][14]"
assert env.rail.get_full_transitions(11, 15) == 0, "[11][15]"
assert env.rail.get_full_transitions(11, 16) == 0, "[11][16]"
assert env.rail.get_full_transitions(11, 17) == 0, "[11][17]"
assert env.rail.get_full_transitions(11, 18) == 0, "[11][18]"
assert env.rail.get_full_transitions(11, 19) == 32800, "[11][19]"
assert env.rail.get_full_transitions(11, 20) == 32800, "[11][20]"
assert env.rail.get_full_transitions(11, 21) == 32800, "[11][21]"
assert env.rail.get_full_transitions(11, 22) == 0, "[11][22]"
assert env.rail.get_full_transitions(11, 23) == 0, "[11][23]"
assert env.rail.get_full_transitions(11, 24) == 0, "[11][24]"
assert env.rail.get_full_transitions(12, 0) == 32800, "[12][0]"
assert env.rail.get_full_transitions(12, 1) == 32800, "[12][1]"
assert env.rail.get_full_transitions(12, 2) == 0, "[12][2]"
assert env.rail.get_full_transitions(12, 3) == 0, "[12][3]"
assert env.rail.get_full_transitions(12, 4) == 0, "[12][4]"
assert env.rail.get_full_transitions(12, 5) == 0, "[12][5]"
assert env.rail.get_full_transitions(12, 6) == 0, "[12][6]"
assert env.rail.get_full_transitions(12, 7) == 0, "[12][7]"
assert env.rail.get_full_transitions(12, 8) == 0, "[12][8]"
assert env.rail.get_full_transitions(12, 9) == 0, "[12][9]"
assert env.rail.get_full_transitions(12, 10) == 0, "[12][10]"
assert env.rail.get_full_transitions(12, 11) == 0, "[12][11]"
assert env.rail.get_full_transitions(12, 12) == 0, "[12][12]"
assert env.rail.get_full_transitions(12, 13) == 0, "[12][13]"
assert env.rail.get_full_transitions(12, 14) == 0, "[12][14]"
assert env.rail.get_full_transitions(12, 15) == 0, "[12][15]"
assert env.rail.get_full_transitions(12, 16) == 0, "[12][16]"
assert env.rail.get_full_transitions(12, 17) == 0, "[12][17]"
assert env.rail.get_full_transitions(12, 18) == 0, "[12][18]"
assert env.rail.get_full_transitions(12, 19) == 32800, "[12][19]"
assert env.rail.get_full_transitions(12, 20) == 32800, "[12][20]"
assert env.rail.get_full_transitions(12, 21) == 32800, "[12][21]"
assert env.rail.get_full_transitions(12, 22) == 0, "[12][22]"
assert env.rail.get_full_transitions(12, 23) == 0, "[12][23]"
assert env.rail.get_full_transitions(12, 24) == 0, "[12][24]"
assert env.rail.get_full_transitions(13, 0) == 32800, "[13][0]"
assert env.rail.get_full_transitions(13, 1) == 32800, "[13][1]"
assert env.rail.get_full_transitions(13, 2) == 0, "[13][2]"
assert env.rail.get_full_transitions(13, 3) == 0, "[13][3]"
assert env.rail.get_full_transitions(13, 4) == 0, "[13][4]"
assert env.rail.get_full_transitions(13, 5) == 0, "[13][5]"
assert env.rail.get_full_transitions(13, 6) == 0, "[13][6]"
assert env.rail.get_full_transitions(13, 7) == 0, "[13][7]"
assert env.rail.get_full_transitions(13, 8) == 0, "[13][8]"
assert env.rail.get_full_transitions(13, 9) == 0, "[13][9]"
assert env.rail.get_full_transitions(13, 10) == 0, "[13][10]"
assert env.rail.get_full_transitions(13, 11) == 0, "[13][11]"
assert env.rail.get_full_transitions(13, 12) == 0, "[13][12]"
assert env.rail.get_full_transitions(13, 13) == 0, "[13][13]"
assert env.rail.get_full_transitions(13, 14) == 0, "[13][14]"
assert env.rail.get_full_transitions(13, 15) == 0, "[13][15]"
assert env.rail.get_full_transitions(13, 16) == 0, "[13][16]"
assert env.rail.get_full_transitions(13, 17) == 0, "[13][17]"
assert env.rail.get_full_transitions(13, 18) == 0, "[13][18]"
assert env.rail.get_full_transitions(13, 19) == 32800, "[13][19]"
assert env.rail.get_full_transitions(13, 20) == 32800, "[13][20]"
assert env.rail.get_full_transitions(13, 21) == 32800, "[13][21]"
assert env.rail.get_full_transitions(13, 22) == 0, "[13][22]"
assert env.rail.get_full_transitions(13, 23) == 0, "[13][23]"
assert env.rail.get_full_transitions(13, 24) == 0, "[13][24]"
assert env.rail.get_full_transitions(14, 0) == 32800, "[14][0]"
assert env.rail.get_full_transitions(14, 1) == 32800, "[14][1]"
assert env.rail.get_full_transitions(14, 2) == 0, "[14][2]"
assert env.rail.get_full_transitions(14, 3) == 0, "[14][3]"
assert env.rail.get_full_transitions(14, 4) == 0, "[14][4]"
assert env.rail.get_full_transitions(14, 5) == 0, "[14][5]"
assert env.rail.get_full_transitions(14, 6) == 0, "[14][6]"
assert env.rail.get_full_transitions(14, 7) == 0, "[14][7]"
assert env.rail.get_full_transitions(14, 8) == 0, "[14][8]"
assert env.rail.get_full_transitions(14, 9) == 0, "[14][9]"
assert env.rail.get_full_transitions(14, 10) == 0, "[14][10]"
assert env.rail.get_full_transitions(14, 11) == 0, "[14][11]"
assert env.rail.get_full_transitions(14, 12) == 0, "[14][12]"
assert env.rail.get_full_transitions(14, 13) == 0, "[14][13]"
assert env.rail.get_full_transitions(14, 14) == 0, "[14][14]"
assert env.rail.get_full_transitions(14, 15) == 0, "[14][15]"
assert env.rail.get_full_transitions(14, 16) == 0, "[14][16]"
assert env.rail.get_full_transitions(14, 17) == 0, "[14][17]"
assert env.rail.get_full_transitions(14, 18) == 0, "[14][18]"
assert env.rail.get_full_transitions(14, 19) == 32800, "[14][19]"
assert env.rail.get_full_transitions(14, 20) == 32800, "[14][20]"
assert env.rail.get_full_transitions(14, 21) == 32800, "[14][21]"
assert env.rail.get_full_transitions(14, 22) == 0, "[14][22]"
assert env.rail.get_full_transitions(14, 23) == 0, "[14][23]"
assert env.rail.get_full_transitions(14, 24) == 0, "[14][24]"
assert env.rail.get_full_transitions(15, 0) == 32800, "[15][0]"
assert env.rail.get_full_transitions(15, 1) == 32800, "[15][1]"
assert env.rail.get_full_transitions(15, 2) == 0, "[15][2]"
assert env.rail.get_full_transitions(15, 3) == 0, "[15][3]"
assert env.rail.get_full_transitions(15, 4) == 0, "[15][4]"
assert env.rail.get_full_transitions(15, 5) == 0, "[15][5]"
assert env.rail.get_full_transitions(15, 6) == 0, "[15][6]"
assert env.rail.get_full_transitions(15, 7) == 0, "[15][7]"
assert env.rail.get_full_transitions(15, 8) == 0, "[15][8]"
assert env.rail.get_full_transitions(15, 9) == 0, "[15][9]"
assert env.rail.get_full_transitions(15, 10) == 0, "[15][10]"
assert env.rail.get_full_transitions(15, 11) == 0, "[15][11]"
assert env.rail.get_full_transitions(15, 12) == 0, "[15][12]"
assert env.rail.get_full_transitions(15, 13) == 0, "[15][13]"
assert env.rail.get_full_transitions(15, 14) == 0, "[15][14]"
assert env.rail.get_full_transitions(15, 15) == 0, "[15][15]"
assert env.rail.get_full_transitions(15, 16) == 0, "[15][16]"
assert env.rail.get_full_transitions(15, 17) == 0, "[15][17]"
assert env.rail.get_full_transitions(15, 18) == 0, "[15][18]"
assert env.rail.get_full_transitions(15, 19) == 32800, "[15][19]"
assert env.rail.get_full_transitions(15, 20) == 32800, "[15][20]"
assert env.rail.get_full_transitions(15, 21) == 32800, "[15][21]"
assert env.rail.get_full_transitions(15, 22) == 0, "[15][22]"
assert env.rail.get_full_transitions(15, 23) == 0, "[15][23]"
assert env.rail.get_full_transitions(15, 24) == 0, "[15][24]"
assert env.rail.get_full_transitions(16, 0) == 32800, "[16][0]"
assert env.rail.get_full_transitions(16, 1) == 32800, "[16][1]"
assert env.rail.get_full_transitions(16, 2) == 0, "[16][2]"
assert env.rail.get_full_transitions(16, 3) == 0, "[16][3]"
assert env.rail.get_full_transitions(16, 4) == 0, "[16][4]"
assert env.rail.get_full_transitions(16, 5) == 0, "[16][5]"
assert env.rail.get_full_transitions(16, 6) == 0, "[16][6]"
assert env.rail.get_full_transitions(16, 7) == 0, "[16][7]"
assert env.rail.get_full_transitions(16, 8) == 0, "[16][8]"
assert env.rail.get_full_transitions(16, 9) == 0, "[16][9]"
assert env.rail.get_full_transitions(16, 10) == 0, "[16][10]"
assert env.rail.get_full_transitions(16, 11) == 0, "[16][11]"
assert env.rail.get_full_transitions(16, 12) == 0, "[16][12]"
assert env.rail.get_full_transitions(16, 13) == 0, "[16][13]"
assert env.rail.get_full_transitions(16, 14) == 0, "[16][14]"
assert env.rail.get_full_transitions(16, 15) == 0, "[16][15]"
assert env.rail.get_full_transitions(16, 16) == 0, "[16][16]"
assert env.rail.get_full_transitions(16, 17) == 0, "[16][17]"
assert env.rail.get_full_transitions(16, 18) == 0, "[16][18]"
assert env.rail.get_full_transitions(16, 19) == 32800, "[16][19]"
assert env.rail.get_full_transitions(16, 20) == 32800, "[16][20]"
assert env.rail.get_full_transitions(16, 21) == 32800, "[16][21]"
assert env.rail.get_full_transitions(16, 22) == 0, "[16][22]"
assert env.rail.get_full_transitions(16, 23) == 0, "[16][23]"
assert env.rail.get_full_transitions(16, 24) == 0, "[16][24]"
assert env.rail.get_full_transitions(17, 0) == 32800, "[17][0]"
assert env.rail.get_full_transitions(17, 1) == 32800, "[17][1]"
assert env.rail.get_full_transitions(17, 2) == 0, "[17][2]"
assert env.rail.get_full_transitions(17, 3) == 0, "[17][3]"
assert env.rail.get_full_transitions(17, 4) == 0, "[17][4]"
assert env.rail.get_full_transitions(17, 5) == 0, "[17][5]"
assert env.rail.get_full_transitions(17, 6) == 0, "[17][6]"
assert env.rail.get_full_transitions(17, 7) == 0, "[17][7]"
assert env.rail.get_full_transitions(17, 8) == 0, "[17][8]"
assert env.rail.get_full_transitions(17, 9) == 0, "[17][9]"
assert env.rail.get_full_transitions(17, 10) == 0, "[17][10]"
assert env.rail.get_full_transitions(17, 11) == 0, "[17][11]"
assert env.rail.get_full_transitions(17, 12) == 0, "[17][12]"
assert env.rail.get_full_transitions(17, 13) == 0, "[17][13]"
assert env.rail.get_full_transitions(17, 14) == 0, "[17][14]"
assert env.rail.get_full_transitions(17, 15) == 0, "[17][15]"
assert env.rail.get_full_transitions(17, 16) == 0, "[17][16]"
assert env.rail.get_full_transitions(17, 17) == 0, "[17][17]"
assert env.rail.get_full_transitions(17, 18) == 0, "[17][18]"
assert env.rail.get_full_transitions(17, 19) == 32800, "[17][19]"
assert env.rail.get_full_transitions(17, 20) == 32800, "[17][20]"
assert env.rail.get_full_transitions(17, 21) == 32800, "[17][21]"
assert env.rail.get_full_transitions(17, 22) == 0, "[17][22]"
assert env.rail.get_full_transitions(17, 23) == 0, "[17][23]"
assert env.rail.get_full_transitions(17, 24) == 0, "[17][24]"
assert env.rail.get_full_transitions(18, 0) == 72, "[18][0]"
assert env.rail.get_full_transitions(18, 1) == 37408, "[18][1]"
assert env.rail.get_full_transitions(18, 2) == 0, "[18][2]"
assert env.rail.get_full_transitions(18, 3) == 0, "[18][3]"
assert env.rail.get_full_transitions(18, 4) == 0, "[18][4]"
assert env.rail.get_full_transitions(18, 5) == 0, "[18][5]"
assert env.rail.get_full_transitions(18, 6) == 0, "[18][6]"
assert env.rail.get_full_transitions(18, 7) == 0, "[18][7]"
assert env.rail.get_full_transitions(18, 8) == 0, "[18][8]"
assert env.rail.get_full_transitions(18, 9) == 0, "[18][9]"
assert env.rail.get_full_transitions(18, 10) == 0, "[18][10]"
assert env.rail.get_full_transitions(18, 11) == 0, "[18][11]"
assert env.rail.get_full_transitions(18, 12) == 0, "[18][12]"
assert env.rail.get_full_transitions(18, 13) == 0, "[18][13]"
assert env.rail.get_full_transitions(18, 14) == 0, "[18][14]"
assert env.rail.get_full_transitions(18, 15) == 0, "[18][15]"
assert env.rail.get_full_transitions(18, 16) == 0, "[18][16]"
assert env.rail.get_full_transitions(18, 17) == 0, "[18][17]"
assert env.rail.get_full_transitions(18, 18) == 0, "[18][18]"
assert env.rail.get_full_transitions(18, 19) == 32800, "[18][19]"
assert env.rail.get_full_transitions(18, 20) == 32800, "[18][20]"
assert env.rail.get_full_transitions(18, 21) == 32800, "[18][21]"
assert env.rail.get_full_transitions(18, 22) == 0, "[18][22]"
assert env.rail.get_full_transitions(18, 23) == 0, "[18][23]"
assert env.rail.get_full_transitions(18, 24) == 0, "[18][24]"
assert env.rail.get_full_transitions(19, 0) == 0, "[19][0]"
assert env.rail.get_full_transitions(19, 1) == 32800, "[19][1]"
assert env.rail.get_full_transitions(19, 2) == 0, "[19][2]"
assert env.rail.get_full_transitions(19, 3) == 0, "[19][3]"
assert env.rail.get_full_transitions(19, 4) == 0, "[19][4]"
assert env.rail.get_full_transitions(19, 5) == 0, "[19][5]"
assert env.rail.get_full_transitions(19, 6) == 0, "[19][6]"
assert env.rail.get_full_transitions(19, 7) == 0, "[19][7]"
assert env.rail.get_full_transitions(19, 8) == 0, "[19][8]"
assert env.rail.get_full_transitions(19, 9) == 0, "[19][9]"
assert env.rail.get_full_transitions(19, 10) == 0, "[19][10]"
assert env.rail.get_full_transitions(19, 11) == 0, "[19][11]"
assert env.rail.get_full_transitions(19, 12) == 0, "[19][12]"
assert env.rail.get_full_transitions(19, 13) == 0, "[19][13]"
assert env.rail.get_full_transitions(19, 14) == 16386, "[19][14]"
assert env.rail.get_full_transitions(19, 15) == 1025, "[19][15]"
assert env.rail.get_full_transitions(19, 16) == 1025, "[19][16]"
assert env.rail.get_full_transitions(19, 17) == 1025, "[19][17]"
assert env.rail.get_full_transitions(19, 18) == 1025, "[19][18]"
assert env.rail.get_full_transitions(19, 19) == 38505, "[19][19]"
assert env.rail.get_full_transitions(19, 20) == 3089, "[19][20]"
assert env.rail.get_full_transitions(19, 21) == 2064, "[19][21]"
assert env.rail.get_full_transitions(19, 22) == 0, "[19][22]"
assert env.rail.get_full_transitions(19, 23) == 0, "[19][23]"
assert env.rail.get_full_transitions(19, 24) == 0, "[19][24]"
assert env.rail.get_full_transitions(20, 0) == 0, "[20][0]"
assert env.rail.get_full_transitions(20, 1) == 32800, "[20][1]"
assert env.rail.get_full_transitions(20, 2) == 0, "[20][2]"
assert env.rail.get_full_transitions(20, 3) == 0, "[20][3]"
assert env.rail.get_full_transitions(20, 4) == 0, "[20][4]"
assert env.rail.get_full_transitions(20, 5) == 0, "[20][5]"
assert env.rail.get_full_transitions(20, 6) == 0, "[20][6]"
assert env.rail.get_full_transitions(20, 7) == 0, "[20][7]"
assert env.rail.get_full_transitions(20, 8) == 0, "[20][8]"
assert env.rail.get_full_transitions(20, 9) == 0, "[20][9]"
assert env.rail.get_full_transitions(20, 10) == 0, "[20][10]"
assert env.rail.get_full_transitions(20, 11) == 0, "[20][11]"
assert env.rail.get_full_transitions(20, 12) == 0, "[20][12]"
assert env.rail.get_full_transitions(20, 13) == 0, "[20][13]"
assert env.rail.get_full_transitions(20, 14) == 32800, "[20][14]"
assert env.rail.get_full_transitions(20, 15) == 0, "[20][15]"
assert env.rail.get_full_transitions(20, 16) == 0, "[20][16]"
assert env.rail.get_full_transitions(20, 17) == 0, "[20][17]"
assert env.rail.get_full_transitions(20, 18) == 0, "[20][18]"
assert env.rail.get_full_transitions(20, 19) == 32800, "[20][19]"
assert env.rail.get_full_transitions(20, 20) == 0, "[20][20]"
assert env.rail.get_full_transitions(20, 21) == 0, "[20][21]"
assert env.rail.get_full_transitions(20, 22) == 0, "[20][22]"
assert env.rail.get_full_transitions(20, 23) == 0, "[20][23]"
assert env.rail.get_full_transitions(20, 24) == 0, "[20][24]"
assert env.rail.get_full_transitions(21, 0) == 0, "[21][0]"
assert env.rail.get_full_transitions(21, 1) == 32800, "[21][1]"
assert env.rail.get_full_transitions(21, 2) == 0, "[21][2]"
assert env.rail.get_full_transitions(21, 3) == 0, "[21][3]"
assert env.rail.get_full_transitions(21, 4) == 0, "[21][4]"
assert env.rail.get_full_transitions(21, 5) == 0, "[21][5]"
assert env.rail.get_full_transitions(21, 6) == 0, "[21][6]"
assert env.rail.get_full_transitions(21, 7) == 0, "[21][7]"
assert env.rail.get_full_transitions(21, 8) == 0, "[21][8]"
assert env.rail.get_full_transitions(21, 9) == 0, "[21][9]"
assert env.rail.get_full_transitions(21, 10) == 0, "[21][10]"
assert env.rail.get_full_transitions(21, 11) == 0, "[21][11]"
assert env.rail.get_full_transitions(21, 12) == 0, "[21][12]"
assert env.rail.get_full_transitions(21, 13) == 0, "[21][13]"
assert env.rail.get_full_transitions(21, 14) == 32800, "[21][14]"
assert env.rail.get_full_transitions(21, 15) == 0, "[21][15]"
assert env.rail.get_full_transitions(21, 16) == 0, "[21][16]"
assert env.rail.get_full_transitions(21, 17) == 0, "[21][17]"
assert env.rail.get_full_transitions(21, 18) == 0, "[21][18]"
assert env.rail.get_full_transitions(21, 19) == 32872, "[21][19]"
assert env.rail.get_full_transitions(21, 20) == 4608, "[21][20]"
assert env.rail.get_full_transitions(21, 21) == 0, "[21][21]"
assert env.rail.get_full_transitions(21, 22) == 0, "[21][22]"
assert env.rail.get_full_transitions(21, 23) == 0, "[21][23]"
assert env.rail.get_full_transitions(21, 24) == 0, "[21][24]"
assert env.rail.get_full_transitions(22, 0) == 0, "[22][0]"
assert env.rail.get_full_transitions(22, 1) == 32800, "[22][1]"
assert env.rail.get_full_transitions(22, 2) == 0, "[22][2]"
assert env.rail.get_full_transitions(22, 3) == 0, "[22][3]"
assert env.rail.get_full_transitions(22, 4) == 0, "[22][4]"
assert env.rail.get_full_transitions(22, 5) == 0, "[22][5]"
assert env.rail.get_full_transitions(22, 6) == 0, "[22][6]"
assert env.rail.get_full_transitions(22, 7) == 0, "[22][7]"
assert env.rail.get_full_transitions(22, 8) == 0, "[22][8]"
assert env.rail.get_full_transitions(22, 9) == 0, "[22][9]"
assert env.rail.get_full_transitions(22, 10) == 0, "[22][10]"
assert env.rail.get_full_transitions(22, 11) == 0, "[22][11]"
assert env.rail.get_full_transitions(22, 12) == 0, "[22][12]"
assert env.rail.get_full_transitions(22, 13) == 0, "[22][13]"
assert env.rail.get_full_transitions(22, 14) == 32800, "[22][14]"
assert env.rail.get_full_transitions(22, 15) == 0, "[22][15]"
assert env.rail.get_full_transitions(22, 16) == 0, "[22][16]"
assert env.rail.get_full_transitions(22, 17) == 0, "[22][17]"
assert env.rail.get_full_transitions(22, 18) == 0, "[22][18]"
assert env.rail.get_full_transitions(22, 19) == 49186, "[22][19]"
assert env.rail.get_full_transitions(22, 20) == 34864, "[22][20]"
assert env.rail.get_full_transitions(22, 21) == 0, "[22][21]"
assert env.rail.get_full_transitions(22, 22) == 0, "[22][22]"
assert env.rail.get_full_transitions(22, 23) == 0, "[22][23]"
assert env.rail.get_full_transitions(22, 24) == 0, "[22][24]"
assert env.rail.get_full_transitions(23, 0) == 0, "[23][0]"
assert env.rail.get_full_transitions(23, 1) == 32800, "[23][1]"
assert env.rail.get_full_transitions(23, 2) == 0, "[23][2]"
assert env.rail.get_full_transitions(23, 3) == 0, "[23][3]"
assert env.rail.get_full_transitions(23, 4) == 0, "[23][4]"
assert env.rail.get_full_transitions(23, 5) == 16386, "[23][5]"
assert env.rail.get_full_transitions(23, 6) == 1025, "[23][6]"
assert env.rail.get_full_transitions(23, 7) == 4608, "[23][7]"
assert env.rail.get_full_transitions(23, 8) == 0, "[23][8]"
assert env.rail.get_full_transitions(23, 9) == 0, "[23][9]"
assert env.rail.get_full_transitions(23, 10) == 0, "[23][10]"
assert env.rail.get_full_transitions(23, 11) == 0, "[23][11]"
assert env.rail.get_full_transitions(23, 12) == 0, "[23][12]"
assert env.rail.get_full_transitions(23, 13) == 0, "[23][13]"
assert env.rail.get_full_transitions(23, 14) == 32800, "[23][14]"
assert env.rail.get_full_transitions(23, 15) == 0, "[23][15]"
assert env.rail.get_full_transitions(23, 16) == 0, "[23][16]"
assert env.rail.get_full_transitions(23, 17) == 0, "[23][17]"
assert env.rail.get_full_transitions(23, 18) == 16386, "[23][18]"
assert env.rail.get_full_transitions(23, 19) == 34864, "[23][19]"
assert env.rail.get_full_transitions(23, 20) == 32872, "[23][20]"
assert env.rail.get_full_transitions(23, 21) == 4608, "[23][21]"
assert env.rail.get_full_transitions(23, 22) == 0, "[23][22]"
assert env.rail.get_full_transitions(23, 23) == 0, "[23][23]"
assert env.rail.get_full_transitions(23, 24) == 0, "[23][24]"
assert env.rail.get_full_transitions(24, 0) == 0, "[24][0]"
assert env.rail.get_full_transitions(24, 1) == 72, "[24][1]"
assert env.rail.get_full_transitions(24, 2) == 1025, "[24][2]"
assert env.rail.get_full_transitions(24, 3) == 5633, "[24][3]"
assert env.rail.get_full_transitions(24, 4) == 17411, "[24][4]"
assert env.rail.get_full_transitions(24, 5) == 3089, "[24][5]"
assert env.rail.get_full_transitions(24, 6) == 1025, "[24][6]"
assert env.rail.get_full_transitions(24, 7) == 1097, "[24][7]"
assert env.rail.get_full_transitions(24, 8) == 5633, "[24][8]"
assert env.rail.get_full_transitions(24, 9) == 17411, "[24][9]"
assert env.rail.get_full_transitions(24, 10) == 1025, "[24][10]"
assert env.rail.get_full_transitions(24, 11) == 5633, "[24][11]"
assert env.rail.get_full_transitions(24, 12) == 1025, "[24][12]"
assert env.rail.get_full_transitions(24, 13) == 1025, "[24][13]"
assert env.rail.get_full_transitions(24, 14) == 2064, "[24][14]"
assert env.rail.get_full_transitions(24, 15) == 0, "[24][15]"
assert env.rail.get_full_transitions(24, 16) == 0, "[24][16]"
assert env.rail.get_full_transitions(24, 17) == 0, "[24][17]"
assert env.rail.get_full_transitions(24, 18) == 32800, "[24][18]"
assert env.rail.get_full_transitions(24, 19) == 32800, "[24][19]"
assert env.rail.get_full_transitions(24, 20) == 32800, "[24][20]"
assert env.rail.get_full_transitions(24, 21) == 32800, "[24][21]"
assert env.rail.get_full_transitions(24, 22) == 0, "[24][22]"
assert env.rail.get_full_transitions(24, 23) == 0, "[24][23]"
assert env.rail.get_full_transitions(24, 24) == 0, "[24][24]"
assert env.rail.get_full_transitions(25, 0) == 0, "[25][0]"
assert env.rail.get_full_transitions(25, 1) == 0, "[25][1]"
assert env.rail.get_full_transitions(25, 2) == 0, "[25][2]"
assert env.rail.get_full_transitions(25, 3) == 72, "[25][3]"
assert env.rail.get_full_transitions(25, 4) == 3089, "[25][4]"
assert env.rail.get_full_transitions(25, 5) == 5633, "[25][5]"
assert env.rail.get_full_transitions(25, 6) == 1025, "[25][6]"
assert env.rail.get_full_transitions(25, 7) == 17411, "[25][7]"
assert env.rail.get_full_transitions(25, 8) == 1097, "[25][8]"
assert env.rail.get_full_transitions(25, 9) == 2064, "[25][9]"
assert env.rail.get_full_transitions(25, 10) == 0, "[25][10]"
assert env.rail.get_full_transitions(25, 11) == 32800, "[25][11]"
assert env.rail.get_full_transitions(25, 12) == 0, "[25][12]"
assert env.rail.get_full_transitions(25, 13) == 0, "[25][13]"
assert env.rail.get_full_transitions(25, 14) == 0, "[25][14]"
assert env.rail.get_full_transitions(25, 15) == 0, "[25][15]"
assert env.rail.get_full_transitions(25, 16) == 0, "[25][16]"
assert env.rail.get_full_transitions(25, 17) == 0, "[25][17]"
assert env.rail.get_full_transitions(25, 18) == 72, "[25][18]"
assert env.rail.get_full_transitions(25, 19) == 37408, "[25][19]"
assert env.rail.get_full_transitions(25, 20) == 49186, "[25][20]"
assert env.rail.get_full_transitions(25, 21) == 2064, "[25][21]"
assert env.rail.get_full_transitions(25, 22) == 0, "[25][22]"
assert env.rail.get_full_transitions(25, 23) == 0, "[25][23]"
assert env.rail.get_full_transitions(25, 24) == 0, "[25][24]"
assert env.rail.get_full_transitions(26, 0) == 0, "[26][0]"
assert env.rail.get_full_transitions(26, 1) == 0, "[26][1]"
assert env.rail.get_full_transitions(26, 2) == 0, "[26][2]"
assert env.rail.get_full_transitions(26, 3) == 0, "[26][3]"
assert env.rail.get_full_transitions(26, 4) == 0, "[26][4]"
assert env.rail.get_full_transitions(26, 5) == 72, "[26][5]"
assert env.rail.get_full_transitions(26, 6) == 1025, "[26][6]"
assert env.rail.get_full_transitions(26, 7) == 2064, "[26][7]"
assert env.rail.get_full_transitions(26, 8) == 0, "[26][8]"
assert env.rail.get_full_transitions(26, 9) == 0, "[26][9]"
assert env.rail.get_full_transitions(26, 10) == 0, "[26][10]"
assert env.rail.get_full_transitions(26, 11) == 32800, "[26][11]"
assert env.rail.get_full_transitions(26, 12) == 0, "[26][12]"
assert env.rail.get_full_transitions(26, 13) == 0, "[26][13]"
assert env.rail.get_full_transitions(26, 14) == 0, "[26][14]"
assert env.rail.get_full_transitions(26, 15) == 0, "[26][15]"
assert env.rail.get_full_transitions(26, 16) == 0, "[26][16]"
assert env.rail.get_full_transitions(26, 17) == 0, "[26][17]"
assert env.rail.get_full_transitions(26, 18) == 0, "[26][18]"
assert env.rail.get_full_transitions(26, 19) == 32872, "[26][19]"
assert env.rail.get_full_transitions(26, 20) == 37408, "[26][20]"
assert env.rail.get_full_transitions(26, 21) == 0, "[26][21]"
assert env.rail.get_full_transitions(26, 22) == 0, "[26][22]"
assert env.rail.get_full_transitions(26, 23) == 0, "[26][23]"
assert env.rail.get_full_transitions(26, 24) == 0, "[26][24]"
assert env.rail.get_full_transitions(27, 0) == 0, "[27][0]"
assert env.rail.get_full_transitions(27, 1) == 0, "[27][1]"
assert env.rail.get_full_transitions(27, 2) == 0, "[27][2]"
assert env.rail.get_full_transitions(27, 3) == 0, "[27][3]"
assert env.rail.get_full_transitions(27, 4) == 0, "[27][4]"
assert env.rail.get_full_transitions(27, 5) == 0, "[27][5]"
assert env.rail.get_full_transitions(27, 6) == 0, "[27][6]"
assert env.rail.get_full_transitions(27, 7) == 0, "[27][7]"
assert env.rail.get_full_transitions(27, 8) == 0, "[27][8]"
assert env.rail.get_full_transitions(27, 9) == 0, "[27][9]"
assert env.rail.get_full_transitions(27, 10) == 0, "[27][10]"
assert env.rail.get_full_transitions(27, 11) == 32800, "[27][11]"
assert env.rail.get_full_transitions(27, 12) == 0, "[27][12]"
assert env.rail.get_full_transitions(27, 13) == 0, "[27][13]"
assert env.rail.get_full_transitions(27, 14) == 0, "[27][14]"
assert env.rail.get_full_transitions(27, 15) == 0, "[27][15]"
assert env.rail.get_full_transitions(27, 16) == 0, "[27][16]"
assert env.rail.get_full_transitions(27, 17) == 0, "[27][17]"
assert env.rail.get_full_transitions(27, 18) == 0, "[27][18]"
assert env.rail.get_full_transitions(27, 19) == 49186, "[27][19]"
assert env.rail.get_full_transitions(27, 20) == 2064, "[27][20]"
assert env.rail.get_full_transitions(27, 21) == 0, "[27][21]"
assert env.rail.get_full_transitions(27, 22) == 0, "[27][22]"
assert env.rail.get_full_transitions(27, 23) == 0, "[27][23]"
assert env.rail.get_full_transitions(27, 24) == 0, "[27][24]"
assert env.rail.get_full_transitions(28, 0) == 0, "[28][0]"
assert env.rail.get_full_transitions(28, 1) == 0, "[28][1]"
assert env.rail.get_full_transitions(28, 2) == 0, "[28][2]"
assert env.rail.get_full_transitions(28, 3) == 0, "[28][3]"
assert env.rail.get_full_transitions(28, 4) == 0, "[28][4]"
assert env.rail.get_full_transitions(28, 5) == 0, "[28][5]"
assert env.rail.get_full_transitions(28, 6) == 0, "[28][6]"
assert env.rail.get_full_transitions(28, 7) == 0, "[28][7]"
assert env.rail.get_full_transitions(28, 8) == 0, "[28][8]"
assert env.rail.get_full_transitions(28, 9) == 0, "[28][9]"
assert env.rail.get_full_transitions(28, 10) == 0, "[28][10]"
assert env.rail.get_full_transitions(28, 11) == 32800, "[28][11]"
assert env.rail.get_full_transitions(28, 12) == 0, "[28][12]"
assert env.rail.get_full_transitions(28, 13) == 0, "[28][13]"
assert env.rail.get_full_transitions(28, 14) == 0, "[28][14]"
assert env.rail.get_full_transitions(28, 15) == 0, "[28][15]"
assert env.rail.get_full_transitions(28, 16) == 0, "[28][16]"
assert env.rail.get_full_transitions(28, 17) == 0, "[28][17]"
assert env.rail.get_full_transitions(28, 18) == 0, "[28][18]"
assert env.rail.get_full_transitions(28, 19) == 32800, "[28][19]"
assert env.rail.get_full_transitions(28, 20) == 0, "[28][20]"
assert env.rail.get_full_transitions(28, 21) == 0, "[28][21]"
assert env.rail.get_full_transitions(28, 22) == 0, "[28][22]"
assert env.rail.get_full_transitions(28, 23) == 0, "[28][23]"
assert env.rail.get_full_transitions(28, 24) == 0, "[28][24]"
assert env.rail.get_full_transitions(29, 0) == 0, "[29][0]"
assert env.rail.get_full_transitions(29, 1) == 0, "[29][1]"
assert env.rail.get_full_transitions(29, 2) == 0, "[29][2]"
assert env.rail.get_full_transitions(29, 3) == 0, "[29][3]"
assert env.rail.get_full_transitions(29, 4) == 0, "[29][4]"
assert env.rail.get_full_transitions(29, 5) == 0, "[29][5]"
assert env.rail.get_full_transitions(29, 6) == 0, "[29][6]"
assert env.rail.get_full_transitions(29, 7) == 0, "[29][7]"
assert env.rail.get_full_transitions(29, 8) == 0, "[29][8]"
assert env.rail.get_full_transitions(29, 9) == 0, "[29][9]"
assert env.rail.get_full_transitions(29, 10) == 0, "[29][10]"
assert env.rail.get_full_transitions(29, 11) == 72, "[29][11]"
assert env.rail.get_full_transitions(29, 12) == 1025, "[29][12]"
assert env.rail.get_full_transitions(29, 13) == 1025, "[29][13]"
assert env.rail.get_full_transitions(29, 14) == 1025, "[29][14]"
assert env.rail.get_full_transitions(29, 15) == 1025, "[29][15]"
assert env.rail.get_full_transitions(29, 16) == 1025, "[29][16]"
assert env.rail.get_full_transitions(29, 17) == 1025, "[29][17]"
assert env.rail.get_full_transitions(29, 18) == 1025, "[29][18]"
assert env.rail.get_full_transitions(29, 19) == 2064, "[29][19]"
assert env.rail.get_full_transitions(29, 20) == 0, "[29][20]"
assert env.rail.get_full_transitions(29, 21) == 0, "[29][21]"
assert env.rail.get_full_transitions(29, 22) == 0, "[29][22]"
assert env.rail.get_full_transitions(29, 23) == 0, "[29][23]"
assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]"
def test_rail_env_action_required_info():
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env_always_action = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=10,
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=10,
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
# Reset the envs
env_always_action.reset(False, False, random_seed=5)
env_only_if_action_required.reset(False, False, random_seed=5)
assert env_only_if_action_required.rail.grid.tolist() == env_always_action.rail.grid.tolist()
for step in range(50):
print("step {}".format(step))
action_dict_always_action = dict()
action_dict_only_if_action_required = dict()
# Chose an action for each agent in the environment
for a in range(env_always_action.get_num_agents()):
action = np.random.choice(np.arange(4))
action_dict_always_action.update({a: action})
if step == 0 or info_only_if_action_required['action_required'][a]:
action_dict_only_if_action_required.update({a: action})
else:
print("[{}] not action_required {}, speed_counter={}".format(step, a,
env_always_action.agents[a].speed_counter))
obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
action_dict_always_action)
obs_only_if_action_required, rewards_only_if_action_required, done_only_if_action_required, info_only_if_action_required = env_only_if_action_required.step(
action_dict_only_if_action_required)
for a in range(env_always_action.get_num_agents()):
assert len(obs_always_action[a]) == len(obs_only_if_action_required[a])
for i in range(len(obs_always_action[a])):
assert len(obs_always_action[a][i]) == len(obs_only_if_action_required[a][i])
equal = np.array_equal(obs_always_action[a][i], obs_only_if_action_required[a][i])
if not equal:
for r in range(50):
for c in range(50):
assert np.array_equal(obs_always_action[a][i][(r, c)], obs_only_if_action_required[a][i][
(r, c)]), "[{}] a={},i={},{}\n{}\n\nvs.\n\n{}".format(step, a, i, (r, c),
obs_always_action[a][i][(r, c)],
obs_only_if_action_required[a][
i][(r, c)])
assert equal, \
"[{}] [{}][{}] {} vs. {}".format(step, a, i, obs_always_action[a][i],
obs_only_if_action_required[a][i])
assert np.array_equal(rewards_always_action[a], rewards_only_if_action_required[a])
assert np.array_equal(done_always_action[a], done_only_if_action_required[a])
assert info_always_action['action_required'][a] == info_only_if_action_required['action_required'][a]
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
if done_always_action['__all__']:
break
env_renderer.close_window()
def test_rail_env_malfunction_speed_info():
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=5,
grid_mode=False
),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False)
env_renderer = RenderTool(env, gl="PILSVG", )
for step in range(100):
action_dict = dict()
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = np.random.choice(np.arange(4))
action_dict.update({a: action})
obs, rewards, done, info = env.step(
action_dict)
assert 'malfunction' in info
for a in range(env.get_num_agents()):
assert info['malfunction'][a] >= 0
assert info['speed'][a] >= 0 and info['speed'][a] <= 1
assert info['speed'][a] == env.agents[a].speed_counter.speed
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
if done['__all__']:
break
env_renderer.close_window()
def test_sparse_generator_with_too_man_cities_does_not_break_down():
RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=3,
seed=5,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv())
def test_sparse_generator_with_illegal_params_aborts():
"""
Test that the constructor aborts if the initial parameters don't allow more than one city to be built.
"""
with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, ValueError):
RailEnv(width=6, height=6, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=3,
seed=5,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()).reset()
with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, ValueError):
RailEnv(width=60, height=60, rail_generator=sparse_rail_generator(
max_num_cities=1,
max_rails_between_cities=3,
seed=5,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()).reset()
def test_sparse_generator_changes_to_grid_mode():
"""
Test that grid mode is evoked and two cities are created when env is too small to find random cities.
We set the limit of the env such that two cities fit in grid mode but unlikely under random mode
we initiate random seed to be sure that we never create random cities.
"""
rail_env = RailEnv(width=10, height=20, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=2,
max_rail_pairs_in_city=1,
seed=15,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
with warnings.catch_warnings(record=True) as w:
rail_env.reset(True, True, random_seed=15)
assert "[WARNING]" in str(w[-1].message)
from test_utils import create_and_save_env
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.envs.line_generators import sparse_line_generator, line_from_file
def test_line_from_file_sparse():
"""
Test to see that all parameters are loaded as expected
Returns
-------
"""
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# Generate Sparse test env
rail_generator = sparse_rail_generator(max_num_cities=5,
seed=1,
grid_mode=False,
max_rails_between_cities=3,
max_rail_pairs_in_city=3,
)
line_generator = sparse_line_generator(speed_ration_map)
env = create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator,
line_generator=line_generator)
old_num_steps = env._max_episode_steps
old_num_agents = len(env.agents)
# Sparse generator
rail_generator = rail_from_file("./sparse_env_test.pkl")
line_generator = line_from_file("./sparse_env_test.pkl")
sparse_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator,
line_generator=line_generator)
sparse_env_from_file.reset(True, True)
# Assert loaded agent number is correct
assert sparse_env_from_file.get_num_agents() == old_num_agents
# Assert max steps is correct
assert sparse_env_from_file._max_episode_steps == old_num_steps
\ No newline at end of file
import random
from typing import Dict, List
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
class SingleAgentNavigationObs(ObservationBuilder):
"""
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__()
def reset(self):
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent_virtual_position, direction)
min_distances.append(
self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = MalfunctionParameters(malfunction_rate=1, # Rate of malfunction occurence
min_duration=3, # Minimal duration of malfunction
max_duration=3 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
obs, info = env.reset(False, False, random_seed=10)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].state = TrainState.MOVING
agent_halts = 0
total_down_time = 0
agent_old_position = env.agents[0].position
# Move target to unreachable position in order to not interfere with test
env.agents[0].target = (0, 0)
# Add in max episode steps because scheudule generator sets it to 0 for dummy data
env._max_episode_steps = 200
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
obs, all_rewards, done, _ = env.step(actions)
if done["__all__"]:
break
if env.agents[0].malfunction_handler.malfunction_down_counter > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
if agent_malfunctioning:
# Check that agent is not moving while malfunctioning
assert agent_old_position == env.agents[0].position
agent_old_position = env.agents[0].position
total_down_time += env.agents[0].malfunction_handler.malfunction_down_counter
# Check that the appropriate number of malfunctions is achieved
# Dipam: The number of malfunctions varies by seed
assert env.agents[0].malfunction_handler.num_malfunctions == 28, "Actual {}".format(
env.agents[0].malfunction_handler.num_malfunctions)
# Check that malfunctioning data was standing around
assert total_down_time > 0
def test_malfunction_process_statistically():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = MalfunctionParameters(malfunction_rate=1/5, # Rate of malfunction occurence
min_duration=5, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(True, True, random_seed=10)
env._max_episode_steps = 1000
env.agents[0].target = (0, 0)
# Next line only for test generation
agent_malfunction_list = [[] for i in range(2)]
agent_malfunction_list = [[0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1],
[0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent_idx in range(env.get_num_agents()):
# We randomly select an action
action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
# For generating tests only:
# agent_malfunction_list[agent_idx].append(
# env.agents[agent_idx].malfunction_handler.malfunction_down_counter)
assert env.agents[agent_idx].malfunction_handler.malfunction_down_counter == \
agent_malfunction_list[agent_idx][step]
env.step(action_dict)
def test_malfunction_before_entry():
"""Tests that malfunctions are working properly for agents before entering the environment!"""
# Set fixed malfunction duration for this test
stochastic_data = MalfunctionParameters(malfunction_rate=1/2, # Rate of malfunction occurrence
min_duration=10, # Minimal duration of malfunction
max_duration=10 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(False, False, random_seed=10)
env.agents[0].target = (0, 0)
# Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents
malfunction_values = [env.malfunction_generator(env.np_random).num_broken_steps for _ in range(1000)]
expected_value = (1 - np.exp(-0.5)) * 10
assert np.allclose(np.mean(malfunction_values), expected_value, rtol=0.1), "Mean values of malfunction don't match rate"
def test_malfunction_values_and_behavior():
"""
Test the malfunction counts down as desired
Returns
-------
"""
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001, # Rate of malfunction occurence
min_duration=10, # Minimal duration of malfunction
max_duration=10 # Max duration of malfunction
)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 20
# Assertions
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5]
for time_step in range(15):
# Move in the env
_, _, dones,_ = env.step(action_dict)
# Check that next_step decreases as expected
assert env.agents[0].malfunction_handler.malfunction_down_counter == assert_list[time_step]
if dones['__all__']:
break
def test_initial_malfunction():
stochastic_data = MalfunctionParameters(malfunction_rate=1/1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, random_seed=10)
env._max_episode_steps = 1000
print(env.agents[0].malfunction_handler)
env.agents[0].target = (0, 5)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay( # 0
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty when malfunctioning
),
Replay( # 1
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2,
reward=env.step_penalty # full step penalty when malfunctioning
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay( # 2
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1,
reward=env.step_penalty
), # malfunctioning ends: starting and running at speed 1.0
Replay( # 3
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0
),
Replay( # 4
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty # running at speed 1.0
)
],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], skip_reward_check=True)
def test_initial_malfunction_stop_moving():
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
env.reset()
env._max_episode_steps = 1000
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].state)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay( # 0
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty, # full step penalty when stopped
state=TrainState.READY_TO_DEPART
),
Replay( # 1
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty, # full step penalty when stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
#
Replay( # 2
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# we have stopped and do nothing --> should stand still
Replay( # 3
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# we start to move forward --> should go to next cell now
Replay( # 4
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.MOVING
),
Replay( # 5
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.STOPPED
),
Replay( # 6
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.MOVING
),
Replay( # 6
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.STOPPED
)
],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], activate_agents=False,
skip_reward_check=True, set_ready_to_depart=True, skip_action_required_check=True)
def test_initial_malfunction_do_nothing():
stochastic_data = MalfunctionParameters(malfunction_rate=1/70, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
)
env.reset()
env._max_episode_steps = 1000
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty, # full step penalty while malfunctioning
state=TrainState.READY_TO_DEPART
),
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=None,
malfunction=2,
reward=env.step_penalty, # full step penalty while malfunctioning
state=TrainState.MALFUNCTION_OFF_MAP
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=None,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
# we haven't started moving yet --> stay here
Replay(
position=None,
direction=Grid4TransitionsEnum.EAST,
action=None,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MALFUNCTION_OFF_MAP
),
Replay(
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
state=TrainState.MOVING
), # we start to move forward --> should go to next cell now
Replay(
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # step penalty for speed 1.0
state=TrainState.MOVING
)
],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], activate_agents=False,
skip_reward_check=True, set_ready_to_depart=True)
def tests_random_interference_from_outside():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_counter = SpeedCounter(speed=0.33)
env.reset(False, False, random_seed=10)
env_data = []
for step in range(200):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(2)
_, reward, dones, _ = env.step(action_dict)
# Append the rewards of the first trial
env_data.append((reward[0], env.agents[0].position))
assert reward[0] == env_data[step][0]
assert env.agents[0].position == env_data[step][1]
if dones['__all__']:
break
# Run the same test as above but with an external random generator running
# Check that the reward stays the same
rail, rail_map, optionals = make_simple_rail2()
random.seed(47)
np.random.seed(1234)
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_counter = SpeedCounter(speed=0.33)
env.reset(False, False, random_seed=10)
dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
for step in range(200):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(2)
# Do dummy random number generations
random.shuffle(dummy_list)
np.random.rand()
_, reward, dones, _ = env.step(action_dict)
assert reward[0] == env_data[step][0]
assert env.agents[0].position == env_data[step][1]
if dones['__all__']:
break
def test_last_malfunction_step():
"""
Test to check that agent moves when it is not malfunctioning
"""
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_counter = SpeedCounter(speed=1./3.)
env.agents[0].initial_position = (6, 6)
env.agents[0].initial_direction = 2
env.agents[0].target = (0, 3)
env._max_episode_steps = 1000
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].state = TrainState.MOVING
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
# env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_handler.malfunction_down_counter = 0
env_data = []
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# Go forward all the time
action_dict[agent.handle] = RailEnvActions(2)
if env.agents[0].malfunction_handler.malfunction_down_counter < 1:
agent_can_move = True
# Store the position before and after the step
pre_position = env.agents[0].speed_counter.counter
_, reward, _, _ = env.step(action_dict)
# Check if the agent is still allowed to move in this step
if env.agents[0].malfunction_handler.malfunction_down_counter > 0:
agent_can_move = False
post_position = env.agents[0].speed_counter.counter
# Assert that the agent moved while it was still allowed
if agent_can_move:
assert pre_position != post_position
else:
assert post_position == pre_position
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from multiprocessing.pool import Pool
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
def test_multiprocessing_tree_obs():
number_of_agents = 5
rail, rail_map, optionals = make_simple_rail()
optionals['agents_hints']['num_agents'] = number_of_agents
obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=number_of_agents,
obs_builder_object=obs_builder)
env.reset(True, True)
pool = Pool()
pool.map(obs_builder.get, range(number_of_agents))
def main():
test_multiprocessing_tree_obs()
if __name__ == "__main__":
main()
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
env.reset()
env._max_episode_steps = 1000
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5,
),
Replay(
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_LEFT,
reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
state=TrainState.MOVING
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, #
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, #
state=TrainState.MOVING
),
# Replay(
# position=(3, 5),
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE
# ),
# Replay(
# position=(3, 5),
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE
# )
],
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
target=(3, 5),
speed=0.5
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True,
set_ready_to_depart=True)
assert env.agents[0].state == TrainState.DONE
def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True)
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
env._max_episode_steps = 1000
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5,
),
Replay(
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, # running at speed 0.5
state=TrainState.MOVING
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # done
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, # already done
state=TrainState.MOVING
),
# Replay(
# position=None,
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE_REMOVED
# ),
# Replay(
# position=None,
# direction=Grid4TransitionsEnum.WEST,
# action=None,
# reward=env.global_reward, # already done
# status=RailAgentStatus.DONE_REMOVED
# )
],
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
target=(3, 5),
speed=0.5
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True,
set_ready_to_depart=True)
assert env.agents[0].state == TrainState.DONE
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.step_utils.states import TrainState
def test_return_to_ready_to_depart():
"""
When going from ready to depart to malfunction off map, if do nothing is provided, should return to ready to depart
"""
stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence
min_duration=0, # Minimal duration of malfunction
max_duration=0 # Max duration of malfunction
)
rail, _, optionals = make_simple_rail()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 100
for _ in range(3):
env.step({0: RailEnvActions.DO_NOTHING})
env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
env.step({0: RailEnvActions.DO_NOTHING})
assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
for _ in range(2):
env.step({0: RailEnvActions.DO_NOTHING})
assert env.agents[0].state == TrainState.READY_TO_DEPART
def test_ready_to_depart_to_stopped():
"""
When going from ready to depart to malfunction off map, if stopped is provided, should go to stopped
"""
stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence
min_duration=0, # Minimal duration of malfunction
max_duration=0 # Max duration of malfunction
)
rail, _, optionals = make_simple_rail()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 100
for _ in range(3):
env.step({0: RailEnvActions.STOP_MOVING})
assert env.agents[0].state == TrainState.READY_TO_DEPART
env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
env.step({0: RailEnvActions.STOP_MOVING})
assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
for _ in range(2):
env.step({0: RailEnvActions.STOP_MOVING})
assert env.agents[0].state == TrainState.STOPPED
def test_malfunction_no_phase_through():
"""
A moving train shouldn't phase through a malfunctioning train
"""
stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence
min_duration=0, # Minimal duration of malfunction
max_duration=0 # Max duration of malfunction
)
rail, _, optionals = make_simple_rail()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=2,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
)
env.reset(False, False, random_seed=10)
for _ in range(5):
env.step({0: RailEnvActions.MOVE_FORWARD, 1: RailEnvActions.MOVE_FORWARD})
env.agents[1].malfunction_handler._set_malfunction_down_counter(10)
for _ in range(3):
env.step({0: RailEnvActions.MOVE_FORWARD, 1: RailEnvActions.DO_NOTHING})
assert env.agents[0].state == TrainState.STOPPED
assert env.agents[0].position == (3, 5)
\ No newline at end of file
......@@ -4,86 +4,59 @@
Tests for `flatland` package.
"""
from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
from importlib_resources import path
import flatland.utils.rendertools as rt
from flatland.core.env_observation_builder import TreeObsForRailEnv
import images.test
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import empty_rail_generator
import pytest
def checkFrozenImage(oRT, sFileImage, resave=False):
sDirRoot = "."
sDirImages = sDirRoot + "/images/"
img_test = oRT.getImage()
img_test = oRT.get_image()
if resave:
np.savez_compressed(sDirImages + sFileImage, img=img_test)
return
# this is now just for convenience - the file is not read back
np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test)
image_store = np.load(sDirImages + sFileImage)
img_expected = image_store["img"]
assert (img_test.shape == img_expected.shape)
assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \
"Image {} does not match".format(sFileImage)
with path(images, sFileImage) as file_in:
np.load(file_in)
# TODO fails!
# assert (img_test.shape == img_expected.shape) \ # noqa: E800
# assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \ # noqa: E800
# "Image {} does not match".format(sFileImage) \ # noqa: E800
@pytest.mark.skip("Only needed for visual editor, Flatland 3 line generator won't allow empty enviroment")
def test_render_env(save_new_images=False):
# random.seed(100)
np.random.seed(100)
oEnv = RailEnv(width=10, height=10,
rail_generator=random_rail_generator(),
number_of_agents=0,
# obs_builder_object=GlobalObsForRailEnv())
obs_builder_object=TreeObsForRailEnv(max_depth=2)
)
sfTestEnv = "env-data/tests/test1.npy"
oEnv.rail.load_transition_map(sfTestEnv)
oRT = rt.RenderTool(oEnv)
oRT.renderEnv()
oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
oEnv.reset()
oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
oRT = rt.RenderTool(oEnv, gl="PILSVG")
oRT.render_env(show=False)
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
oRT = rt.RenderTool(oEnv, gl="PIL")
oRT.renderEnv()
oRT.render_env()
checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
# disable the tree / observation tests until env-agent save/load is available
if False:
lVisits = oRT.getTreeFromRail(
oEnv.agents_position[0],
oEnv.agents_direction[0],
nDepth=17, bPlot=True)
checkFrozenImage("env-tree-spatial.png")
plt.figure(figsize=(8, 8))
xyTarg = oRT.env.agents_target[0]
visitDest = oRT.plotTree(lVisits, xyTarg)
checkFrozenImage("env-tree-graph.png")
plt.figure(figsize=(10, 10))
oRT.renderEnv()
oRT.plotPath(visitDest)
checkFrozenImage("env-path.png")
def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True)
else:
print("Run 'python test_rendertools.py save' to regenerate images")
print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images")
test_render_env()
if __name__ == "__main__":
main()
\ No newline at end of file
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, empty_rail_generator
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.step_utils.states import TrainState
def test_empty_rail_generator():
n_agents = 2
x_dim = 5
y_dim = 10
# Check that a random level at with correct parameters is generated
rail, _ = empty_rail_generator().generate(width=x_dim, height=y_dim, num_agents=n_agents)
# Check the dimensions
assert rail.grid.shape == (y_dim, x_dim)
# Check that no grid was generated
assert np.count_nonzero(rail.grid) == 0
def test_rail_from_grid_transition_map():
rail, rail_map, optionals = make_simple_rail()
n_agents = 2
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx]._set_state(TrainState.MOVING)
nr_rail_elements = np.count_nonzero(env.rail.grid)
# Check if the number of non-empty rail cells is ok
assert nr_rail_elements == 16
# Check that agents are placed on a rail
for a in env.agents:
assert env.rail.grid[a.position] != 0
assert env.get_num_agents() == n_agents
def tests_rail_from_file():
file_name = "test_with_distance_map.pkl"
# Test to save and load file with distance map.
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
#env.save(file_name)
RailEnvPersister.save(env, file_name)
dist_map_shape = np.shape(env.distance_map.get())
rails_initial = env.rail.grid
agents_initial = env.agents
env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
rails_loaded = env.rail.grid
agents_loaded = env.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert np.shape(env.distance_map.get()) == dist_map_shape
assert env.distance_map.get() is not None
# Test to save and load file without distance map.
file_name_2 = "test_without_distance_map.pkl"
env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(),
number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
env2.reset()
#env2.save(file_name_2)
RailEnvPersister.save(env2, file_name_2)
rails_initial_2 = env2.rail.grid
agents_initial_2 = env2.agents
env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
line_generator=line_from_file(file_name_2), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env2.reset()
rails_loaded_2 = env2.rail.grid
agents_loaded_2 = env2.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_2):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
assert agents_initial_2 == agents_loaded_2
assert not hasattr(env2.obs_builder, "distance_map")
# Test to save with distance map and load without
env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env3.reset()
rails_loaded_3 = env3.rail.grid
agents_loaded_3 = env3.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded_3):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded_3))
assert agents_initial == agents_loaded_3
assert not hasattr(env2.obs_builder, "distance_map")
# Test to save without distance map and load with generating distance map
env4 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name_2),
line_generator=line_from_file(file_name_2),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
)
env4.reset()
rails_loaded_4 = env4.rail.grid
agents_loaded_4 = env4.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_4):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
# Check that no distance map was saved
assert not hasattr(env2.obs_builder, "distance_map")
assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
assert agents_initial_2 == agents_loaded_4
# Check that distance map was generated with correct shape
assert env4.distance_map.get() is not None
assert np.shape(env4.distance_map.get()) == dist_map_shape
def main():
tests_rail_from_file()
if __name__ == "__main__":
main()
import numpy as np
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.step_utils.states import TrainState
def test_get_global_observation():
number_of_agents = 20
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=6,
max_rails_between_cities=4,
seed=15,
grid_mode=False
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=number_of_agents,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
for i in range(len(env.agents)):
agent: EnvAgent = env.agents[i]
print("[{}] state={}, position={}, target={}, initial_position={}".format(i, agent.state, agent.position,
agent.target,
agent.initial_position))
for i, agent in enumerate(env.agents):
obs_agents_state = obs[i][1]
obs_targets = obs[i][2]
# test first channel of obs_targets: own target
nr_agents = np.count_nonzero(obs_targets[:, :, 0])
assert nr_agents == 1, "agent {}: something wrong with own target, found {}".format(i, nr_agents)
# test second channel of obs_targets: other agent's target
for r in range(env.height):
for c in range(env.width):
_other_agent_target = 0
for other_i, other_agent in enumerate(env.agents):
if other_agent.target == (r, c):
_other_agent_target = 1
break
assert obs_targets[(r, c)][
1] == _other_agent_target, "agent {}: at {} expected to be other agent's target = {}".format(
i, (r, c),
_other_agent_target)
# test first channel of obs_agents_state: direction at own position
for r in range(env.height):
for c in range(env.width):
if (agent.state.is_on_map_state() or agent.state == TrainState.DONE) and (
r, c) == agent.position:
assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
"agent {} in state {} at {} expected to contain own direction {}, found {}" \
.format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
elif (agent.state == TrainState.READY_TO_DEPART) and (r, c) == agent.initial_position:
assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
"agent {} in state {} at {} expected to contain own direction {}, found {}" \
.format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
else:
assert np.isclose(obs_agents_state[(r, c)][0], -1), \
"agent {} in state {} at {} expected contain -1 found {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][0])
# test second channel of obs_agents_state: direction at other agents position
for r in range(env.height):
for c in range(env.width):
has_agent = False
for other_i, other_agent in enumerate(env.agents):
if i == other_i:
continue
if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, TrainState.DONE] and (
r, c) == other_agent.position:
assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \
"agent {} in state {} at {} should see other agent with direction {}, found = {}" \
.format(i, agent.state, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
has_agent = True
if not has_agent:
assert np.isclose(obs_agents_state[(r, c)][1], -1), \
"agent {} in state {} at {} should see no other agent direction (-1), found = {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][1])
# test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid
for r in range(env.height):
for c in range(env.width):
has_agent = False
for other_i, other_agent in enumerate(env.agents):
if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED,
TrainState.DONE] and other_agent.position == (r, c):
assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_handler.malfunction_down_counter), \
"agent {} in state {} at {} should see agent malfunction {}, found = {}" \
.format(i, agent.state, (r, c), other_agent.malfunction_handler.malfunction_down_counter,
obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed)
has_agent = True
if not has_agent:
assert np.isclose(obs_agents_state[(r, c)][2], -1), \
"agent {} in state {} at {} should see no agent malfunction (-1), found = {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], -1), \
"agent {} in state {} at {} should see no agent speed (-1), found = {}" \
.format(i, agent.state, (r, c), obs_agents_state[(r, c)][3])
# test fifth channel of obs_agents_state: number of agents ready to depart in to this cell
for r in range(env.height):
for c in range(env.width):
count = 0
for other_i, other_agent in enumerate(env.agents):
if other_agent.state == TrainState.READY_TO_DEPART and other_agent.initial_position == (r, c):
count += 1
assert np.isclose(obs_agents_state[(r, c)][4], count), \
"agent {} in state {} at {} should see {} agents ready to depart, found{}" \
.format(i, agent.state, (r, c), count, obs_agents_state[(r, c)][4])
from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file, \
single_malfunction_generator, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from flatland.envs.persistence import RailEnvPersister
import pytest
def test_malfanction_from_params():
"""
Test loading malfunction from
Returns
-------
"""
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env.reset()
assert env.malfunction_process_data.malfunction_rate == 1000
assert env.malfunction_process_data.min_duration == 2
assert env.malfunction_process_data.max_duration == 5
def test_malfanction_to_and_from_file():
"""
Test loading malfunction from
Returns
-------
"""
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env.reset()
#env.save("./malfunction_saving_loading_tests.pkl")
RailEnvPersister.save(env, "./malfunction_saving_loading_tests.pkl")
malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
env2 = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env2.reset()
assert env2.malfunction_process_data == env.malfunction_process_data
assert env2.malfunction_process_data.malfunction_rate == 1000
assert env2.malfunction_process_data.min_duration == 2
assert env2.malfunction_process_data.max_duration == 5
@pytest.mark.skip("Single malfunction generator is deprecated")
def test_single_malfunction_generator():
"""
Test single malfunction generator
Returns
-------
"""
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=3,
malfunction_duration=5)
)
for test in range(10):
env.reset()
action_dict = dict()
tot_malfunctions = 0
print(test)
for i in range(10):
for agent in env.agents:
# Go forward all the time
action_dict[agent.handle] = RailEnvActions(2)
_, _, dones, _ = env.step(action_dict)
if dones['__all__']:
break
for agent in env.agents:
# Go forward all the time
tot_malfunctions += agent.malfunction_handler.num_malfunctions
assert tot_malfunctions == 1