Commit ef0b7175 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring and clean up

parent 11feb9b0
Pipeline #2037 passed with stages
in 32 minutes and 14 seconds
......@@ -451,7 +451,7 @@ def realistic_rail_generator(num_cities=5,
if os.path.exists("./../render_output/"):
for itrials in range(1000):
print(itrials, "generate new city")
np.random.seed(0 * int(time.time()))
np.random.seed(itrials)
env = RailEnv(width=40 + np.random.choice(100),
height=40 + np.random.choice(100),
rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
......@@ -462,7 +462,7 @@ if os.path.exists("./../render_output/"):
connect_max_nbr_of_shortes_city=2,
do_random_connect_stations=False,
# Number of cities in map
seed=0*int(time.time()), # Random seed
seed=itrials, # Random seed
print_out_info=True
),
schedule_generator=sparse_schedule_generator(),
......
from flatland.core.grid.grid4_utils import validate_new_transition
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
......@@ -12,9 +12,9 @@ class AStarNode:
def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None):
self.parent: IntVector2D = parent
self.pos: IntVector2D = pos
self.g: float = 0.0
self.h: float = 0.0
self.f: float = 0.0
self.g = 0.0
self.h = 0.0
self.f = 0.0
def __eq__(self, other: IntVector2D):
return self.pos == other.pos
......@@ -80,7 +80,7 @@ def a_star(rail_trans: RailEnvTransitions,
continue
# validate positions
if not validate_new_transition(rail_trans, grid_map.grid, prev_pos, current_node.pos, node_pos,
if not grid_map.validate_new_transition(rail_trans, prev_pos, current_node.pos, node_pos,
end_node.pos):
continue
......
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2DArrayType
def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
def get_direction(pos1: IntVector2DArrayType, pos2: IntVector2DArrayType) -> Grid4TransitionsEnum:
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
......@@ -23,45 +24,6 @@ def mirror(dir):
return (dir + 2) % 4
def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = rail_array[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
......
......@@ -7,7 +7,8 @@ from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
......@@ -540,6 +541,46 @@ class GridTransitionMap(TransitionMap):
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, rail_trans: RailEnvTransitions,
prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
def mirror(dir):
return (dir + 2) % 4
......
......@@ -2,12 +2,10 @@
# -*- coding: utf-8 -*-
"""Tests for `flatland` package."""
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid8 import Grid8Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.grid.grid4_utils import validate_new_transition
from flatland.core.transition_map import GridTransitionMap
# remove whitespace in string; keep whitespace below for easier reading
......@@ -117,35 +115,35 @@ def test_is_valid_railenv_transitions():
def test_adding_new_valid_transition():
rail_trans = RailEnvTransitions()
rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)
grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans)
# adding straight
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
# adding valid right turn
assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
# adding valid left turn
assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
assert (grid_map.validate_new_transition(rail_trans, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
rail_array[(5, 5)] = rail_trans.transitions[2]
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
grid_map.grid[(5, 5)] = rail_trans.transitions[2]
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
# should create #4 -> valid
rail_array[(5, 5)] = rail_trans.transitions[3]
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
grid_map.grid[(5, 5)] = rail_trans.transitions[3]
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
rail_array[(5, 5)] = rail_trans.transitions[7]
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
grid_map.grid[(5, 5)] = rail_trans.transitions[7]
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
# test path start condition
rail_array[(5, 5)] = rail_trans.transitions[0]
assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
assert (grid_map.validate_new_transition(rail_trans, None, (5, 5), (5, 6), (10, 10)) is True)
# test path end condition
rail_array[(5, 5)] = rail_trans.transitions[0]
assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
def test_valid_railenv_transitions():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment