Commit 9babe928 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring

parent ef0b7175
import copy
import os
import time
import warnings
import numpy as np
......@@ -278,7 +277,6 @@ def realistic_rail_generator(num_cities=5,
grid_map.grid[start_node] = tmp_trans_sn
grid_map.grid[end_node] = tmp_trans_en
connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added)
def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
......
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
......@@ -75,13 +74,15 @@ def a_star(rail_trans: RailEnvTransitions,
else:
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
# update the "current" pos
node_pos = Vec2d.add(current_node.pos, new_pos)
# is node_pos inside the grid?
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue
# validate positions
if not grid_map.validate_new_transition(rail_trans, prev_pos, current_node.pos, node_pos,
end_node.pos):
if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos):
continue
# create new node
......
......@@ -10,6 +10,19 @@ IntVector2DArrayType = []
class Vec2dOperations:
@staticmethod
def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool:
"""
vector operation : node_a + node_b
:param node_a: tuple with coordinate (x,y) or 2d vector
:param node_b: tuple with coordinate (x,y) or 2d vector
:return:
-------
check if node_a and nobe_b are equal
"""
return node_a[0] == node_b[0] and node_a[1] == node_b[1]
@staticmethod
def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
"""
......
......@@ -9,6 +9,7 @@ from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArrayType
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
......@@ -301,7 +302,7 @@ class GridTransitionMap(TransitionMap):
self.height = new_height
self.grid = new_grid
def is_dead_end(self, rcPos):
def is_dead_end(self, rcPos: IntVector2DArrayType):
"""
Check if the cell is a dead-end.
......@@ -321,7 +322,7 @@ class GridTransitionMap(TransitionMap):
tmp = tmp >> 1
return nbits == 1
def is_simple_turn(self, rcPos):
def is_simple_turn(self, rcPos: IntVector2DArrayType):
"""
Check if the cell is a left/right simple turn
......@@ -348,7 +349,7 @@ class GridTransitionMap(TransitionMap):
return is_simple_turn(tmp)
def check_path_exists(self, start, direction, end):
def check_path_exists(self, start: IntVector2DArrayType, direction: int, end: IntVector2DArrayType):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
......@@ -358,7 +359,8 @@ class GridTransitionMap(TransitionMap):
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if node_position[0] == end[0] and node_position[1] == end[1]:
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
......@@ -371,7 +373,7 @@ class GridTransitionMap(TransitionMap):
return False
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
def cell_neighbours_valid(self, rcPos: IntVector2DArrayType, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
......@@ -423,7 +425,7 @@ class GridTransitionMap(TransitionMap):
return True
def fix_neighbours(self, rcPos, check_this_cell=False):
def fix_neighbours(self, rcPos: IntVector2DArrayType, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
......@@ -476,7 +478,7 @@ class GridTransitionMap(TransitionMap):
return True
def fix_transitions(self, rcPos):
def fix_transitions(self, rcPos: IntVector2DArrayType):
"""
Fixes broken transitions
"""
......@@ -541,9 +543,9 @@ class GridTransitionMap(TransitionMap):
self.set_transitions((rcPos[0], rcPos[1]), transition)
return True
def validate_new_transition(self, rail_trans: RailEnvTransitions,
prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
def validate_new_transition(self, prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
......@@ -556,30 +558,30 @@ class GridTransitionMap(TransitionMap):
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not rail_trans.is_valid(new_trans_e):
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return rail_trans.is_valid(new_trans)
return self.transitions.is_valid(new_trans)
def mirror(dir):
......
......@@ -118,32 +118,32 @@ def test_adding_new_valid_transition():
grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans)
# adding straight
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5), (10, 10)) is True)
# adding valid right turn
assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6), (10, 10)) is True)
# adding valid left turn
assert (grid_map.validate_new_transition(rail_trans, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
grid_map.grid[(5, 5)] = rail_trans.transitions[2]
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
# should create #4 -> valid
grid_map.grid[(5, 5)] = rail_trans.transitions[3]
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
grid_map.grid[(5, 5)] = rail_trans.transitions[7]
assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
# test path start condition
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
assert (grid_map.validate_new_transition(rail_trans, None, (5, 5), (5, 6), (10, 10)) is True)
assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True)
# test path end condition
grid_map.grid[(5, 5)] = rail_trans.transitions[0]
assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True)
def test_valid_railenv_transitions():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment