From ef0b7175d48f8f727679ff467238990cbd1a8f4c Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Sep 2019 13:19:20 +0200 Subject: [PATCH] refactoring and clean up --- .../Simple_Realistic_Railway_Generator.py | 4 +- flatland/core/grid/grid4_astar.py | 10 ++--- flatland/core/grid/grid4_utils.py | 42 +----------------- flatland/core/transition_map.py | 43 ++++++++++++++++++- tests/test_flatland_core_transitions.py | 32 +++++++------- 5 files changed, 66 insertions(+), 65 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 6e5e6a56..f935d3e3 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -451,7 +451,7 @@ def realistic_rail_generator(num_cities=5, if os.path.exists("./../render_output/"): for itrials in range(1000): print(itrials, "generate new city") - np.random.seed(0 * int(time.time())) + np.random.seed(itrials) env = RailEnv(width=40 + np.random.choice(100), height=40 + np.random.choice(100), rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), @@ -462,7 +462,7 @@ if os.path.exists("./../render_output/"): connect_max_nbr_of_shortes_city=2, do_random_connect_stations=False, # Number of cities in map - seed=0*int(time.time()), # Random seed + seed=itrials, # Random seed print_out_info=True ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index c04d71dd..d0b2cc97 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,4 +1,4 @@ -from flatland.core.grid.grid4_utils import validate_new_transition + from flatland.core.grid.grid_utils import IntVector2D from flatland.core.grid.grid_utils import IntVector2DArrayType from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d @@ -12,9 +12,9 @@ class AStarNode: def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None): self.parent: IntVector2D = parent self.pos: IntVector2D = pos - self.g: float = 0.0 - self.h: float = 0.0 - self.f: float = 0.0 + self.g = 0.0 + self.h = 0.0 + self.f = 0.0 def __eq__(self, other: IntVector2D): return self.pos == other.pos @@ -80,7 +80,7 @@ def a_star(rail_trans: RailEnvTransitions, continue # validate positions - if not validate_new_transition(rail_trans, grid_map.grid, prev_pos, current_node.pos, node_pos, + if not grid_map.validate_new_transition(rail_trans, prev_pos, current_node.pos, node_pos, end_node.pos): continue diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py index d64a160b..0a0ba6b8 100644 --- a/flatland/core/grid/grid4_utils.py +++ b/flatland/core/grid/grid4_utils.py @@ -1,7 +1,8 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid_utils import IntVector2DArrayType -def get_direction(pos1, pos2) -> Grid4TransitionsEnum: +def get_direction(pos1: IntVector2DArrayType, pos2: IntVector2DArrayType) -> Grid4TransitionsEnum: """ Assumes pos1 and pos2 are adjacent location on grid. Returns direction (int) that can be used with transitions. @@ -23,45 +24,6 @@ def mirror(dir): return (dir + 2) % 4 -def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): - # start by getting direction used to get to current node - # and direction from current node to possible child node - new_dir = get_direction(current_pos, new_pos) - if prev_pos is not None: - current_dir = get_direction(prev_pos, current_pos) - else: - current_dir = new_dir - # create new transition that would go to child - new_trans = rail_array[current_pos] - if prev_pos is None: - if new_trans == 0: - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # check if matches existing layout - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - if new_pos == end_pos: - # need to validate end pos setup as well - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # need to flip direction because of how end points are defined - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # check if matches existing layout - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - - if not rail_trans.is_valid(new_trans_e): - return False - - # is transition is valid? - return rail_trans.is_valid(new_trans) - - def get_new_position(position, movement): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ if movement == Grid4TransitionsEnum.NORTH: diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 232d6fda..2b06714b 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -7,7 +7,8 @@ from importlib_resources import path from numpy import array from flatland.core.grid.grid4 import Grid4Transitions -from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid4_utils import get_new_position, get_direction +from flatland.core.grid.grid_utils import IntVector2DArrayType from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions @@ -540,6 +541,46 @@ class GridTransitionMap(TransitionMap): self.set_transitions((rcPos[0], rcPos[1]), transition) return True + def validate_new_transition(self, rail_trans: RailEnvTransitions, + prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType, + new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType): + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = self.grid[current_pos] + if prev_pos is None: + if new_trans == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # check if matches existing layout + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = self.grid[end_pos] + if new_trans_e == 0: + # need to flip direction because of how end points are defined + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # check if matches existing layout + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + + if not rail_trans.is_valid(new_trans_e): + return False + + # is transition is valid? + return rail_trans.is_valid(new_trans) + def mirror(dir): return (dir + 2) % 4 diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 9d4a72c0..9c01be4b 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -2,12 +2,10 @@ # -*- coding: utf-8 -*- """Tests for `flatland` package.""" -import numpy as np - from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid8 import Grid8Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.core.grid.grid4_utils import validate_new_transition +from flatland.core.transition_map import GridTransitionMap # remove whitespace in string; keep whitespace below for easier reading @@ -117,35 +115,35 @@ def test_is_valid_railenv_transitions(): def test_adding_new_valid_transition(): rail_trans = RailEnvTransitions() - rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) + grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans) # adding straight - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True) + assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (6, 5), (10, 10)) is True) # adding valid right turn - assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (5, 6), (10, 10)) is True) # adding valid left turn - assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True) + assert (grid_map.validate_new_transition(rail_trans, (5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - rail_array[(5, 5)] = rail_trans.transitions[2] - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + grid_map.grid[(5, 5)] = rail_trans.transitions[2] + assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid - rail_array[(5, 5)] = rail_trans.transitions[3] - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True) + grid_map.grid[(5, 5)] = rail_trans.transitions[3] + assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - rail_array[(5, 5)] = rail_trans.transitions[7] - assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + grid_map.grid[(5, 5)] = rail_trans.transitions[7] + assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition - rail_array[(5, 5)] = rail_trans.transitions[0] - assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True) + grid_map.grid[(5, 5)] = rail_trans.transitions[0] + assert (grid_map.validate_new_transition(rail_trans, None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition - rail_array[(5, 5)] = rail_trans.transitions[0] - assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True) + grid_map.grid[(5, 5)] = rail_trans.transitions[0] + assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (6, 5), (6, 5)) is True) def test_valid_railenv_transitions(): -- GitLab