From 9babe92826827a697ef488776cc562b8d89a3a8d Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Sep 2019 13:35:53 +0200 Subject: [PATCH] refactoring --- .../Simple_Realistic_Railway_Generator.py | 2 - flatland/core/grid/grid4_astar.py | 9 +++-- flatland/core/grid/grid_utils.py | 13 +++++++ flatland/core/transition_map.py | 38 ++++++++++--------- tests/test_flatland_core_transitions.py | 16 ++++---- 5 files changed, 46 insertions(+), 32 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index f935d3e3..09f3b538 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -1,6 +1,5 @@ import copy import os -import time import warnings import numpy as np @@ -278,7 +277,6 @@ def realistic_rail_generator(num_cities=5, grid_map.grid[start_node] = tmp_trans_sn grid_map.grid[end_node] = tmp_trans_en - connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added) def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index d0b2cc97..779ee984 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,4 +1,3 @@ - from flatland.core.grid.grid_utils import IntVector2D from flatland.core.grid.grid_utils import IntVector2DArrayType from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d @@ -75,13 +74,15 @@ def a_star(rail_trans: RailEnvTransitions, else: prev_pos = None for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: - node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + # update the "current" pos + node_pos = Vec2d.add(current_node.pos, new_pos) + + # is node_pos inside the grid? if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: continue # validate positions - if not grid_map.validate_new_transition(rail_trans, prev_pos, current_node.pos, node_pos, - end_node.pos): + if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos): continue # create new node diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py index a9d6ffaa..f9657bf3 100644 --- a/flatland/core/grid/grid_utils.py +++ b/flatland/core/grid/grid_utils.py @@ -10,6 +10,19 @@ IntVector2DArrayType = [] class Vec2dOperations: + @staticmethod + def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool: + """ + vector operation : node_a + node_b + + :param node_a: tuple with coordinate (x,y) or 2d vector + :param node_b: tuple with coordinate (x,y) or 2d vector + :return: + ------- + check if node_a and nobe_b are equal + """ + return node_a[0] == node_b[0] and node_a[1] == node_b[1] + @staticmethod def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D: """ diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 2b06714b..d8da9c56 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -9,6 +9,7 @@ from numpy import array from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position, get_direction from flatland.core.grid.grid_utils import IntVector2DArrayType +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions @@ -301,7 +302,7 @@ class GridTransitionMap(TransitionMap): self.height = new_height self.grid = new_grid - def is_dead_end(self, rcPos): + def is_dead_end(self, rcPos: IntVector2DArrayType): """ Check if the cell is a dead-end. @@ -321,7 +322,7 @@ class GridTransitionMap(TransitionMap): tmp = tmp >> 1 return nbits == 1 - def is_simple_turn(self, rcPos): + def is_simple_turn(self, rcPos: IntVector2DArrayType): """ Check if the cell is a left/right simple turn @@ -348,7 +349,7 @@ class GridTransitionMap(TransitionMap): return is_simple_turn(tmp) - def check_path_exists(self, start, direction, end): + def check_path_exists(self, start: IntVector2DArrayType, direction: int, end: IntVector2DArrayType): # print("_path_exists({},{},{}".format(start, direction, end)) # BFS - Check if a path exists between the 2 nodes @@ -358,7 +359,8 @@ class GridTransitionMap(TransitionMap): node = stack.pop() node_position = node[0] node_direction = node[1] - if node_position[0] == end[0] and node_position[1] == end[1]: + + if Vec2d.is_equal(node_position, end): return True if node not in visited: visited.add(node) @@ -371,7 +373,7 @@ class GridTransitionMap(TransitionMap): return False - def cell_neighbours_valid(self, rcPos, check_this_cell=False): + def cell_neighbours_valid(self, rcPos: IntVector2DArrayType, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) Checks that: @@ -423,7 +425,7 @@ class GridTransitionMap(TransitionMap): return True - def fix_neighbours(self, rcPos, check_this_cell=False): + def fix_neighbours(self, rcPos: IntVector2DArrayType, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) Checks that: @@ -476,7 +478,7 @@ class GridTransitionMap(TransitionMap): return True - def fix_transitions(self, rcPos): + def fix_transitions(self, rcPos: IntVector2DArrayType): """ Fixes broken transitions """ @@ -541,9 +543,9 @@ class GridTransitionMap(TransitionMap): self.set_transitions((rcPos[0], rcPos[1]), transition) return True - def validate_new_transition(self, rail_trans: RailEnvTransitions, - prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType, + def validate_new_transition(self, prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType, new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType): + # start by getting direction used to get to current node # and direction from current node to possible child node new_dir = get_direction(current_pos, new_pos) @@ -556,30 +558,30 @@ class GridTransitionMap(TransitionMap): if prev_pos is None: if new_trans == 0: # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1) else: # check if matches existing layout - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1) else: # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1) # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - if new_pos == end_pos: + new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if Vec2d.is_equal(new_pos, end_pos): # need to validate end pos setup as well new_trans_e = self.grid[end_pos] if new_trans_e == 0: # need to flip direction because of how end points are defined - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) else: # check if matches existing layout - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1) - if not rail_trans.is_valid(new_trans_e): + if not self.transitions.is_valid(new_trans_e): return False # is transition is valid? - return rail_trans.is_valid(new_trans) + return self.transitions.is_valid(new_trans) def mirror(dir): diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 9c01be4b..b4c268a7 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -118,32 +118,32 @@ def test_adding_new_valid_transition(): grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans) # adding straight - assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (6, 5), (10, 10)) is True) + assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5), (10, 10)) is True) # adding valid right turn - assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6), (10, 10)) is True) # adding valid left turn - assert (grid_map.validate_new_transition(rail_trans, (5, 6), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn grid_map.grid[(5, 5)] = rail_trans.transitions[2] - assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid grid_map.grid[(5, 5)] = rail_trans.transitions[3] - assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn grid_map.grid[(5, 5)] = rail_trans.transitions[7] - assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition grid_map.grid[(5, 5)] = rail_trans.transitions[0] - assert (grid_map.validate_new_transition(rail_trans, None, (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition grid_map.grid[(5, 5)] = rail_trans.transitions[0] - assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (6, 5), (6, 5)) is True) + assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True) def test_valid_railenv_transitions(): -- GitLab