Skip to content
Snippets Groups Projects
Commit dc658474 authored by u229589's avatar u229589
Browse files

refactor connect_straight_line

parent 28c23e66
No related branches found
No related tags found
No related merge requests found
......@@ -296,25 +296,3 @@ def distance_on_rail(pos1, pos2, metric="Euclidean"):
return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
if metric == "Manhattan":
return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
:param pos2: position we want to know it is facing
:return: direction NESW as int N:0 E:1 S:2 W:3
"""
diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1]))
axis = np.argmax(np.power(diff_vec, 2))
direction = np.sign(diff_vec[axis])
if axis == 0:
if direction > 0:
return Grid4TransitionsEnum.NORTH
else:
return Grid4TransitionsEnum.SOUTH
else:
if direction > 0:
return Grid4TransitionsEnum.WEST
else:
return Grid4TransitionsEnum.EAST
......@@ -9,7 +9,8 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position, directions_of_vector
from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position, directions_of_vector, \
direction_to_point
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
......@@ -86,8 +87,8 @@ def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, st
return path
def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D, openend: bool = False) -> IntVector2DArray:
def connect_straight_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D) -> IntVector2DArray:
"""
Generates a straight rail line from start cell to end cell.
Diagonal lines are not allowed
......@@ -95,43 +96,37 @@ def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitio
:param grid_map:
:param start: Cell coordinates for start of line
:param end: Cell coordinates for end of line
:param openend: If True then the transition at start and end is set to 0: An empty cell
:return: A list of all cells in the path
"""
# Assert that a straight line is possible
if not (start[0] == end[0] or start[1] == end[1]):
print("No straight line possible!")
return []
current_cell: IntVector2D = start
path: IntVector2DArray = [current_cell]
new_trans = grid_map.grid[current_cell]
direction = (np.clip(end[0] - start[0], -1, 1), np.clip(end[1] - start[1], -1, 1))
if direction[0] == 0:
if direction[1] > 0:
direction_int = 1
else:
direction_int = 3
else:
if direction[0] > 0:
direction_int = 2
else:
direction_int = 0
new_trans = rail_trans.set_transition(new_trans, direction_int, direction_int, 1)
new_trans = rail_trans.set_transition(new_trans, mirror(direction_int), mirror(direction_int), 1)
grid_map.grid[current_cell] = new_trans
if openend:
grid_map.grid[current_cell] = 0
# Set path
while current_cell != end:
current_cell = tuple(map(lambda x, y: x + y, current_cell, direction))
new_trans = grid_map.grid[current_cell]
new_trans = rail_trans.set_transition(new_trans, direction_int, direction_int, 1)
new_trans = rail_trans.set_transition(new_trans, mirror(direction_int), mirror(direction_int), 1)
grid_map.grid[current_cell] = new_trans
if current_cell == end and openend:
grid_map.grid[current_cell] = 0
path.append(current_cell)
direction = direction_to_point(start, end)
if direction is Grid4TransitionsEnum.NORTH or direction is Grid4TransitionsEnum.SOUTH:
start_row = min(start[0], end[0])
end_row = max(start[0], end[0]) + 1
rows = np.arange(start_row, end_row)
length = np.abs(end[0] - start[0]) + 1
cols = np.repeat(start[1], length)
else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST
start_col = min(start[1], end[1])
end_col = max(start[1], end[1]) + 1
cols = np.arange(start_col, end_col)
length = np.abs(end[1] - start[1]) + 1
rows = np.repeat(start[0], length)
path = list(zip(rows, cols))
for cell in path:
transition = grid_map.grid[cell]
transition = rail_trans.set_transition(transition, direction, direction, 1)
transition = rail_trans.set_transition(transition, mirror(direction), mirror(direction), 1)
grid_map.grid[cell] = transition
return path
......
......@@ -6,13 +6,13 @@ import msgpack
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point, IntVector2DArray, IntVector2D, \
from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_straigt_line
from flatland.envs.grid4_generators_utils import connect_rail, connect_straight_line
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
......@@ -777,14 +777,14 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
boarder_two = inner_connection_points[current_city][opposite_boarder]
# Connect the ends of the tracks
connect_straigt_line(rail_trans, grid_map, boarder_one[0], boarder_one[-1], False)
connect_straigt_line(rail_trans, grid_map, boarder_two[0], boarder_two[-1], False)
connect_straight_line(rail_trans, grid_map, boarder_one[0], boarder_one[-1])
connect_straight_line(rail_trans, grid_map, boarder_two[0], boarder_two[-1])
# Connect parallel tracks
for track_id in range(len(inner_connection_points[current_city][boarder])):
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
current_track = connect_straigt_line(rail_trans, grid_map, source, target, False)
current_track = connect_straight_line(rail_trans, grid_map, source, target)
if target in all_outer_connection_points and source in all_outer_connection_points and len(through_path_cells[current_city]) < 1:
through_path_cells[current_city].extend(current_track)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment