diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 9550b71348c177adbaa30b4bd3e5307ba2e855ce..7a673bcf9ba46c574db0983e3c52257e4a07358e 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -414,13 +414,13 @@ class GridTransitionMap(TransitionMap): # loop over available outbound directions (indices) for rcPos self.set_transitions(rcPos, 0) - incomping_connections = np.zeros(4) + incoming_connections = np.zeros(4) for iDirOut in np.arange(4): gdRC = gDir2dRC[iDirOut] # row,col increment gPos2 = grcPos + gdRC # next cell in that direction # Check the adjacent cell is within bounds - # if not, then this transition is invalid! + # if not, then ignore it for the count of incoming connections if np.any(gPos2 < 0): continue if np.any(gPos2 >= grcMax): @@ -432,23 +432,23 @@ class GridTransitionMap(TransitionMap): for orientation in range(4): connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut)) if connected > 0: - incomping_connections[iDirOut] = 1 + incoming_connections[iDirOut] = 1 - number_of_incoming = np.sum(incomping_connections) + number_of_incoming = np.sum(incoming_connections) # Only one incoming direction --> Straight line if number_of_incoming == 1: for direction in range(4): - if incomping_connections[direction] > 0: + if incoming_connections[direction] > 0: self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) # Connect all incoming connections if number_of_incoming == 2: - connect_directions = np.argwhere(incomping_connections > 0) + connect_directions = np.argwhere(incoming_connections > 0) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) # Find feasible connection fro three entries if number_of_incoming == 3: - hole = np.argwhere(incomping_connections < 1)[0][0] + hole = np.argwhere(incoming_connections < 1)[0][0] connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index deaabd02fc1f842fae8c36d9056db632791e67a8..6e6665af88c9e8b31e1a689815edb7aaada342f9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -2,7 +2,7 @@ Definition of the RailEnv environment. """ # TODO: _ this is a global method --> utils or remove later - +import warnings from enum import IntEnum import msgpack @@ -228,7 +228,7 @@ class RailEnv(Environment): rcPos = (r, c) check = self.rail.cell_neighbours_valid(rcPos, True) if not check: - print("WARNING: Invalid grid at {} -> {}".format(rcPos, check)) + warnings.warn("Invalid grid at {} -> {}".format(rcPos, check)) if replace_agents: self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 28978ca3e5b958a51c578f6be5e8c87b77baaa97..c5fe4860783f242f21c97c55a9119d8918454a96 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -2,11 +2,87 @@ from typing import Tuple import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [horizontal_straight] * 2 + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + +def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ _ + # \ + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [horizontal_straight] * 2 + [simple_switch_west_east_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + +def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | # | @@ -16,15 +92,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: # | # | # | - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 12e0c092a37a475ab6e7dde21c665778e06f5e59..62df397d34b06755465ca0c9f664b9117c87243f 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -1,6 +1,6 @@ import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv @@ -11,15 +11,8 @@ from flatland.envs.rail_env import RailEnv def test_walker(): # _ _ _ - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() + cells = transitions.transition_list dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 7acd58ed0337745f645db6dcc24a70ecb0b64305..98c276f894b51685ce0edf43f6bd1b1137d46eb0 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -10,13 +10,13 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -from flatland.utils.simple_rail import make_simple_rail +from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail """Test predictions for `flatland` package.""" def test_dummy_predictor(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map = make_simple_rail2() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], @@ -89,7 +89,7 @@ def test_dummy_predictor(rendering=False): expected_actions = np.array([[0.], [2.], [2.], - [1.], + [2.], [2.], [2.], [2.], @@ -226,7 +226,7 @@ def test_shortest_path_predictor(rendering=False): def test_shortest_path_predictor_conflicts(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map = make_invalid_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 71dc87ceddde986be763491d28dd2b70673632f4..7ebbbb1461e24aae4c2319f51a9bb4abb2d3b25c 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent @@ -49,15 +48,6 @@ def test_save_load(): def test_rail_environment_single_agent(): - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - # We instantiate the following map on a 3x3 grid # _ _ # / \/ \ @@ -65,6 +55,7 @@ def test_rail_environment_single_agent(): # \_/\_/ transitions = RailEnvTransitions() + cells = transitions.transition_list vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) @@ -139,7 +130,7 @@ test_rail_environment_single_agent() def test_dead_end(): - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() straight_vertical = int('1000000000100000', 2) # Case 1 - straight straight_horizontal = transitions.rotate_transition(straight_vertical,