Commit 0ac5cff2 authored by u214892's avatar u214892
Browse files

#147 fix tests

parent f5e70812
Pipeline #1835 passed with stages
in 30 minutes and 32 seconds
......@@ -414,13 +414,13 @@ class GridTransitionMap(TransitionMap):
# loop over available outbound directions (indices) for rcPos
self.set_transitions(rcPos, 0)
incomping_connections = np.zeros(4)
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
......@@ -432,23 +432,23 @@ class GridTransitionMap(TransitionMap):
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incomping_connections[iDirOut] = 1
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incomping_connections)
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line
if number_of_incoming == 1:
for direction in range(4):
if incomping_connections[direction] > 0:
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
connect_directions = np.argwhere(incomping_connections > 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection fro three entries
if number_of_incoming == 3:
hole = np.argwhere(incomping_connections < 1)[0][0]
hole = np.argwhere(incoming_connections < 1)[0][0]
connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
......
......@@ -2,7 +2,7 @@
Definition of the RailEnv environment.
"""
# TODO: _ this is a global method --> utils or remove later
import warnings
from enum import IntEnum
import msgpack
......@@ -228,7 +228,7 @@ class RailEnv(Environment):
rcPos = (r, c)
check = self.rail.cell_neighbours_valid(rcPos, True)
if not check:
print("WARNING: Invalid grid at {} -> {}".format(rcPos, check))
warnings.warn("Invalid grid at {} -> {}".format(rcPos, check))
if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
......
......@@ -2,11 +2,87 @@ from typing import Tuple
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _\ _ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[horizontal_straight] * 2 + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ _\ _ _ _ _ _ _
# \
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[horizontal_straight] * 2 + [simple_switch_west_east_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
......@@ -16,15 +92,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
......
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv
......@@ -11,15 +11,8 @@ from flatland.envs.rail_env import RailEnv
def test_walker():
# _ _ _
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
transitions = RailEnvTransitions()
cells = transitions.transition_list
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
......
......@@ -10,13 +10,13 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
"""Test predictions for `flatland` package."""
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map = make_simple_rail2()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
......@@ -89,7 +89,7 @@ def test_dummy_predictor(rendering=False):
expected_actions = np.array([[0.],
[2.],
[2.],
[1.],
[2.],
[2.],
[2.],
[2.],
......@@ -226,7 +226,7 @@ def test_shortest_path_predictor(rendering=False):
def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map = make_invalid_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
......
......@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
......@@ -49,15 +48,6 @@ def test_save_load():
def test_rail_environment_single_agent():
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
# We instantiate the following map on a 3x3 grid
# _ _
# / \/ \
......@@ -65,6 +55,7 @@ def test_rail_environment_single_agent():
# \_/\_/
transitions = RailEnvTransitions()
cells = transitions.transition_list
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
......@@ -139,7 +130,7 @@ test_rail_environment_single_agent()
def test_dead_end():
transitions = Grid4Transitions([])
transitions = RailEnvTransitions()
straight_vertical = int('1000000000100000', 2) # Case 1 - straight
straight_horizontal = transitions.rotate_transition(straight_vertical,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment