diff --git a/benchmarks/benchmark_all_examples.py b/benchmarks/benchmark_all_examples.py index e45537a4dd2dca52482f53f9b865f2302cb660c6..8af61ef4907a234c8e3ae37eb29692e89e079beb 100644 --- a/benchmarks/benchmark_all_examples.py +++ b/benchmarks/benchmark_all_examples.py @@ -14,7 +14,7 @@ for entry in [entry for entry in importlib_resources.contents('examples') if not pkg_resources.resource_isdir('examples', entry) and entry.endswith(".py") and '__init__' not in entry - and 'demo.py' not in entry + and 'DELETE' not in entry ]: print("*****************************************************************") print("Benchmarking {}".format(entry)) diff --git a/benchmarks/profile_all_examples.py b/benchmarks/profile_all_examples.py index 8015e23a04e8096d3302b3b3a717221d0d872773..53b40bfc7acfc9c56bc2ced9c8d2c4a968867ee2 100644 --- a/benchmarks/profile_all_examples.py +++ b/benchmarks/profile_all_examples.py @@ -29,5 +29,6 @@ for entry in [entry for entry in importlib_resources.contents('examples') if and entry.endswith(".py") and '__init__' not in entry and 'demo.py' not in entry + and 'DELETE' not in entry ]: profile('examples', entry) diff --git a/benchmarks/run_all_examples.py b/benchmarks/run_all_examples.py index 1b3e3be066989e18af3f36e1dd73ded37c0bc6cf..b169902d096c1150a071f330eb174a2c22a5c78f 100644 --- a/benchmarks/run_all_examples.py +++ b/benchmarks/run_all_examples.py @@ -8,6 +8,8 @@ from importlib_resources import path from benchmarks.benchmark_utils import swap_attr +print("GRRRRRRRR run_all_examples.py") + for entry in [entry for entry in importlib_resources.contents('examples') if not pkg_resources.resource_isdir('examples', entry) and entry.endswith(".py") @@ -24,6 +26,11 @@ for entry in [entry for entry in importlib_resources.contents('examples') if print("Running {}".format(entry)) print("*****************************************************************") with swap_attr(sys, "stdin", StringIO("q")): - runpy.run_path(file_in, run_name="__main__", init_globals={ - 'argv': ['--sleep-for-animation=False'] - }) + try: + runpy.run_path(file_in, run_name="__main__", init_globals={ + 'argv': ['--sleep-for-animation=False'] + }) + except Exception as e: + print(e) + print("runpy done.") + print("Done with {}".format(entry)) diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 317372da390693dd51b53d411c4d5615582183b0..6c6add683bf5b7694d939bbe1a590617fb069d3e 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -10,6 +10,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.misc import str2bool from flatland.utils.rendertools import RenderTool random.seed(100) @@ -51,7 +52,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent.position, direction) - min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) + min_distances.append( + self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) @@ -70,7 +72,7 @@ def main(args): sleep_for_animation = True for o, a in opts: if o in ("--sleep-for-animation"): - sleep_for_animation = bool(a) + sleep_for_animation = str2bool(a) else: assert False, "unhandled option" diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 9238a2af4137e37e9d79bc3c1aaade2bb987403e..92fbb37b2a0d8ff583b3056e418d11026f261032 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -11,6 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.misc import str2bool from flatland.utils.ordered_set import OrderedSet from flatland.utils.rendertools import RenderTool @@ -110,7 +111,7 @@ def main(args): sleep_for_animation = True for o, a in opts: if o in ("--sleep-for-animation"): - sleep_for_animation = bool(a) + sleep_for_animation = str2bool(a) else: assert False, "unhandled option" diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 6162b918734eb311752675e75e203d90e5558c1c..d33d49252618e2a22c6e5c13ae1d434807224ea2 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -51,4 +51,5 @@ env.reset() env_renderer = RenderTool(env) env_renderer.render_env(show=True) -input("Press Enter to continue...") +# uncomment to keep the renderer open +# input("Press Enter to continue...") diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index fbadbd657c36fa1dadf0bca65cff3e9cccd269ea..388128d0d246d73f0236b054a3228ec20c46864e 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -19,4 +19,5 @@ env.reset() env_renderer = RenderTool(env) env_renderer.render_env(show=True, show_predictions=False, show_observations=False) -input("Press Enter to continue...") +# uncomment to keep the renderer open +#input("Press Enter to continue...") diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 6db9ba5abbd0999ef3896e733516ed6b3e498bae..34abee096a043b73f53de8eed42a2e2b73ec1cc5 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -33,4 +33,5 @@ env.reset() env_renderer = RenderTool(env, gl="PIL") env_renderer.render_env(show=True) -input("Press Enter to continue...") +# uncomment to keep the renderer open +#input("Press Enter to continue...") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 6df6d4af3076b3d9659aadbd55296b667dc7d6db..61ca23840d6c59e75d75d5a15fa9fb8118f45b95 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -37,9 +37,7 @@ for step in range(100): i = 0 while i < len(cmds): if cmds[i] == 'q': - import sys - - sys.exit() + break elif cmds[i] == 's': obs, all_rewards, done, _ = env.step(action_dict) action_dict = {} @@ -50,5 +48,4 @@ for step in range(100): action_dict[agent_id] = action i = i + 1 i += 1 - - env_renderer.render_env(show=True, frames=True) + env_renderer.render_env(show=True, frames=True) diff --git a/examples/simple_example_city_railway_generator.py b/examples/simple_example_city_railway_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ce96932ccf49c9d71a97431bd0bdabd4cdeaf576 --- /dev/null +++ b/examples/simple_example_city_railway_generator.py @@ -0,0 +1,60 @@ +import os + +import numpy as np + +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators_city_generator import city_generator +from flatland.envs.schedule_generators import city_schedule_generator +from flatland.utils.rendertools import RenderTool, AgentRenderVariant + +os.mkdir("./../render_output/") + +for itrials in np.arange(1, 15, 1): + print(itrials, "generate new city") + + # init seed + np.random.seed(itrials) + + # select distance function used in a-star path finding + dist_fun = Vec2d.get_manhattan_distance + dfsel = (itrials - 1) % 3 + if dfsel == 1: + dist_fun = Vec2d.get_euclidean_distance + elif dfsel == 2: + dist_fun = Vec2d.get_chebyshev_distance + + # create RailEnv and use the city_generator to create a map + env = RailEnv(width=40 + np.random.choice(100), + height=40 + np.random.choice(100), + rail_generator=city_generator(num_cities=5 + np.random.choice(10), + city_size=10 + np.random.choice(5), + allowed_rotation_angles=np.arange(0, 360, 6), + max_number_of_station_tracks=4 + np.random.choice(4), + nbr_of_switches_per_station_track=2 + np.random.choice(2), + connect_max_nbr_of_shortes_city=2 + np.random.choice(4), + do_random_connect_stations=itrials % 2 == 0, + a_star_distance_function=dist_fun, + seed=itrials, + print_out_info=False + ), + schedule_generator=city_schedule_generator(), + number_of_agents=10000, + obs_builder_object=GlobalObsForRailEnv()) + + # reset to initialize agents_static + env_renderer = RenderTool(env, gl="PILSVG", screen_width=1400, screen_height=1000, + agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) + + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + + # store rendered file into render_output if the path exists + env_renderer.gl.save_image( + os.path.join( + "./../render_output/", + "flatland_frame_{:04d}.png".format(itrials) + )) + + # close the renderer / window + env_renderer.close_window() diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index d1652a38c7ecd45cbdd28522b6aeeb28683c4736..a049ae260860e946ac8b27f9e83ab11bf4ed2920 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,18 +1,29 @@ -from flatland.core.grid.grid4_utils import validate_new_transition +import numpy as np + +from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance +from flatland.core.grid.grid_utils import IntVector2DArray +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.transition_map import GridTransitionMap from flatland.utils.ordered_set import OrderedSet -class AStarNode(): +class AStarNode: """A node class for A* Pathfinding""" - def __init__(self, parent=None, pos=None): + def __init__(self, pos: IntVector2D, parent=None): self.parent = parent - self.pos = pos - self.g = 0 - self.h = 0 - self.f = 0 + self.pos: IntVector2D = pos + self.g = 0.0 + self.h = 0.0 + self.f = 0.0 def __eq__(self, other): + """ + + Parameters + ---------- + other : AStarNode + """ return self.pos == other.pos def __hash__(self): @@ -26,20 +37,25 @@ class AStarNode(): self.f = other.f -def a_star(rail_trans, rail_array, start, end): +def a_star(grid_map: GridTransitionMap, + start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. """ - rail_shape = rail_array.shape - start_node = AStarNode(None, start) - end_node = AStarNode(None, end) + rail_shape = grid_map.grid.shape + + tmp = np.zeros(rail_shape) - 10 + + start_node = AStarNode(start, None) + end_node = AStarNode(end, None) open_nodes = OrderedSet() closed_nodes = OrderedSet() open_nodes.add(start_node) while len(open_nodes) > 0: - # get node with current shortest path (lowest f) + # get node with current shortest est. path (lowest f) current_node = None for item in open_nodes: if current_node is None: @@ -59,6 +75,7 @@ def a_star(rail_trans, rail_array, start, end): while current is not None: path.append(current.pos) current = current.parent + # return reversed path return path[::-1] @@ -68,17 +85,21 @@ def a_star(rail_trans, rail_array, start, end): prev_pos = current_node.parent.pos else: prev_pos = None + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: - node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + # update the "current" pos + node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos) + + # is node_pos inside the grid? if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: continue # validate positions - if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos): continue # create new node - new_node = AStarNode(current_node, node_pos) + new_node = AStarNode(node_pos, current_node) children.append(new_node) # loop through children @@ -88,13 +109,13 @@ def a_star(rail_trans, rail_array, start, end): continue # create the f, g, and h values - child.g = current_node.g + 1 - # this heuristic favors diagonal paths: - # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \# noqa: E800 + child.g = current_node.g + 1.0 # this heuristic avoids diagonal paths - child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) + child.h = a_star_distance_function(child.pos, end_node.pos) child.f = child.g + child.h + tmp[child.pos[0]][child.pos[1]] = child.f + # already in the open list? if child in open_nodes: continue diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py index d64a160b6b422b5984ae675ab4544604a74b1337..98652459d7a7ac7f1694ac53fe1d0a12880ab8b2 100644 --- a/flatland/core/grid/grid4_utils.py +++ b/flatland/core/grid/grid4_utils.py @@ -1,7 +1,8 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid_utils import IntVector2DArray -def get_direction(pos1, pos2) -> Grid4TransitionsEnum: +def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum: """ Assumes pos1 and pos2 are adjacent location on grid. Returns direction (int) that can be used with transitions. @@ -23,45 +24,6 @@ def mirror(dir): return (dir + 2) % 4 -def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): - # start by getting direction used to get to current node - # and direction from current node to possible child node - new_dir = get_direction(current_pos, new_pos) - if prev_pos is not None: - current_dir = get_direction(prev_pos, current_pos) - else: - current_dir = new_dir - # create new transition that would go to child - new_trans = rail_array[current_pos] - if prev_pos is None: - if new_trans == 0: - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # check if matches existing layout - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - if new_pos == end_pos: - # need to validate end pos setup as well - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # need to flip direction because of how end points are defined - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # check if matches existing layout - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - - if not rail_trans.is_valid(new_trans_e): - return False - - # is transition is valid? - return rail_trans.is_valid(new_trans) - - def get_new_position(position, movement): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ if movement == Grid4TransitionsEnum.NORTH: diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py index 09e664f2c85f5919f8769653d91ccd7f6e621ec2..fe4a381fb8458aa61bd54be9f2ba105297400402 100644 --- a/flatland/core/grid/grid_utils.py +++ b/flatland/core/grid/grid_utils.py @@ -1,7 +1,244 @@ +from typing import Tuple, Callable, List, Type + import numpy as np +Vector2D: Type = Tuple[float, float] +IntVector2D: Type = Tuple[int, int] + +IntVector2DArray: Type = List[IntVector2D] +IntVector2DArrayArray: Type = List[List[IntVector2D]] + +Vector2DArray: Type = List[Vector2D] +Vector2DArrayArray: Type = List[List[Vector2D]] + +IntVector2DDistance: Type = Callable[[IntVector2D, IntVector2D], float] + + +class Vec2dOperations: + + @staticmethod + def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool: + """ + vector operation : node_a + node_b + + :param node_a: tuple with coordinate (x,y) or 2d vector + :param node_b: tuple with coordinate (x,y) or 2d vector + :return: + ------- + check if node_a and nobe_b are equal + """ + return node_a[0] == node_b[0] and node_a[1] == node_b[1] + + @staticmethod + def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D: + """ + vector operation : node_a - node_b + + :param node_a: tuple with coordinate (x,y) or 2d vector + :param node_b: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return node_a[0] - node_b[0], node_a[1] - node_b[1] + + @staticmethod + def add(node_a: Vector2D, node_b: Vector2D) -> Vector2D: + """ + vector operation : node_a + node_b + + :param node_a: tuple with coordinate (x,y) or 2d vector + :param node_b: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return node_a[0] + node_b[0], node_a[1] + node_b[1] + + @staticmethod + def make_orthogonal(node: Vector2D) -> Vector2D: + """ + vector operation : rotates the 2D vector +90° + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return node[1], -node[0] + + @staticmethod + def get_norm(node: Vector2D) -> float: + """ + calculates the euclidean norm of the 2d vector + [see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/] + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return np.sqrt(node[0] * node[0] + node[1] * node[1]) + + @staticmethod + def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float: + """ + calculates the euclidean norm of the 2d vector + + Parameters + ---------- + node_a + tuple with coordinate (x,y) or 2d vector + node_b + tuple with coordinate (x,y) or 2d vector + + Returns + ------- + float + Euclidean distance + """ + return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a)) + + @staticmethod + def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float: + """ + calculates the manhattan distance of the 2d vector + [see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/] + + Parameters + ---------- + node_a + tuple with coordinate (x,y) or 2d vector + node_b + tuple with coordinate (x,y) or 2d vector + + Returns + ------- + float + Mahnhattan distance + """ + delta = (Vec2dOperations.subtract(node_b, node_a)) + return np.abs(delta[0]) + np.abs(delta[1]) + + @staticmethod + def get_chebyshev_distance(node_a: Vector2D, node_b: Vector2D) -> float: + """ + calculates the chebyshev norm of the 2d vector + [see: https://lyfat.wordpress.com/2012/05/22/euclidean-vs-chebyshev-vs-manhattan-distance/] + + :Parameters + ---------- + node_a + tuple with coordinate (x,y) or 2d vector + node_b + tuple with coordinate (x,y) or 2d vector + + Returns + ------- + float + the chebyshev distance + """ + delta = (Vec2dOperations.subtract(node_b, node_a)) + return max(np.abs(delta[0]), np.abs(delta[1])) + + @staticmethod + def normalize(node: Vector2D) -> Tuple[float, float]: + """ + normalize the 2d vector = v/|v| + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + n = Vec2dOperations.get_norm(node) + if n > 0.0: + n = 1 / n + return Vec2dOperations.scale(node, n) + + @staticmethod + def scale(node: Vector2D, scale: float) -> Vector2D: + """ + scales the 2d vector = node * scale + + :param node: tuple with coordinate (x,y) or 2d vector + :param scale: scalar to scale + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return node[0] * scale, node[1] * scale + + @staticmethod + def round(node: Vector2D) -> IntVector2D: + """ + rounds the x and y coordinate and convert them to an integer values + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return int(np.round(node[0])), int(np.round(node[1])) + + @staticmethod + def ceil(node: Vector2D) -> IntVector2D: + """ + ceiling the x and y coordinate and convert them to an integer values + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return int(np.ceil(node[0])), int(np.ceil(node[1])) + + @staticmethod + def floor(node: Vector2D) -> IntVector2D: + """ + floor the x and y coordinate and convert them to an integer values + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return int(np.floor(node[0])), int(np.floor(node[1])) + + @staticmethod + def bound(node: Vector2D, min_value: float, max_value: float) -> Vector2D: + """ + force the values x and y to be between min_value and max_value + + :param node: tuple with coordinate (x,y) or 2d vector + :param min_value: scalar value + :param max_value: scalar value + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1])) + + @staticmethod + def rotate(node: Vector2D, rot_in_degree: float) -> Vector2D: + """ + rotate the 2d vector with given angle in degree + + :param node: tuple with coordinate (x,y) or 2d vector + :param rot_in_degree: angle in degree + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + alpha = rot_in_degree / 180.0 * np.pi + x0 = node[0] + y0 = node[1] + x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha) + y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha) + return x1, y1 + -def position_to_coordinate(depth: int, positions): +def position_to_coordinate(depth: int, positions: List[int]): """Converts coordinates to positions:: [ (0,0) (0,1) .. (0,w-1) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index fc5bb1d3514b55d805a8323385749cd7349f4c17..07678add5549c3ac13df876132ed3bcdbf5bec5e 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -7,7 +7,9 @@ from importlib_resources import path from numpy import array from flatland.core.grid.grid4 import Grid4Transitions -from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid4_utils import get_new_position, get_direction +from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions from flatland.utils.ordered_set import OrderedSet @@ -301,7 +303,7 @@ class GridTransitionMap(TransitionMap): self.height = new_height self.grid = new_grid - def is_dead_end(self, rcPos): + def is_dead_end(self, rcPos: IntVector2DArray): """ Check if the cell is a dead-end. @@ -321,7 +323,7 @@ class GridTransitionMap(TransitionMap): tmp = tmp >> 1 return nbits == 1 - def is_simple_turn(self, rcPos): + def is_simple_turn(self, rcPos: IntVector2DArray): """ Check if the cell is a left/right simple turn @@ -348,7 +350,7 @@ class GridTransitionMap(TransitionMap): return is_simple_turn(tmp) - def check_path_exists(self, start, direction, end): + def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray): # print("_path_exists({},{},{}".format(start, direction, end)) # BFS - Check if a path exists between the 2 nodes @@ -358,7 +360,8 @@ class GridTransitionMap(TransitionMap): node = stack.pop() node_position = node[0] node_direction = node[1] - if node_position[0] == end[0] and node_position[1] == end[1]: + + if Vec2d.is_equal(node_position, end): return True if node not in visited: visited.add(node) @@ -371,7 +374,7 @@ class GridTransitionMap(TransitionMap): return False - def cell_neighbours_valid(self, rcPos, check_this_cell=False): + def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) Checks that: @@ -422,7 +425,7 @@ class GridTransitionMap(TransitionMap): return True - def fix_neighbours(self, rcPos, check_this_cell=False): + def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) Checks that: @@ -474,7 +477,7 @@ class GridTransitionMap(TransitionMap): return True - def fix_transitions(self, rcPos): + def fix_transitions(self, rcPos: IntVector2DArray): """ Fixes broken transitions """ @@ -539,6 +542,46 @@ class GridTransitionMap(TransitionMap): self.set_transitions((rcPos[0], rcPos[1]), transition) return True + def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D, + new_pos: IntVector2D, end_pos: IntVector2D): + + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = self.grid[current_pos] + if prev_pos is None: + if new_trans == 0: + # need to flip direction because of how end points are defined + new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # check if matches existing layout + new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if Vec2d.is_equal(new_pos, end_pos): + # need to validate end pos setup as well + new_trans_e = self.grid[end_pos] + if new_trans_e == 0: + # need to flip direction because of how end points are defined + new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # check if matches existing layout + new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1) + + if not self.transitions.is_valid(new_trans_e): + return False + + # is transition is valid? + return self.transitions.is_valid(new_trans) + def mirror(dir): return (dir + 2) % 4 diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 996bd73ad9de598eb162a937c135681675119ad3..d6f47abfd85cfa1cc7e72e27aeb4f7ededa975dd 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -7,14 +7,22 @@ a GridTransitionMap object. from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_rail(rail_trans, rail_array, start, end): +def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, + end: IntVector2D, + flip_start_node_trans=False, + flip_end_node_trans=False, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): """ - Creates a new path [start,end] in rail_array, based on rail_trans. + Creates a new path [start,end] in grid_map, based on rail_trans. """ # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) + path = a_star(grid_map, start, end, a_star_distance_function) if len(path) < 2: return [] current_dir = get_direction(path[0], path[1]) @@ -24,12 +32,15 @@ def connect_rail(rail_trans, rail_array, start, end): new_pos = path[index + 1] new_dir = get_direction(current_pos, new_pos) - new_trans = rail_array[current_pos] + new_trans = grid_map.grid[current_pos] if index == 0: if new_trans == 0: # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + if flip_start_node_trans: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + new_trans = 0 else: # into existing rail new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -38,158 +49,45 @@ def connect_rail(rail_trans, rail_array, start, end): new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) # set the backwards path new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans + grid_map.grid[current_pos] = new_trans if new_pos == end_pos: # setup end pos setup - new_trans_e = rail_array[end_pos] + new_trans_e = grid_map.grid[end_pos] if new_trans_e == 0: # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + if flip_end_node_trans: + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + new_trans_e = 0 else: # into existing rail new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e + grid_map.grid[end_pos] = new_trans_e current_dir = new_dir return path -def connect_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) +def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function) - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # don't set any transition at node yet - new_trans = 0 - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - # don't set any transition at node yet +def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function) - new_trans_e = 0 - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - current_dir = new_dir - return path +def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function) -def connect_from_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = 0 - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path - - -def connect_to_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = 0 - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path +def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index f2dc5c46dfdf11c31b865483f20dfbc19f4964e6..60c606f789f0d83f04fd5c549e155f437c977b7d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -123,7 +123,7 @@ def complex_rail_generator(nr_start_goal=1, # we might as well give up at this point break - new_path = connect_rail(rail_trans, rail_array, start, goal) + new_path = connect_rail(rail_trans, grid_map, start, goal) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) @@ -148,7 +148,7 @@ def complex_rail_generator(nr_start_goal=1, break if not all_ok: break - new_path = connect_rail(rail_trans, rail_array, start, goal) + new_path = connect_rail(rail_trans, grid_map, start, goal) if len(new_path) >= 2: nr_created += 1 @@ -645,7 +645,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for neighb in connected_neighb_idx: if neighb not in node_stack: node_stack.append(neighb) - connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + connect_nodes(rail_trans, grid_map, node_positions[current_node], node_positions[neighb]) node_stack.pop(0) # Place train stations close to the node @@ -688,7 +688,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 train_stations[trainstation_node].append((station_x, station_y)) # Connect train station to the correct node - connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], + connection = connect_from_nodes(rail_trans, grid_map, node_positions[trainstation_node], (station_x, station_y)) # Check if connection was made if len(connection) == 0: @@ -723,11 +723,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 width - 2) # Connect train station to the correct node - connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1), + connect_nodes(rail_trans, grid_map, (intersect_x_1, intersect_y_1), (intersect_x_2, intersect_y_2)) - connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + connect_nodes(rail_trans, grid_map, intersection_positions[intersection], (intersect_x_1, intersect_y_1)) - connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + connect_nodes(rail_trans, grid_map, intersection_positions[intersection], (intersect_x_2, intersect_y_2)) grid_map.fix_transitions((intersect_x_1, intersect_y_1)) grid_map.fix_transitions((intersect_x_2, intersect_y_2)) diff --git a/flatland/envs/rail_generators_city_generator.py b/flatland/envs/rail_generators_city_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..53cdaef1e0bbe794f38f72cb77089f296e0c9cf5 --- /dev/null +++ b/flatland/envs/rail_generators_city_generator.py @@ -0,0 +1,499 @@ +import copy +import warnings +from typing import Sequence, Optional + +import numpy as np + +from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2DDistance, IntVector2DArrayArray +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail +from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct + +FloatArrayType = Sequence[float] + + +def city_generator(num_cities: int = 5, + city_size: int = 10, + allowed_rotation_angles: Optional[Sequence[float]] = None, + max_number_of_station_tracks: int = 4, + nbr_of_switches_per_station_track: int = 2, + connect_max_nbr_of_shortes_city: int = 4, + do_random_connect_stations: bool = False, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, + seed: int = 0, + print_out_info: bool = True) -> RailGenerator: + """ + This is a level generator which generates a realistic rail configurations + + :param num_cities: Number of city node + :param city_size: Length of city measure in cells + :param allowed_rotation_angles: Rotate the city (around center) + :param max_number_of_station_tracks: max number of tracks per station + :param nbr_of_switches_per_station_track: number of switches per track (max) + :param connect_max_nbr_of_shortes_city: max number of connecting track between stations + :param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand + :param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the "a-star" path + :param seed: Random Seed + :param print_out_info: print debug info if True + :return: + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def do_generate_city_locations(width: int, + height: int, + intern_city_size: int, + intern_max_number_of_station_tracks: int) -> (IntVector2DArray, int): + + X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) + Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) + + max_num_cities = min(num_cities, X * Y) + + cities_at = np.random.choice(X * Y, max_num_cities, False) + cities_at = np.sort(cities_at) + if print_out_info: + print("max nbr of cities with given configuration is:", max_num_cities) + + x = np.floor(cities_at / Y) + y = cities_at - x * Y + xs = (x * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2 + ys = (y * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2 + + generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))] + return generate_city_locations, max_num_cities + + def do_orient_cities(generate_city_locations: IntVector2DArrayArray, intern_city_size: int, + rotation_angles_set: FloatArrayType): + for i in range(len(generate_city_locations)): + # station main orientation (horizontal or vertical + rot_angle = np.random.choice(rotation_angles_set) + add_pos_val = Vec2d.scale(Vec2d.rotate((1, 0), rot_angle), + int(max(1.0, (intern_city_size - 3) / 2))) + # noinspection PyTypeChecker + generate_city_locations[i][0] = Vec2d.add(generate_city_locations[i][1], add_pos_val) + add_pos_val = Vec2d.scale(Vec2d.rotate((1, 0), 180 + rot_angle), + int(max(1.0, (intern_city_size - 3) / 2))) + # noinspection PyTypeChecker + generate_city_locations[i][1] = Vec2d.add(generate_city_locations[i][1], add_pos_val) + return generate_city_locations + + # noinspection PyTypeChecker + def create_stations_from_city_locations(rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + generate_city_locations: IntVector2DArrayArray, + intern_max_number_of_station_tracks: int) -> (IntVector2DArray, + IntVector2DArray, + IntVector2DArray, + IntVector2DArray, + IntVector2DArray): + + nodes_added = [] + start_nodes_added: IntVector2DArrayArray = [[] for _ in range(len(generate_city_locations))] + end_nodes_added: IntVector2DArrayArray = [[] for _ in range(len(generate_city_locations))] + station_slots = [[] for _ in range(len(generate_city_locations))] + station_tracks = [[[] for _ in range(intern_max_number_of_station_tracks)] for _ in range(len( + generate_city_locations))] + + station_slots_cnt = 0 + + for city_loop in range(len(generate_city_locations)): + # Connect train station to the correct node + number_of_connecting_tracks = np.random.choice(max(0, intern_max_number_of_station_tracks)) + 1 + track_id = 0 + for ct in range(number_of_connecting_tracks): + org_start_node = generate_city_locations[city_loop][0] + org_end_node = generate_city_locations[city_loop][1] + + ortho_trans = Vec2d.make_orthogonal( + Vec2d.normalize(Vec2d.subtract(org_start_node, org_end_node))) + s = (ct - number_of_connecting_tracks / 2.0) + start_node = Vec2d.ceil( + Vec2d.add(org_start_node, Vec2d.scale(ortho_trans, s))) + end_node = Vec2d.ceil( + Vec2d.add(org_end_node, Vec2d.scale(ortho_trans, s))) + + connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + + start_nodes_added[city_loop].append(start_node) + end_nodes_added[city_loop].append(end_node) + + # place in the center of path a station slot + # station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) + for c_loop in range(len(connection)): + station_slots[city_loop].append(connection[c_loop]) + station_slots_cnt += len(connection) + + station_tracks[city_loop][track_id] = connection + track_id += 1 + else: + if print_out_info: + print("create_stations_from_city_locations : connect_from_nodes -> no path found") + + if print_out_info: + print("max nbr of station slots with given configuration is:", station_slots_cnt) + + return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks + + # noinspection PyTypeChecker + def create_switches_at_stations(rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + station_tracks: IntVector2DArrayArray, + nodes_added: IntVector2DArray, + intern_nbr_of_switches_per_station_track: int) -> IntVector2DArray: + + for k_loop in range(intern_nbr_of_switches_per_station_track): + for city_loop in range(len(station_tracks)): + k = k_loop + city_loop + datas = station_tracks[city_loop] + if len(datas) > 1: + + track = datas[0] + if len(track) > 0: + if k % 2 == 0: + x = int(np.random.choice(int(len(track) / 2)) + 1) + else: + x = len(track) - int(np.random.choice(int(len(track) / 2)) + 1) + start_node = track[x] + for i in np.arange(1, len(datas)): + track = datas[i] + if len(track) > 1: + if k % 2 == 0: + x = x + 2 + if len(track) <= x: + x = 1 + else: + x = x - 2 + if x < 2: + x = len(track) - 1 + end_node = track[x] + connection = connect_rail(rail_trans, grid_map, start_node, end_node, + a_star_distance_function) + if len(connection) == 0: + if print_out_info: + print("create_switches_at_stations : connect_rail -> no path found") + start_node = datas[i][0] + end_node = datas[i - 1][0] + connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function) + + nodes_added.append(start_node) + nodes_added.append(end_node) + + if k % 2 == 0: + x = x + 2 + if len(track) <= x: + x = 1 + else: + x = x - 2 + if x < 2: + x = len(track) - 2 + start_node = track[x] + + return nodes_added + + def create_graph_edge(from_city_index: int, to_city_index: int) -> (int, int, int): + return from_city_index, to_city_index, np.inf + + def calc_nbr_of_graphs(graph: []) -> ([], []): + for i in range(len(graph)): + for j in range(len(graph)): + a = graph[i] + b = graph[j] + connected = False + if a[0] == b[0] or a[1] == b[0]: + connected = True + if a[0] == b[1] or a[1] == b[1]: + connected = True + + if connected: + a = [graph[i][0], graph[i][1], graph[i][2]] + b = [graph[j][0], graph[j][1], graph[j][2]] + graph[i] = (graph[i][0], graph[i][1], min(np.min(a), np.min(b))) + graph[j] = (graph[j][0], graph[j][1], min(np.min(a), np.min(b))) + else: + a = [graph[i][0], graph[i][1], graph[i][2]] + graph[i] = (graph[i][0], graph[i][1], np.min(a)) + b = [graph[j][0], graph[j][1], graph[j][2]] + graph[j] = (graph[j][0], graph[j][1], np.min(b)) + + graph_ids = [] + for i in range(len(graph)): + graph_ids.append(graph[i][2]) + if print_out_info: + print("************* NBR of graphs:", len(np.unique(graph_ids))) + return graph, np.unique(graph_ids).astype(int) + + def connect_sub_graphs(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + org_s_nodes: IntVector2DArrayArray, + org_e_nodes: IntVector2DArrayArray, + city_edges: IntVector2DArray, + nodes_added: IntVector2DArray): + _, graphids = calc_nbr_of_graphs(city_edges) + if len(graphids) > 0: + for i in range(len(graphids) - 1): + connection = [] + iteration_counter = 0 + while len(connection) == 0 and iteration_counter < 100: + s_nodes = copy.deepcopy(org_s_nodes) + e_nodes = copy.deepcopy(org_e_nodes) + start_nodes = s_nodes[graphids[i]] + end_nodes = e_nodes[graphids[i + 1]] + start_node = start_nodes[np.random.choice(len(start_nodes))] + end_node = end_nodes[np.random.choice(len(end_nodes))] + # TODO : removing, what the hell is going on, why we have to set rail_array -> transition to zero + # TODO : before we can call connect_rail. If we don't reset the transistion to zero -> no rail + # TODO : will be generated. + grid_map.grid[start_node] = 0 + grid_map.grid[end_node] = 0 + connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + else: + if print_out_info: + print("connect_sub_graphs : connect_rail -> no path found") + + iteration_counter += 1 + + def connect_stations(rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + org_s_nodes: IntVector2DArrayArray, + org_e_nodes: IntVector2DArrayArray, + nodes_added: IntVector2DArray, + intern_connect_max_nbr_of_shortes_city: int): + city_edges = [] + + s_nodes:IntVector2DArrayArray = copy.deepcopy(org_s_nodes) + e_nodes:IntVector2DArrayArray = copy.deepcopy(org_e_nodes) + + for nbr_connected in range(intern_connect_max_nbr_of_shortes_city): + for city_loop in range(len(s_nodes)): + sns = s_nodes[city_loop] + for start_node in sns: + min_distance = np.inf + end_node = None + cl = 0 + for city_loop_find_shortest in range(len(e_nodes)): + if city_loop_find_shortest == city_loop: + continue + ens = e_nodes[city_loop_find_shortest] + for en in ens: + d = Vec2d.get_euclidean_distance(start_node, en) + if d < min_distance: + min_distance = d + end_node = en + cl = city_loop_find_shortest + + if end_node is not None: + tmp_trans_sn = grid_map.grid[start_node] + tmp_trans_en = grid_map.grid[end_node] + grid_map.grid[start_node] = 0 + grid_map.grid[end_node] = 0 + connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function) + if len(connection) > 0: + s_nodes[city_loop].remove(start_node) + e_nodes[cl].remove(end_node) + + edge = create_graph_edge(city_loop, cl) + if city_loop > cl: + edge = create_graph_edge(cl, city_loop) + if not (edge in city_edges): + city_edges.append(edge) + nodes_added.append(start_node) + nodes_added.append(end_node) + else: + if print_out_info: + print("connect_stations : connect_rail -> no path found") + + grid_map.grid[start_node] = tmp_trans_sn + grid_map.grid[end_node] = tmp_trans_en + + connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added) + + def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start_nodes_added: IntVector2DArray, + end_nodes_added: IntVector2DArray, + nodes_added: IntVector2DArray, + intern_connect_max_nbr_of_shortes_city: int): + if len(start_nodes_added) < 1: + return + x = np.arange(len(start_nodes_added)) + random_city_idx = np.random.choice(x, len(x), False) + + # cyclic connection + random_city_idx = np.append(random_city_idx, random_city_idx[0]) + + for city_loop in range(len(random_city_idx) - 1): + idx_a = random_city_idx[city_loop + 1] + idx_b = random_city_idx[city_loop] + s_nodes = start_nodes_added[idx_a] + e_nodes = end_nodes_added[idx_b] + + max_input_output = max(len(s_nodes), len(e_nodes)) + max_input_output = min(intern_connect_max_nbr_of_shortes_city, max_input_output) + + idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) + idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) + + if len(idx_s_nodes) < max_input_output: + idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len( + idx_s_nodes))) + if len(idx_e_nodes) < max_input_output: + idx_e_nodes = np.append(idx_e_nodes, + np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len( + idx_e_nodes))) + + if len(idx_s_nodes) > intern_connect_max_nbr_of_shortes_city: + idx_s_nodes = np.random.choice(idx_s_nodes, intern_connect_max_nbr_of_shortes_city, False) + if len(idx_e_nodes) > intern_connect_max_nbr_of_shortes_city: + idx_e_nodes = np.random.choice(idx_e_nodes, intern_connect_max_nbr_of_shortes_city, False) + + for i in range(max_input_output): + start_node = s_nodes[idx_s_nodes[i]] + end_node = e_nodes[idx_e_nodes[i]] + grid_map.grid[start_node] = 0 + grid_map.grid[end_node] = 0 + connection = connect_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + else: + if print_out_info: + print("connect_random_stations : connect_nodes -> no path found") + + def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + train_stations: IntVector2DArray): + tmp_train_stations = copy.deepcopy(train_stations) + for city_loop in range(len(train_stations)): + for n in tmp_train_stations[city_loop]: + do_remove = True + trans = rail_trans.transition_list[1] + for _ in range(4): + trans = rail_trans.rotate_transition(trans, rotation=90) + if grid_map.grid[n] == trans: + do_remove = False + if do_remove: + train_stations[city_loop].remove(n) + + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + grid_map.grid.fill(0) + np.random.seed(seed + num_resets) + + intern_city_size = city_size + if city_size < 3: + warnings.warn("min city_size requried to be > 3!") + intern_city_size = 3 + if print_out_info: + print("intern_city_size:", intern_city_size) + + intern_max_number_of_station_tracks = max_number_of_station_tracks + if max_number_of_station_tracks < 1: + warnings.warn("min max_number_of_station_tracks requried to be > 1!") + intern_max_number_of_station_tracks = 1 + if print_out_info: + print("intern_max_number_of_station_tracks:", intern_max_number_of_station_tracks) + + intern_nbr_of_switches_per_station_track = nbr_of_switches_per_station_track + if nbr_of_switches_per_station_track < 1: + warnings.warn("min intern_nbr_of_switches_per_station_track requried to be > 2!") + intern_nbr_of_switches_per_station_track = 2 + if print_out_info: + print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track) + + intern_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city + if connect_max_nbr_of_shortes_city < 1: + warnings.warn("min intern_connect_max_nbr_of_shortes_city requried to be > 1!") + intern_connect_max_nbr_of_shortes_city = 1 + if print_out_info: + print("intern_connect_max_nbr_of_shortes_city:", intern_connect_max_nbr_of_shortes_city) + + # ---------------------------------------------------------------------------------- + # generate city locations + generate_city_locations, max_num_cities = do_generate_city_locations(width, height, intern_city_size, + intern_max_number_of_station_tracks) + + # ---------------------------------------------------------------------------------- + # apply orientation to cities (horizontal, vertical) + generate_city_locations = do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles) + + # ---------------------------------------------------------------------------------- + # generate city topology + nodes_added, train_stations_slots, s_nodes, e_nodes, station_tracks = \ + create_stations_from_city_locations(rail_trans, grid_map, + generate_city_locations, + intern_max_number_of_station_tracks) + # build switches + create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added, + intern_nbr_of_switches_per_station_track) + + # ---------------------------------------------------------------------------------- + # connect stations + if do_random_connect_stations: + connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, + intern_connect_max_nbr_of_shortes_city) + else: + connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, + intern_connect_max_nbr_of_shortes_city) + + # ---------------------------------------------------------------------------------- + # fix all transition at starting / ending points (mostly add a dead end, if missing) + # TODO we might have to remove the fixing stuff in the future + for i in range(len(nodes_added)): + grid_map.fix_transitions(nodes_added[i]) + + # ---------------------------------------------------------------------------------- + # remove stations where underlaying rail is a switch + remove_switch_stations(rail_trans, grid_map, train_stations_slots) + + # ---------------------------------------------------------------------------------- + # Slot availability in node + node_available_start = [] + node_available_target = [] + for node_idx in range(max_num_cities): + node_available_start.append(len(train_stations_slots[node_idx])) + node_available_target.append(len(train_stations_slots[node_idx])) + + # Assign agents to slots + agent_start_targets_nodes = [] + for agent_idx in range(num_agents): + avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] + avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] + if len(avail_target_nodes) == 0: + num_agents -= 1 + continue + start_node = np.random.choice(avail_start_nodes) + target_node = np.random.choice(avail_target_nodes) + tries = 0 + found_agent_pair = True + while target_node == start_node: + target_node = np.random.choice(avail_target_nodes) + tries += 1 + # Test again with new start node if no pair is found (This code needs to be improved) + if (tries + 1) % 10 == 0: + start_node = np.random.choice(avail_start_nodes) + if tries > 100: + warnings.warn("Could not set train_stations, removing agent!") + found_agent_pair = False + break + if found_agent_pair: + node_available_start[start_node] -= 1 + node_available_target[target_node] -= 1 + agent_start_targets_nodes.append((start_node, target_node)) + else: + num_agents -= 1 + + return grid_map, {'agents_hints': { + 'num_agents': num_agents, + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations_slots + }} + + return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 467e4a7fe2e00a3c6dc1720382ca6188b6affd7f..b3576a2bec77f75afc9331cc6c190649590a990c 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -243,3 +243,8 @@ def schedule_from_file(filename) -> ScheduleGenerator: return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction return generator + + +# we can us the same schedule generator for city_rail_generator +# in order to be able to change this transparently in the future, we use a different name. +city_schedule_generator = sparse_schedule_generator diff --git a/flatland/utils/misc.py b/flatland/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e2282698c7a8030a2d8637d7ef9acd8d80afb517 --- /dev/null +++ b/flatland/utils/misc.py @@ -0,0 +1,3 @@ +# https://stackoverflow.com/questions/715417/converting-from-a-string-to-boolean-in-python +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py new file mode 100644 index 0000000000000000000000000000000000000000..72deddc66eacc71a5aa840b49225e5ba056a8b84 --- /dev/null +++ b/tests/test_flatland_core_grid4_generators_util.py @@ -0,0 +1,69 @@ +import numpy as np + +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes, connect_to_nodes + + +def test_build_railway_infrastructure(): + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=20, height=20, transitions=rail_trans) + grid_map.grid.fill(0) + np.random.seed(0) + + start_point = (2, 2) + end_point = (8, 8) + connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_001_expected = [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), + (7, 8), (8, 8)] + + start_point = (1, 3) + end_point = (1, 7) + connection_002 = connect_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] + + start_point = (6, 2) + end_point = (6, 5) + connection_003 = connect_from_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] + + start_point = (7, 5) + end_point = (8, 9) + connection_004 = connect_to_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] + + assert connection_001 == connection_001_expected, \ + "actual={}, expected={}".format(connection_001, connection_001_expected) + assert connection_002 == connection_002_expected, \ + "actual={}, expected={}".format(connection_002, connection_002_expected) + assert connection_003 == connection_003_expected, \ + "actual={}, expected={}".format(connection_003, connection_003_expected) + assert connection_004 == connection_004_expected, \ + "actual={}, expected={}".format(connection_004, connection_004_expected) + + grid_map_grid_expected = [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1025, 1025, 1025, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 4, 1025, 1025, 1025, 1025, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1025, 1025, 256, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 4, 1025, 1025, 33825, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ] + + assert np.all(grid_map.grid == grid_map_grid_expected), \ + "actual={}, expected={}".format(grid_map.grid, grid_map_grid_expected) diff --git a/tests/test_flatland_core_grid_grid_utils.py b/tests/test_flatland_core_grid_grid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d598255375b6b86cb44132e03796403b968676c0 --- /dev/null +++ b/tests/test_flatland_core_grid_grid_utils.py @@ -0,0 +1,150 @@ +import numpy as np + +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d + + +def test_vec2d_is_equal(): + node_a = (1, 2) + node_b = (2, 4) + node_c = (1, 2) + res_1 = Vec2d.is_equal(node_a, node_b) + res_2 = Vec2d.is_equal(node_a, node_c) + + assert not res_1 + assert res_2 + + +def test_vec2d_subtract(): + node_a = (1, 2) + node_b = (2, 4) + res_1 = Vec2d.subtract(node_a, node_b) + res_2 = Vec2d.subtract(node_b, node_a) + assert res_1 != res_2 + assert res_1 == (-1, -2) + assert res_2 == (1, 2) + + +def test_vec2d_add(): + node_a = (1, 2) + node_b = (2, 3) + res_1 = Vec2d.add(node_a, node_b) + res_2 = Vec2d.add(node_b, node_a) + assert res_1 == res_2 + assert res_1 == (3, 5) + + +def test_vec2d_make_orthogonal(): + node_a = (1, 2) + res_1 = Vec2d.make_orthogonal(node_a) + assert res_1 == (2, -1) + + +def test_vec2d_euclidean_distance(): + node_a = (3, -7) + node_0 = (0, 0) + assert Vec2d.get_euclidean_distance(node_a, node_0) == Vec2d.get_norm(node_a) + + +def test_vec2d_manhattan_distance(): + node_a = (3, -7) + node_0 = (0, 0) + assert Vec2d.get_manhattan_distance(node_a, node_0) == 3 + 7 + + +def test_vec2d_chebyshev_distance(): + node_a = (3, -7) + node_0 = (0, 0) + assert Vec2d.get_chebyshev_distance(node_a, node_0) == 7 + node_b = (-3, 7) + node_0 = (0, 0) + assert Vec2d.get_chebyshev_distance(node_b, node_0) == 7 + node_c = (3, 7) + node_0 = (0, 0) + assert Vec2d.get_chebyshev_distance(node_c, node_0) == 7 + + +def test_vec2d_norm(): + node_a = (1, 2) + node_b = (1, -2) + res_1 = Vec2d.get_norm(node_a) + res_2 = Vec2d.get_norm(node_b) + assert np.sqrt(1 * 1 + 2 * 2) == res_1 + assert np.sqrt(1 * 1 + (-2) * (-2)) == res_2 + + +def test_vec2d_normalize(): + node_a = (1, 2) + node_b = (1, -2) + res_1 = Vec2d.normalize(node_a) + res_2 = Vec2d.normalize(node_b) + assert np.isclose(1.0, Vec2d.get_norm(res_1)) + assert np.isclose(1.0, Vec2d.get_norm(res_2)) + + +def test_vec2d_scale(): + node_a = (1, 2) + node_b = (1, -2) + res_1 = Vec2d.scale(node_a, 2) + res_2 = Vec2d.scale(node_b, -2.5) + assert res_1 == (2, 4) + assert res_2 == (-2.5, 5) + + +def test_vec2d_round(): + node_a = (-1.95, -2.2) + node_b = (1.95, 2.2) + res_1 = Vec2d.round(node_a) + res_2 = Vec2d.round(node_b) + assert res_1 == (-2, -2) + assert res_2 == (2, 2) + + +def test_vec2d_ceil(): + node_a = (-1.95, -2.2) + node_b = (1.95, 2.2) + res_1 = Vec2d.ceil(node_a) + res_2 = Vec2d.ceil(node_b) + assert res_1 == (-1, -2) + assert res_2 == (2, 3) + + +def test_vec2d_floor(): + node_a = (-1.95, -2.2) + node_b = (1.95, 2.2) + res_1 = Vec2d.floor(node_a) + res_2 = Vec2d.floor(node_b) + assert res_1 == (-2, -3) + assert res_2 == (1, 2) + + +def test_vec2d_bound(): + node_a = (-1.95, -2.2) + node_b = (1.95, 2.2) + res_1 = Vec2d.bound(node_a, -1, 0) + res_2 = Vec2d.bound(node_b, 2, 2.2) + assert res_1 == (-1, -1) + assert res_2 == (2, 2.2) + + +def test_vec2d_rotate(): + node_a = (-1.95, -2.2) + res_1 = Vec2d.rotate(node_a, -90.0) + res_2 = Vec2d.rotate(node_a, 0.0) + res_3 = Vec2d.rotate(node_a, 90.0) + res_4 = Vec2d.rotate(node_a, 180.0) + res_5 = Vec2d.rotate(node_a, 270.0) + res_6 = Vec2d.rotate(node_a, 30.0) + + res_1 = (Vec2d.get_norm(Vec2d.subtract(res_1, (-2.2, 1.95)))) + res_2 = (Vec2d.get_norm(Vec2d.subtract(res_2, (-1.95, -2.2)))) + res_3 = (Vec2d.get_norm(Vec2d.subtract(res_3, (2.2, -1.95)))) + res_4 = (Vec2d.get_norm(Vec2d.subtract(res_4, (1.95, 2.2)))) + res_5 = (Vec2d.get_norm(Vec2d.subtract(res_5, (-2.2, 1.95)))) + res_6 = (Vec2d.get_norm(Vec2d.subtract(res_6, (-0.5887495373796556, -2.880255888325765)))) + + assert np.isclose(0, res_1) + assert np.isclose(0, res_2) + assert np.isclose(0, res_3) + assert np.isclose(0, res_4) + assert np.isclose(0, res_5) + assert np.isclose(0, res_6) diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 9d4a72c056f9825383f1c9d055509e870219fd9f..b4c268a78826770f89cb59477465949780b960a0 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -2,12 +2,10 @@ # -*- coding: utf-8 -*- """Tests for `flatland` package.""" -import numpy as np - from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid8 import Grid8Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.core.grid.grid4_utils import validate_new_transition +from flatland.core.transition_map import GridTransitionMap # remove whitespace in string; keep whitespace below for easier reading @@ -117,35 +115,35 @@ def test_is_valid_railenv_transitions(): def test_adding_new_valid_transition(): rail_trans = RailEnvTransitions() - rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) + grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans) # adding straight - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True) + assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5), (10, 10)) is True) # adding valid right turn - assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6), (10, 10)) is True) # adding valid left turn - assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - rail_array[(5, 5)] = rail_trans.transitions[2] - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + grid_map.grid[(5, 5)] = rail_trans.transitions[2] + assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid - rail_array[(5, 5)] = rail_trans.transitions[3] - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True) + grid_map.grid[(5, 5)] = rail_trans.transitions[3] + assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - rail_array[(5, 5)] = rail_trans.transitions[7] - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + grid_map.grid[(5, 5)] = rail_trans.transitions[7] + assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition - rail_array[(5, 5)] = rail_trans.transitions[0] - assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True) + grid_map.grid[(5, 5)] = rail_trans.transitions[0] + assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition - rail_array[(5, 5)] = rail_trans.transitions[0] - assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True) + grid_map.grid[(5, 5)] = rail_trans.transitions[0] + assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True) def test_valid_railenv_transitions(): diff --git a/tests/test_flatland_envs_city_generator.py b/tests/test_flatland_envs_city_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fe39d785712c88f03af09c7a6d1dac715a585db3 --- /dev/null +++ b/tests/test_flatland_envs_city_generator.py @@ -0,0 +1,301 @@ +import numpy as np + +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators_city_generator import city_generator +from flatland.envs.schedule_generators import city_schedule_generator + + +def test_city_generator(): + dist_fun = Vec2d.get_manhattan_distance + env = RailEnv(width=50, + height=50, + rail_generator=city_generator(num_cities=5, + city_size=10, + allowed_rotation_angles=[90], + max_number_of_station_tracks=4, + nbr_of_switches_per_station_track=2, + connect_max_nbr_of_shortes_city=2, + do_random_connect_stations=False, + a_star_distance_function=dist_fun, + seed=0, + print_out_info=False + ), + schedule_generator=city_schedule_generator(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) + + expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) + + expected_grid_map[8][16]=4 + expected_grid_map[8][17]=5633 + expected_grid_map[8][18]=1025 + expected_grid_map[8][19]=1025 + expected_grid_map[8][20]=17411 + expected_grid_map[8][21]=1025 + expected_grid_map[8][22]=1025 + expected_grid_map[8][23]=1025 + expected_grid_map[8][24]=1025 + expected_grid_map[8][25]=1025 + expected_grid_map[8][26]=4608 + expected_grid_map[9][16]=16386 + expected_grid_map[9][17]=50211 + expected_grid_map[9][18]=1025 + expected_grid_map[9][19]=1025 + expected_grid_map[9][20]=3089 + expected_grid_map[9][21]=1025 + expected_grid_map[9][22]=256 + expected_grid_map[9][26]=32800 + expected_grid_map[10][6]=16386 + expected_grid_map[10][7]=1025 + expected_grid_map[10][8]=1025 + expected_grid_map[10][9]=1025 + expected_grid_map[10][10]=1025 + expected_grid_map[10][11]=1025 + expected_grid_map[10][12]=1025 + expected_grid_map[10][13]=1025 + expected_grid_map[10][14]=1025 + expected_grid_map[10][15]=1025 + expected_grid_map[10][16]=33825 + expected_grid_map[10][17]=34864 + expected_grid_map[10][26]=32800 + expected_grid_map[11][6]=32800 + expected_grid_map[11][16]=32800 + expected_grid_map[11][17]=32800 + expected_grid_map[11][26]=32800 + expected_grid_map[12][6]=32800 + expected_grid_map[12][16]=32800 + expected_grid_map[12][17]=32800 + expected_grid_map[12][26]=32800 + expected_grid_map[13][6]=32800 + expected_grid_map[13][16]=32800 + expected_grid_map[13][17]=32800 + expected_grid_map[13][26]=32800 + expected_grid_map[14][6]=32800 + expected_grid_map[14][16]=32800 + expected_grid_map[14][17]=32800 + expected_grid_map[14][26]=32800 + expected_grid_map[15][6]=32800 + expected_grid_map[15][16]=32800 + expected_grid_map[15][17]=32800 + expected_grid_map[15][26]=32800 + expected_grid_map[16][6]=32800 + expected_grid_map[16][16]=32800 + expected_grid_map[16][17]=32800 + expected_grid_map[16][26]=32800 + expected_grid_map[17][6]=32800 + expected_grid_map[17][16]=72 + expected_grid_map[17][17]=1097 + expected_grid_map[17][18]=1025 + expected_grid_map[17][19]=1025 + expected_grid_map[17][20]=1025 + expected_grid_map[17][21]=1025 + expected_grid_map[17][22]=1025 + expected_grid_map[17][23]=1025 + expected_grid_map[17][24]=1025 + expected_grid_map[17][25]=1025 + expected_grid_map[17][26]=33825 + expected_grid_map[17][27]=4608 + expected_grid_map[18][6]=32800 + expected_grid_map[18][26]=72 + expected_grid_map[18][27]=52275 + expected_grid_map[18][28]=5633 + expected_grid_map[18][29]=17411 + expected_grid_map[18][30]=1025 + expected_grid_map[18][31]=1025 + expected_grid_map[18][32]=256 + expected_grid_map[19][6]=32800 + expected_grid_map[19][25]=16386 + expected_grid_map[19][26]=1025 + expected_grid_map[19][27]=2136 + expected_grid_map[19][28]=1097 + expected_grid_map[19][29]=1097 + expected_grid_map[19][30]=5633 + expected_grid_map[19][31]=1025 + expected_grid_map[19][32]=256 + expected_grid_map[20][6]=32800 + expected_grid_map[20][25]=32800 + expected_grid_map[20][26]=16386 + expected_grid_map[20][27]=17411 + expected_grid_map[20][28]=1025 + expected_grid_map[20][29]=1025 + expected_grid_map[20][30]=3089 + expected_grid_map[20][31]=1025 + expected_grid_map[20][32]=256 + expected_grid_map[21][6]=32800 + expected_grid_map[21][16]=16386 + expected_grid_map[21][17]=1025 + expected_grid_map[21][18]=1025 + expected_grid_map[21][19]=1025 + expected_grid_map[21][20]=1025 + expected_grid_map[21][21]=1025 + expected_grid_map[21][22]=1025 + expected_grid_map[21][23]=1025 + expected_grid_map[21][24]=1025 + expected_grid_map[21][25]=33825 + expected_grid_map[21][26]=33825 + expected_grid_map[21][27]=2064 + expected_grid_map[22][6]=32800 + expected_grid_map[22][16]=32800 + expected_grid_map[22][25]=32800 + expected_grid_map[22][26]=32800 + expected_grid_map[23][6]=32800 + expected_grid_map[23][16]=32800 + expected_grid_map[23][25]=32800 + expected_grid_map[23][26]=32800 + expected_grid_map[24][6]=32800 + expected_grid_map[24][16]=32800 + expected_grid_map[24][25]=32800 + expected_grid_map[24][26]=32800 + expected_grid_map[25][6]=32800 + expected_grid_map[25][16]=32800 + expected_grid_map[25][25]=32800 + expected_grid_map[25][26]=32800 + expected_grid_map[26][6]=32800 + expected_grid_map[26][16]=32800 + expected_grid_map[26][25]=32800 + expected_grid_map[26][26]=32800 + expected_grid_map[27][6]=72 + expected_grid_map[27][7]=1025 + expected_grid_map[27][8]=1025 + expected_grid_map[27][9]=17411 + expected_grid_map[27][10]=1025 + expected_grid_map[27][11]=1025 + expected_grid_map[27][12]=1025 + expected_grid_map[27][13]=1025 + expected_grid_map[27][14]=1025 + expected_grid_map[27][15]=4608 + expected_grid_map[27][16]=72 + expected_grid_map[27][17]=17411 + expected_grid_map[27][18]=5633 + expected_grid_map[27][19]=1025 + expected_grid_map[27][20]=1025 + expected_grid_map[27][21]=1025 + expected_grid_map[27][22]=1025 + expected_grid_map[27][23]=1025 + expected_grid_map[27][24]=1025 + expected_grid_map[27][25]=33825 + expected_grid_map[27][26]=2064 + expected_grid_map[28][6]=4 + expected_grid_map[28][7]=1025 + expected_grid_map[28][8]=1025 + expected_grid_map[28][9]=3089 + expected_grid_map[28][10]=1025 + expected_grid_map[28][11]=1025 + expected_grid_map[28][12]=1025 + expected_grid_map[28][13]=1025 + expected_grid_map[28][14]=4608 + expected_grid_map[28][15]=72 + expected_grid_map[28][16]=1025 + expected_grid_map[28][17]=2136 + expected_grid_map[28][18]=1097 + expected_grid_map[28][19]=5633 + expected_grid_map[28][20]=5633 + expected_grid_map[28][21]=1025 + expected_grid_map[28][22]=256 + expected_grid_map[28][25]=32800 + expected_grid_map[29][6]=4 + expected_grid_map[29][7]=5633 + expected_grid_map[29][8]=20994 + expected_grid_map[29][9]=5633 + expected_grid_map[29][10]=1025 + expected_grid_map[29][11]=1025 + expected_grid_map[29][12]=1025 + expected_grid_map[29][13]=1025 + expected_grid_map[29][14]=1097 + expected_grid_map[29][15]=5633 + expected_grid_map[29][16]=1025 + expected_grid_map[29][17]=17411 + expected_grid_map[29][18]=5633 + expected_grid_map[29][19]=1097 + expected_grid_map[29][20]=3089 + expected_grid_map[29][21]=20994 + expected_grid_map[29][22]=1025 + expected_grid_map[29][23]=1025 + expected_grid_map[29][24]=1025 + expected_grid_map[29][25]=2064 + expected_grid_map[30][6]=16386 + expected_grid_map[30][7]=38505 + expected_grid_map[30][8]=3089 + expected_grid_map[30][9]=1097 + expected_grid_map[30][10]=1025 + expected_grid_map[30][11]=1025 + expected_grid_map[30][12]=256 + expected_grid_map[30][15]=32800 + expected_grid_map[30][16]=16386 + expected_grid_map[30][17]=52275 + expected_grid_map[30][18]=1097 + expected_grid_map[30][19]=1025 + expected_grid_map[30][20]=1025 + expected_grid_map[30][21]=3089 + expected_grid_map[30][22]=256 + expected_grid_map[31][6]=32800 + expected_grid_map[31][7]=32800 + expected_grid_map[31][15]=72 + expected_grid_map[31][16]=37408 + expected_grid_map[31][17]=32800 + expected_grid_map[32][6]=32800 + expected_grid_map[32][7]=32800 + expected_grid_map[32][16]=32800 + expected_grid_map[32][17]=32800 + expected_grid_map[33][6]=32800 + expected_grid_map[33][7]=32800 + expected_grid_map[33][16]=32800 + expected_grid_map[33][17]=32800 + expected_grid_map[34][6]=32800 + expected_grid_map[34][7]=32800 + expected_grid_map[34][16]=32800 + expected_grid_map[34][17]=32800 + expected_grid_map[35][6]=32800 + expected_grid_map[35][7]=32800 + expected_grid_map[35][16]=32800 + expected_grid_map[35][17]=32800 + expected_grid_map[36][6]=32800 + expected_grid_map[36][7]=32800 + expected_grid_map[36][16]=32800 + expected_grid_map[36][17]=32800 + expected_grid_map[37][6]=72 + expected_grid_map[37][7]=1097 + expected_grid_map[37][8]=1025 + expected_grid_map[37][9]=1025 + expected_grid_map[37][10]=1025 + expected_grid_map[37][11]=1025 + expected_grid_map[37][12]=1025 + expected_grid_map[37][13]=1025 + expected_grid_map[37][14]=1025 + expected_grid_map[37][15]=1025 + expected_grid_map[37][16]=33897 + expected_grid_map[37][17]=37408 + expected_grid_map[38][16]=72 + expected_grid_map[38][17]=52275 + expected_grid_map[38][18]=5633 + expected_grid_map[38][19]=17411 + expected_grid_map[38][20]=1025 + expected_grid_map[38][21]=1025 + expected_grid_map[38][22]=256 + expected_grid_map[39][16]=4 + expected_grid_map[39][17]=52275 + expected_grid_map[39][18]=3089 + expected_grid_map[39][19]=1097 + expected_grid_map[39][20]=5633 + expected_grid_map[39][21]=1025 + expected_grid_map[39][22]=256 + expected_grid_map[40][16]=4 + expected_grid_map[40][17]=1097 + expected_grid_map[40][18]=1025 + expected_grid_map[40][19]=1025 + expected_grid_map[40][20]=3089 + expected_grid_map[40][21]=1025 + expected_grid_map[40][22]=256 + + assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid, + expected_grid_map) + + s0 = 0 + s1 = 0 + for a in range(env.get_num_agents()): + s0 = Vec2d.get_manhattan_distance(env.agents[a].position, (0, 0)) + s1 = Vec2d.get_chebyshev_distance(env.agents[a].position, (0, 0)) + assert s0 == 58, "actual={}".format(s0) + assert s1 == 38, "actual={}".format(s1) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index d363597e107a63cdfc0c8f6f429e0f023b0b7c38..e164752483e2b4ad5896d754d378a5519c960237 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -2,6 +2,7 @@ import random import numpy as np +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -24,11 +25,694 @@ def test_sparse_rail_generator(): schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) - # reset to initialize agents_static - env_renderer = RenderTool(env, gl="PILSVG", ) - env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - env_renderer.gl.save_image("./sparse_generator_false.png") - # TODO test assertions! + expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) + expected_grid_map[1][33] = 8192 + expected_grid_map[2][33] = 32800 + expected_grid_map[3][31] = 4 + expected_grid_map[3][32] = 4608 + expected_grid_map[3][33] = 32800 + expected_grid_map[4][30] = 16386 + expected_grid_map[4][31] = 17411 + expected_grid_map[4][32] = 1097 + expected_grid_map[4][33] = 38505 + expected_grid_map[4][34] = 1025 + expected_grid_map[4][35] = 5633 + expected_grid_map[4][36] = 1025 + expected_grid_map[4][37] = 1025 + expected_grid_map[4][38] = 1025 + expected_grid_map[4][39] = 1025 + expected_grid_map[4][40] = 1025 + expected_grid_map[4][41] = 1025 + expected_grid_map[4][42] = 1025 + expected_grid_map[4][43] = 1025 + expected_grid_map[4][44] = 1025 + expected_grid_map[4][45] = 1025 + expected_grid_map[4][46] = 1025 + expected_grid_map[4][47] = 1025 + expected_grid_map[4][48] = 4608 + expected_grid_map[5][30] = 128 + expected_grid_map[5][31] = 32800 + expected_grid_map[5][33] = 32800 + expected_grid_map[5][35] = 32800 + expected_grid_map[5][48] = 32800 + expected_grid_map[6][30] = 4 + expected_grid_map[6][31] = 2064 + expected_grid_map[6][33] = 32800 + expected_grid_map[6][35] = 128 + expected_grid_map[6][48] = 32800 + expected_grid_map[7][33] = 32872 + expected_grid_map[7][34] = 1025 + expected_grid_map[7][35] = 1025 + expected_grid_map[7][36] = 1025 + expected_grid_map[7][37] = 1025 + expected_grid_map[7][38] = 1025 + expected_grid_map[7][39] = 1025 + expected_grid_map[7][40] = 1025 + expected_grid_map[7][41] = 20994 + expected_grid_map[7][42] = 1025 + expected_grid_map[7][43] = 1025 + expected_grid_map[7][44] = 1025 + expected_grid_map[7][45] = 1025 + expected_grid_map[7][46] = 1025 + expected_grid_map[7][47] = 4608 + expected_grid_map[7][48] = 32800 + expected_grid_map[8][3] = 16386 + expected_grid_map[8][4] = 1025 + expected_grid_map[8][5] = 1025 + expected_grid_map[8][6] = 1025 + expected_grid_map[8][7] = 1025 + expected_grid_map[8][8] = 1025 + expected_grid_map[8][9] = 1025 + expected_grid_map[8][10] = 1025 + expected_grid_map[8][11] = 5633 + expected_grid_map[8][12] = 1025 + expected_grid_map[8][13] = 1025 + expected_grid_map[8][14] = 1025 + expected_grid_map[8][15] = 1025 + expected_grid_map[8][16] = 1025 + expected_grid_map[8][17] = 1025 + expected_grid_map[8][18] = 4608 + expected_grid_map[8][33] = 32800 + expected_grid_map[8][41] = 32800 + expected_grid_map[8][47] = 32800 + expected_grid_map[8][48] = 32800 + expected_grid_map[9][3] = 32800 + expected_grid_map[9][11] = 32800 + expected_grid_map[9][12] = 8192 + expected_grid_map[9][13] = 8192 + expected_grid_map[9][18] = 32800 + expected_grid_map[9][33] = 32800 + expected_grid_map[9][41] = 32800 + expected_grid_map[9][47] = 32800 + expected_grid_map[9][48] = 32800 + expected_grid_map[10][3] = 32800 + expected_grid_map[10][8] = 8192 + expected_grid_map[10][11] = 32800 + expected_grid_map[10][12] = 32800 + expected_grid_map[10][13] = 32800 + expected_grid_map[10][18] = 32800 + expected_grid_map[10][33] = 32800 + expected_grid_map[10][41] = 32800 + expected_grid_map[10][47] = 32800 + expected_grid_map[10][48] = 32800 + expected_grid_map[11][3] = 32800 + expected_grid_map[11][8] = 32800 + expected_grid_map[11][11] = 32800 + expected_grid_map[11][12] = 32800 + expected_grid_map[11][13] = 32800 + expected_grid_map[11][18] = 32800 + expected_grid_map[11][33] = 32800 + expected_grid_map[11][41] = 32800 + expected_grid_map[11][47] = 32800 + expected_grid_map[11][48] = 32800 + expected_grid_map[12][3] = 32800 + expected_grid_map[12][8] = 72 + expected_grid_map[12][9] = 1025 + expected_grid_map[12][10] = 17411 + expected_grid_map[12][11] = 52275 + expected_grid_map[12][12] = 3089 + expected_grid_map[12][13] = 3089 + expected_grid_map[12][14] = 1025 + expected_grid_map[12][15] = 1025 + expected_grid_map[12][16] = 1025 + expected_grid_map[12][17] = 1025 + expected_grid_map[12][18] = 33825 + expected_grid_map[12][19] = 1025 + expected_grid_map[12][20] = 1025 + expected_grid_map[12][21] = 1025 + expected_grid_map[12][22] = 1025 + expected_grid_map[12][23] = 1025 + expected_grid_map[12][24] = 1025 + expected_grid_map[12][25] = 1025 + expected_grid_map[12][26] = 1025 + expected_grid_map[12][27] = 1025 + expected_grid_map[12][28] = 1025 + expected_grid_map[12][29] = 1025 + expected_grid_map[12][30] = 1025 + expected_grid_map[12][31] = 1025 + expected_grid_map[12][32] = 1025 + expected_grid_map[12][33] = 33825 + expected_grid_map[12][34] = 1025 + expected_grid_map[12][35] = 1025 + expected_grid_map[12][36] = 1025 + expected_grid_map[12][37] = 1025 + expected_grid_map[12][38] = 1025 + expected_grid_map[12][39] = 1025 + expected_grid_map[12][40] = 1025 + expected_grid_map[12][41] = 35889 + expected_grid_map[12][42] = 4608 + expected_grid_map[12][47] = 32800 + expected_grid_map[12][48] = 32800 + expected_grid_map[13][3] = 32800 + expected_grid_map[13][10] = 32800 + expected_grid_map[13][11] = 32872 + expected_grid_map[13][12] = 1025 + expected_grid_map[13][13] = 256 + expected_grid_map[13][15] = 8192 + expected_grid_map[13][16] = 8192 + expected_grid_map[13][17] = 8192 + expected_grid_map[13][18] = 32800 + expected_grid_map[13][20] = 8192 + expected_grid_map[13][33] = 32800 + expected_grid_map[13][41] = 32800 + expected_grid_map[13][42] = 32800 + expected_grid_map[13][47] = 32800 + expected_grid_map[13][48] = 32800 + expected_grid_map[14][3] = 32800 + expected_grid_map[14][10] = 128 + expected_grid_map[14][11] = 32800 + expected_grid_map[14][15] = 72 + expected_grid_map[14][16] = 37408 + expected_grid_map[14][17] = 32800 + expected_grid_map[14][18] = 32800 + expected_grid_map[14][20] = 32800 + expected_grid_map[14][33] = 32800 + expected_grid_map[14][41] = 32800 + expected_grid_map[14][42] = 32800 + expected_grid_map[14][47] = 32800 + expected_grid_map[14][48] = 32800 + expected_grid_map[15][3] = 32800 + expected_grid_map[15][11] = 32800 + expected_grid_map[15][15] = 4 + expected_grid_map[15][16] = 1097 + expected_grid_map[15][17] = 1097 + expected_grid_map[15][18] = 3089 + expected_grid_map[15][19] = 1025 + expected_grid_map[15][20] = 3089 + expected_grid_map[15][21] = 1025 + expected_grid_map[15][22] = 1025 + expected_grid_map[15][23] = 1025 + expected_grid_map[15][24] = 1025 + expected_grid_map[15][25] = 1025 + expected_grid_map[15][26] = 1025 + expected_grid_map[15][27] = 1025 + expected_grid_map[15][28] = 1025 + expected_grid_map[15][29] = 1025 + expected_grid_map[15][30] = 1025 + expected_grid_map[15][31] = 1025 + expected_grid_map[15][32] = 1025 + expected_grid_map[15][33] = 33825 + expected_grid_map[15][34] = 1025 + expected_grid_map[15][35] = 1025 + expected_grid_map[15][36] = 1025 + expected_grid_map[15][37] = 1025 + expected_grid_map[15][38] = 1025 + expected_grid_map[15][39] = 1025 + expected_grid_map[15][40] = 1025 + expected_grid_map[15][41] = 35889 + expected_grid_map[15][42] = 37408 + expected_grid_map[15][47] = 32800 + expected_grid_map[15][48] = 32800 + expected_grid_map[16][3] = 32800 + expected_grid_map[16][7] = 8192 + expected_grid_map[16][11] = 32800 + expected_grid_map[16][33] = 32800 + expected_grid_map[16][41] = 32800 + expected_grid_map[16][42] = 32800 + expected_grid_map[16][47] = 32800 + expected_grid_map[16][48] = 32800 + expected_grid_map[17][3] = 32800 + expected_grid_map[17][7] = 32800 + expected_grid_map[17][9] = 8192 + expected_grid_map[17][10] = 8192 + expected_grid_map[17][11] = 32800 + expected_grid_map[17][33] = 32800 + expected_grid_map[17][41] = 32800 + expected_grid_map[17][42] = 32800 + expected_grid_map[17][47] = 32800 + expected_grid_map[17][48] = 32800 + expected_grid_map[18][3] = 32800 + expected_grid_map[18][7] = 32800 + expected_grid_map[18][8] = 8192 + expected_grid_map[18][9] = 32800 + expected_grid_map[18][10] = 32800 + expected_grid_map[18][11] = 32800 + expected_grid_map[18][33] = 32800 + expected_grid_map[18][41] = 32800 + expected_grid_map[18][42] = 32800 + expected_grid_map[18][47] = 32800 + expected_grid_map[18][48] = 32800 + expected_grid_map[19][3] = 72 + expected_grid_map[19][4] = 1025 + expected_grid_map[19][5] = 1025 + expected_grid_map[19][6] = 1025 + expected_grid_map[19][7] = 1097 + expected_grid_map[19][8] = 1097 + expected_grid_map[19][9] = 1097 + expected_grid_map[19][10] = 52275 + expected_grid_map[19][11] = 33825 + expected_grid_map[19][12] = 1025 + expected_grid_map[19][13] = 1025 + expected_grid_map[19][14] = 4608 + expected_grid_map[19][33] = 32800 + expected_grid_map[19][41] = 32800 + expected_grid_map[19][42] = 32800 + expected_grid_map[19][47] = 32800 + expected_grid_map[19][48] = 32800 + expected_grid_map[20][7] = 4 + expected_grid_map[20][8] = 1025 + expected_grid_map[20][9] = 1025 + expected_grid_map[20][10] = 34864 + expected_grid_map[20][11] = 32800 + expected_grid_map[20][14] = 32800 + expected_grid_map[20][33] = 32800 + expected_grid_map[20][41] = 32800 + expected_grid_map[20][42] = 32800 + expected_grid_map[20][47] = 32800 + expected_grid_map[20][48] = 32800 + expected_grid_map[21][10] = 32800 + expected_grid_map[21][11] = 32800 + expected_grid_map[21][14] = 32800 + expected_grid_map[21][24] = 8192 + expected_grid_map[21][33] = 32872 + expected_grid_map[21][34] = 1025 + expected_grid_map[21][35] = 1025 + expected_grid_map[21][36] = 1025 + expected_grid_map[21][37] = 1025 + expected_grid_map[21][38] = 1025 + expected_grid_map[21][39] = 1025 + expected_grid_map[21][40] = 1025 + expected_grid_map[21][41] = 33825 + expected_grid_map[21][42] = 38505 + expected_grid_map[21][43] = 1025 + expected_grid_map[21][44] = 1025 + expected_grid_map[21][45] = 1025 + expected_grid_map[21][46] = 1025 + expected_grid_map[21][47] = 37408 + expected_grid_map[21][48] = 32800 + expected_grid_map[22][10] = 32800 + expected_grid_map[22][11] = 32800 + expected_grid_map[22][14] = 32800 + expected_grid_map[22][22] = 8192 + expected_grid_map[22][24] = 32800 + expected_grid_map[22][27] = 8192 + expected_grid_map[22][33] = 32800 + expected_grid_map[22][41] = 32800 + expected_grid_map[22][42] = 32800 + expected_grid_map[22][47] = 32800 + expected_grid_map[22][48] = 32800 + expected_grid_map[23][10] = 32800 + expected_grid_map[23][11] = 32800 + expected_grid_map[23][14] = 32800 + expected_grid_map[23][22] = 72 + expected_grid_map[23][23] = 17411 + expected_grid_map[23][24] = 1097 + expected_grid_map[23][25] = 17411 + expected_grid_map[23][26] = 1025 + expected_grid_map[23][27] = 3089 + expected_grid_map[23][28] = 1025 + expected_grid_map[23][29] = 1025 + expected_grid_map[23][30] = 1025 + expected_grid_map[23][31] = 1025 + expected_grid_map[23][32] = 1025 + expected_grid_map[23][33] = 33825 + expected_grid_map[23][34] = 1025 + expected_grid_map[23][35] = 1025 + expected_grid_map[23][36] = 1025 + expected_grid_map[23][37] = 1025 + expected_grid_map[23][38] = 1025 + expected_grid_map[23][39] = 1025 + expected_grid_map[23][40] = 1025 + expected_grid_map[23][41] = 3089 + expected_grid_map[23][42] = 34864 + expected_grid_map[23][47] = 32800 + expected_grid_map[23][48] = 32800 + expected_grid_map[24][10] = 32800 + expected_grid_map[24][11] = 32800 + expected_grid_map[24][14] = 32800 + expected_grid_map[24][23] = 32800 + expected_grid_map[24][24] = 4 + expected_grid_map[24][25] = 34864 + expected_grid_map[24][33] = 32800 + expected_grid_map[24][42] = 32800 + expected_grid_map[24][47] = 32800 + expected_grid_map[24][48] = 32800 + expected_grid_map[25][10] = 32800 + expected_grid_map[25][11] = 32800 + expected_grid_map[25][14] = 32800 + expected_grid_map[25][23] = 128 + expected_grid_map[25][25] = 32800 + expected_grid_map[25][33] = 32800 + expected_grid_map[25][42] = 32800 + expected_grid_map[25][47] = 32800 + expected_grid_map[25][48] = 32800 + expected_grid_map[26][10] = 32800 + expected_grid_map[26][11] = 32800 + expected_grid_map[26][14] = 32800 + expected_grid_map[26][25] = 32800 + expected_grid_map[26][33] = 32800 + expected_grid_map[26][42] = 32800 + expected_grid_map[26][47] = 32800 + expected_grid_map[26][48] = 32800 + expected_grid_map[27][10] = 32800 + expected_grid_map[27][11] = 32800 + expected_grid_map[27][14] = 32800 + expected_grid_map[27][25] = 32800 + expected_grid_map[27][33] = 32800 + expected_grid_map[27][42] = 32800 + expected_grid_map[27][47] = 32800 + expected_grid_map[27][48] = 32800 + expected_grid_map[28][10] = 32800 + expected_grid_map[28][11] = 32800 + expected_grid_map[28][14] = 32800 + expected_grid_map[28][25] = 32800 + expected_grid_map[28][33] = 49186 + expected_grid_map[28][34] = 256 + expected_grid_map[28][42] = 32800 + expected_grid_map[28][44] = 8192 + expected_grid_map[28][45] = 8192 + expected_grid_map[28][47] = 32800 + expected_grid_map[28][48] = 32800 + expected_grid_map[28][49] = 8192 + expected_grid_map[29][10] = 32800 + expected_grid_map[29][11] = 32800 + expected_grid_map[29][14] = 32800 + expected_grid_map[29][25] = 32800 + expected_grid_map[29][32] = 16386 + expected_grid_map[29][33] = 37408 + expected_grid_map[29][34] = 8192 + expected_grid_map[29][42] = 32800 + expected_grid_map[29][44] = 72 + expected_grid_map[29][45] = 37408 + expected_grid_map[29][47] = 32800 + expected_grid_map[29][48] = 32800 + expected_grid_map[29][49] = 32800 + expected_grid_map[30][10] = 32800 + expected_grid_map[30][11] = 32800 + expected_grid_map[30][14] = 32800 + expected_grid_map[30][25] = 32800 + expected_grid_map[30][32] = 128 + expected_grid_map[30][33] = 49186 + expected_grid_map[30][34] = 33825 + expected_grid_map[30][35] = 1025 + expected_grid_map[30][36] = 1025 + expected_grid_map[30][37] = 1025 + expected_grid_map[30][38] = 1025 + expected_grid_map[30][39] = 5633 + expected_grid_map[30][40] = 1025 + expected_grid_map[30][41] = 1025 + expected_grid_map[30][42] = 2064 + expected_grid_map[30][45] = 16458 + expected_grid_map[30][46] = 17411 + expected_grid_map[30][47] = 38505 + expected_grid_map[30][48] = 38433 + expected_grid_map[30][49] = 2064 + expected_grid_map[31][10] = 32800 + expected_grid_map[31][11] = 32800 + expected_grid_map[31][14] = 32800 + expected_grid_map[31][25] = 32800 + expected_grid_map[31][30] = 8192 + expected_grid_map[31][31] = 4 + expected_grid_map[31][32] = 17411 + expected_grid_map[31][33] = 34864 + expected_grid_map[31][34] = 32800 + expected_grid_map[31][39] = 32800 + expected_grid_map[31][45] = 32800 + expected_grid_map[31][46] = 32800 + expected_grid_map[31][47] = 32800 + expected_grid_map[31][48] = 32800 + expected_grid_map[32][10] = 32800 + expected_grid_map[32][11] = 32800 + expected_grid_map[32][14] = 32800 + expected_grid_map[32][25] = 32800 + expected_grid_map[32][30] = 72 + expected_grid_map[32][31] = 1025 + expected_grid_map[32][32] = 2064 + expected_grid_map[32][33] = 32872 + expected_grid_map[32][34] = 2064 + expected_grid_map[32][39] = 32800 + expected_grid_map[32][45] = 128 + expected_grid_map[32][46] = 128 + expected_grid_map[32][47] = 32800 + expected_grid_map[32][48] = 32800 + expected_grid_map[33][10] = 32800 + expected_grid_map[33][11] = 32800 + expected_grid_map[33][14] = 32872 + expected_grid_map[33][15] = 1025 + expected_grid_map[33][16] = 1025 + expected_grid_map[33][17] = 1025 + expected_grid_map[33][18] = 1025 + expected_grid_map[33][19] = 1025 + expected_grid_map[33][20] = 1025 + expected_grid_map[33][21] = 1025 + expected_grid_map[33][22] = 1025 + expected_grid_map[33][23] = 1025 + expected_grid_map[33][24] = 1025 + expected_grid_map[33][25] = 35889 + expected_grid_map[33][26] = 1025 + expected_grid_map[33][27] = 1025 + expected_grid_map[33][28] = 1025 + expected_grid_map[33][29] = 1025 + expected_grid_map[33][30] = 1025 + expected_grid_map[33][31] = 1025 + expected_grid_map[33][32] = 1025 + expected_grid_map[33][33] = 34864 + expected_grid_map[33][39] = 32800 + expected_grid_map[33][47] = 32800 + expected_grid_map[33][48] = 32800 + expected_grid_map[34][5] = 16386 + expected_grid_map[34][6] = 1025 + expected_grid_map[34][7] = 1025 + expected_grid_map[34][8] = 1025 + expected_grid_map[34][9] = 1025 + expected_grid_map[34][10] = 33825 + expected_grid_map[34][11] = 3089 + expected_grid_map[34][12] = 1025 + expected_grid_map[34][13] = 1025 + expected_grid_map[34][14] = 33825 + expected_grid_map[34][15] = 1025 + expected_grid_map[34][16] = 1025 + expected_grid_map[34][17] = 1025 + expected_grid_map[34][18] = 1025 + expected_grid_map[34][19] = 1025 + expected_grid_map[34][20] = 1025 + expected_grid_map[34][21] = 1025 + expected_grid_map[34][22] = 1025 + expected_grid_map[34][23] = 1025 + expected_grid_map[34][24] = 1025 + expected_grid_map[34][25] = 2064 + expected_grid_map[34][33] = 32800 + expected_grid_map[34][39] = 32800 + expected_grid_map[34][47] = 32800 + expected_grid_map[34][48] = 32800 + expected_grid_map[35][5] = 32800 + expected_grid_map[35][10] = 32800 + expected_grid_map[35][14] = 32800 + expected_grid_map[35][16] = 8192 + expected_grid_map[35][33] = 32800 + expected_grid_map[35][39] = 32800 + expected_grid_map[35][47] = 32800 + expected_grid_map[35][48] = 32800 + expected_grid_map[36][5] = 32800 + expected_grid_map[36][10] = 32800 + expected_grid_map[36][14] = 32800 + expected_grid_map[36][16] = 32800 + expected_grid_map[36][17] = 8192 + expected_grid_map[36][19] = 8192 + expected_grid_map[36][33] = 32800 + expected_grid_map[36][39] = 32800 + expected_grid_map[36][41] = 8192 + expected_grid_map[36][47] = 32800 + expected_grid_map[36][48] = 32800 + expected_grid_map[37][5] = 32800 + expected_grid_map[37][10] = 32800 + expected_grid_map[37][14] = 32800 + expected_grid_map[37][16] = 32800 + expected_grid_map[37][17] = 49186 + expected_grid_map[37][18] = 1025 + expected_grid_map[37][19] = 2064 + expected_grid_map[37][33] = 32800 + expected_grid_map[37][39] = 32800 + expected_grid_map[37][41] = 32800 + expected_grid_map[37][42] = 16386 + expected_grid_map[37][43] = 256 + expected_grid_map[37][47] = 32800 + expected_grid_map[37][48] = 32800 + expected_grid_map[38][5] = 72 + expected_grid_map[38][6] = 1025 + expected_grid_map[38][7] = 1025 + expected_grid_map[38][8] = 1025 + expected_grid_map[38][9] = 1025 + expected_grid_map[38][10] = 33825 + expected_grid_map[38][11] = 1025 + expected_grid_map[38][12] = 1025 + expected_grid_map[38][13] = 1025 + expected_grid_map[38][14] = 33897 + expected_grid_map[38][15] = 17411 + expected_grid_map[38][16] = 1097 + expected_grid_map[38][17] = 38505 + expected_grid_map[38][18] = 256 + expected_grid_map[38][33] = 32800 + expected_grid_map[38][39] = 32800 + expected_grid_map[38][41] = 32800 + expected_grid_map[38][42] = 32800 + expected_grid_map[38][43] = 8192 + expected_grid_map[38][47] = 32800 + expected_grid_map[38][48] = 32800 + expected_grid_map[39][10] = 32800 + expected_grid_map[39][14] = 32800 + expected_grid_map[39][15] = 32800 + expected_grid_map[39][17] = 32800 + expected_grid_map[39][18] = 4 + expected_grid_map[39][19] = 4608 + expected_grid_map[39][33] = 32800 + expected_grid_map[39][39] = 49186 + expected_grid_map[39][40] = 17411 + expected_grid_map[39][41] = 1097 + expected_grid_map[39][42] = 52275 + expected_grid_map[39][43] = 3089 + expected_grid_map[39][44] = 4608 + expected_grid_map[39][47] = 32800 + expected_grid_map[39][48] = 32800 + expected_grid_map[40][10] = 32800 + expected_grid_map[40][14] = 32800 + expected_grid_map[40][15] = 128 + expected_grid_map[40][17] = 32800 + expected_grid_map[40][18] = 8192 + expected_grid_map[40][19] = 32800 + expected_grid_map[40][33] = 32800 + expected_grid_map[40][39] = 32800 + expected_grid_map[40][40] = 32800 + expected_grid_map[40][42] = 32872 + expected_grid_map[40][43] = 4608 + expected_grid_map[40][44] = 32800 + expected_grid_map[40][47] = 32800 + expected_grid_map[40][48] = 32800 + expected_grid_map[41][10] = 32800 + expected_grid_map[41][14] = 32800 + expected_grid_map[41][17] = 32800 + expected_grid_map[41][18] = 72 + expected_grid_map[41][19] = 37408 + expected_grid_map[41][21] = 8192 + expected_grid_map[41][33] = 32800 + expected_grid_map[41][39] = 32800 + expected_grid_map[41][40] = 128 + expected_grid_map[41][42] = 32800 + expected_grid_map[41][43] = 128 + expected_grid_map[41][44] = 32800 + expected_grid_map[41][47] = 32800 + expected_grid_map[41][48] = 32800 + expected_grid_map[42][10] = 32800 + expected_grid_map[42][14] = 72 + expected_grid_map[42][15] = 1025 + expected_grid_map[42][16] = 1025 + expected_grid_map[42][17] = 33825 + expected_grid_map[42][18] = 17411 + expected_grid_map[42][19] = 52275 + expected_grid_map[42][20] = 5633 + expected_grid_map[42][21] = 3089 + expected_grid_map[42][22] = 1025 + expected_grid_map[42][23] = 1025 + expected_grid_map[42][24] = 1025 + expected_grid_map[42][25] = 1025 + expected_grid_map[42][26] = 1025 + expected_grid_map[42][27] = 1025 + expected_grid_map[42][28] = 1025 + expected_grid_map[42][29] = 1025 + expected_grid_map[42][30] = 4608 + expected_grid_map[42][33] = 32800 + expected_grid_map[42][39] = 32800 + expected_grid_map[42][42] = 32800 + expected_grid_map[42][44] = 32800 + expected_grid_map[42][47] = 32800 + expected_grid_map[42][48] = 32800 + expected_grid_map[43][10] = 32800 + expected_grid_map[43][17] = 32800 + expected_grid_map[43][18] = 128 + expected_grid_map[43][19] = 32800 + expected_grid_map[43][20] = 32800 + expected_grid_map[43][30] = 32800 + expected_grid_map[43][33] = 32800 + expected_grid_map[43][39] = 32800 + expected_grid_map[43][42] = 32800 + expected_grid_map[43][44] = 32800 + expected_grid_map[43][47] = 32800 + expected_grid_map[43][48] = 32800 + expected_grid_map[44][4] = 4 + expected_grid_map[44][5] = 1025 + expected_grid_map[44][6] = 1025 + expected_grid_map[44][7] = 1025 + expected_grid_map[44][8] = 1025 + expected_grid_map[44][9] = 1025 + expected_grid_map[44][10] = 3089 + expected_grid_map[44][11] = 1025 + expected_grid_map[44][12] = 1025 + expected_grid_map[44][13] = 1025 + expected_grid_map[44][14] = 1025 + expected_grid_map[44][15] = 1025 + expected_grid_map[44][16] = 1025 + expected_grid_map[44][17] = 3089 + expected_grid_map[44][18] = 1025 + expected_grid_map[44][19] = 2064 + expected_grid_map[44][20] = 128 + expected_grid_map[44][30] = 72 + expected_grid_map[44][31] = 1025 + expected_grid_map[44][32] = 1025 + expected_grid_map[44][33] = 35889 + expected_grid_map[44][34] = 1025 + expected_grid_map[44][35] = 1025 + expected_grid_map[44][36] = 1025 + expected_grid_map[44][37] = 1025 + expected_grid_map[44][38] = 1025 + expected_grid_map[44][39] = 33825 + expected_grid_map[44][40] = 1025 + expected_grid_map[44][41] = 1025 + expected_grid_map[44][42] = 2064 + expected_grid_map[44][44] = 32800 + expected_grid_map[44][47] = 32800 + expected_grid_map[44][48] = 32800 + expected_grid_map[45][33] = 32872 + expected_grid_map[45][34] = 1025 + expected_grid_map[45][35] = 1025 + expected_grid_map[45][36] = 1025 + expected_grid_map[45][37] = 1025 + expected_grid_map[45][38] = 1025 + expected_grid_map[45][39] = 33825 + expected_grid_map[45][40] = 1025 + expected_grid_map[45][41] = 1025 + expected_grid_map[45][42] = 1025 + expected_grid_map[45][43] = 1025 + expected_grid_map[45][44] = 1097 + expected_grid_map[45][45] = 1025 + expected_grid_map[45][46] = 1025 + expected_grid_map[45][47] = 34864 + expected_grid_map[45][48] = 32800 + expected_grid_map[46][33] = 32800 + expected_grid_map[46][39] = 32800 + expected_grid_map[46][47] = 32800 + expected_grid_map[46][48] = 32800 + expected_grid_map[47][33] = 32800 + expected_grid_map[47][39] = 32800 + expected_grid_map[47][47] = 32800 + expected_grid_map[47][48] = 128 + expected_grid_map[48][33] = 32800 + expected_grid_map[48][39] = 32800 + expected_grid_map[48][47] = 32800 + expected_grid_map[49][33] = 72 + expected_grid_map[49][34] = 1025 + expected_grid_map[49][35] = 1025 + expected_grid_map[49][36] = 1025 + expected_grid_map[49][37] = 1025 + expected_grid_map[49][38] = 1025 + expected_grid_map[49][39] = 2136 + expected_grid_map[49][40] = 1025 + expected_grid_map[49][41] = 1025 + expected_grid_map[49][42] = 1025 + expected_grid_map[49][43] = 1025 + expected_grid_map[49][44] = 1025 + expected_grid_map[49][45] = 1025 + expected_grid_map[49][46] = 1025 + expected_grid_map[49][47] = 2064 + + assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid, + expected_grid_map) + s0 = 0 + s1 = 0 + for a in range(env.get_num_agents()): + s0 = Vec2d.get_manhattan_distance(env.agents[a].position, (0, 0)) + s1 = Vec2d.get_chebyshev_distance(env.agents[a].position, (0, 0)) + assert s0 == 53, "actual={}".format(s0) + assert s1 == 36, "actual={}".format(s1) def test_sparse_rail_generator_deterministic(): @@ -897,6 +1581,7 @@ def test_rail_env_action_required_info(): if done_always_action['__all__']: break + env_renderer.close_window() def test_rail_env_malfunction_speed_info(): @@ -947,6 +1632,7 @@ def test_rail_env_malfunction_speed_info(): if done['__all__']: break + env_renderer.close_window() def test_sparse_generator_with_too_man_cities_does_not_break_down(): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index fde9df58663993ae170c4c1e3fea55637feb4282..884a2a51f84a40a45acced32e7310dcf4d497944 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -48,7 +48,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent.position, direction) - min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + min_distances.append( + self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 8248c675995fc5c906e82d8650a5b619e7b038f2..853b025f2ebd39949453f35e0d053e519163237c 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -45,14 +45,12 @@ def test_render_env(save_new_images=False): oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.render_env(show=False) - checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) - - + def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True)