Commit a1085186 authored by Erik Nygren's avatar Erik Nygren
Browse files

introduced new inner city parallel track connector

parent 3e30bd8a
......@@ -36,7 +36,7 @@ env = RailEnv(width=50,
# Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_between_cities=3,
max_rails_in_city=6,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
......
......@@ -561,11 +561,11 @@ class GridTransitionMap(TransitionMap):
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
hole = np.argwhere(incoming_connections < 1)[0][0]
if direction > 0:
if direction >= 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 2:
transition = simple_switch_west_south
if switch_type_idx == 0:
transition = simple_switch_west_south
elif switch_type_idx == 2:
transition = simple_switch_east_south
else:
transition = np.random.choice(three_way_transitions, 1)
......
......@@ -9,7 +9,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point, get_new_position
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
......@@ -126,3 +126,33 @@ def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVec
grid_map.grid[cell] = transition
return path
def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, rail_trans: RailEnvTransitions):
"""
Fix inner city nodes
:param grid_map:
:param start:
:param rail_trans:
:return:
"""
corner_directions = []
for direction in range(4):
tmp_pos = get_new_position(inner_node_pos, direction)
if grid_map.grid[tmp_pos] > 0:
corner_directions.append(direction)
if len(corner_directions) == 2:
transition = 0
transition = rail_trans.set_transition(transition, mirror(corner_directions[0]), corner_directions[1], 1)
transition = rail_trans.set_transition(transition, mirror(corner_directions[1]), corner_directions[0], 1)
grid_map.grid[inner_node_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[0])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[0], mirror(corner_directions[0]), 1)
grid_map.grid[tmp_pos] = transition
tmp_pos = get_new_position(inner_node_pos, corner_directions[1])
transition = grid_map.grid[tmp_pos]
transition = rail_trans.set_transition(transition, corner_directions[1], mirror(corner_directions[1]),
1)
grid_map.grid[tmp_pos] = transition
return
......@@ -12,7 +12,8 @@ from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, In
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
fix_inner_nodes
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
......@@ -698,22 +699,35 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
for direction in range(4):
connection_slots = np.arange(nr_of_connection_points) - start_idx
inner_point_offset = np.abs(connection_slots) + np.clip(connection_slots, 0, 1)
for connection_idx in range(connections_per_direction[direction]):
if direction == 0:
tmp_coordinates = (
city_position[0] - city_radius + inner_point_offset[connection_idx],
city_position[1] + connection_slots[connection_idx])
out_tmp_coordinates = (
city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx])
if direction == 1:
tmp_coordinates = (
city_position[0] + connection_slots[connection_idx],
city_position[1] + city_radius - inner_point_offset[connection_idx])
out_tmp_coordinates = (
city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius)
if direction == 2:
tmp_coordinates = (
city_position[0] + city_radius - inner_point_offset[connection_idx],
city_position[1] + connection_slots[connection_idx])
out_tmp_coordinates = (
city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx])
if direction == 3:
tmp_coordinates = (
city_position[0] + connection_slots[connection_idx],
city_position[1] - city_radius + inner_point_offset[connection_idx])
out_tmp_coordinates = (
city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius)
connection_points_coordinates_inner[direction].append(tmp_coordinates)
if connection_idx in range(start_idx, start_idx + number_of_out_rails):
connection_points_coordinates_outer[direction].append(tmp_coordinates)
connection_points_coordinates_outer[direction].append(out_tmp_coordinates)
inner_connection_points.append(connection_points_coordinates_inner)
outer_connection_points.append(connection_points_coordinates_outer)
......@@ -789,11 +803,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
:param grid_map:
:return: Returns the cells of the through path which cannot be occupied by trainstations
"""
through_path_cells: List[IntVector2DArray] = [[] for i in range(len(city_positions))]
free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))]
for current_city in range(len(city_positions)):
all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in
sublist]
# This part only works if we have keep same number of connection points for both directions
# Also only works with two connection direction at each city
for i in range(4):
......@@ -802,18 +814,29 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
break
opposite_boarder = (boarder + 2) % 4
boarder_one = inner_connection_points[current_city][boarder]
boarder_two = inner_connection_points[current_city][opposite_boarder]
# Connect the ends of the tracks
connect_straight_line_in_grid_map(grid_map, boarder_one[0], boarder_one[-1], rail_trans)
connect_straight_line_in_grid_map(grid_map, boarder_two[0], boarder_two[-1], rail_trans)
nr_of_connection_points = len(inner_connection_points[current_city][boarder])
number_of_out_rails = len(outer_connection_points[current_city][boarder])
start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
# Connect parallel tracks
for track_id in range(len(inner_connection_points[current_city][boarder])):
for track_id in range(nr_of_connection_points):
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans)
for track_id in range(nr_of_connection_points):
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
fix_inner_nodes(
grid_map, source, rail_trans)
fix_inner_nodes(
grid_map, target, rail_trans)
if start_idx <= track_id < start_idx + number_of_out_rails:
source_outer = outer_connection_points[current_city][boarder][track_id - start_idx]
target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx]
connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans)
connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
free_rails[current_city].append(current_track)
return free_rails
......@@ -870,18 +893,20 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
cells_to_fix = city_cells + inter_city_lines
for cell in cells_to_fix:
cell_valid = grid_map.cell_neighbours_valid(cell, True)
if grid_map.grid[cell] == int('1000010000100001', 2):
grid_map.fix_transitions(cell)
# cell_valid = grid_map.transitions.is_valid(cell)
# if grid_map.grid[cell] == int('1000010000100001', 2):
# grid_map.fix_transitions(cell)
# if bin(grid_map.grid[cell]).count("1") == 4:
# cell_valid = False
# print("fixing cell", cell, vector_field[cell])
if not cell_valid:
rails_to_fix[3 * rails_to_fix_cnt] = cell[0]
rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1]
rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[(cell[0], cell[1])]
rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell]
rails_to_fix_cnt += 1
# Fix all other cells
for cell in range(rails_to_fix_cnt):
grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]),
rails_to_fix[3 * rails_to_fix_cnt + 2])
grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
"""
......
......@@ -570,8 +570,8 @@ def test_sparse_rail_generator():
for a in range(env.get_num_agents()):
s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0))
s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0))
assert s0 == 61, "actual={}".format(s0)
assert s1 == 42, "actual={}".format(s1)
assert s0 == 39, "actual={}".format(s0)
assert s1 == 27, "actual={}".format(s1)
def test_sparse_rail_generator_deterministic():
......
......@@ -158,7 +158,7 @@ def test_malfunction_process_statistically():
env.step(action_dict)
# check that generation of malfunctions works as expected
assert nb_malfunction == 128, "nb_malfunction={}".format(nb_malfunction)
assert nb_malfunction == 152, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment