Commit 67fe6124 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

allow orientation based fixing of cells

parent 91ae7d2b
Pipeline #2331 failed with stages
in 60 minutes
......@@ -30,9 +30,9 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=100,
height=100,
rail_generator=sparse_rail_generator(max_num_cities=20,
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=10,
# Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=False,
......
......@@ -14,6 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
class TransitionMap:
"""
......@@ -499,7 +500,7 @@ class GridTransitionMap(TransitionMap):
return True
def fix_transitions(self, rcPos: IntVector2DArray):
def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1):
"""
Fixes broken transitions
"""
......@@ -559,9 +560,17 @@ class GridTransitionMap(TransitionMap):
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
transition = np.random.choice(three_way_transitions, 1)
hole = np.argwhere(incoming_connections < 1)[0][0]
if direction > 0:
switch_type_idx = (direction - hole + 3) % 4
if switch_type_idx == 2:
transition = simple_switch_west_south
if switch_type_idx == 0:
transition = simple_switch_east_south
else:
transition = np.random.choice(three_way_transitions, 1)
else:
transition = np.random.choice(three_way_transitions, 1)
transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transitions((rcPos[0], rcPos[1]), transition)
......
......@@ -557,7 +557,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
cell_vector_field = np.zeros(shape=(height, width), dtype=int) - 1
city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 1
min_nr_rails_in_city = 3
......@@ -857,17 +857,17 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
rails_to_fix_cnt = 0
cells_to_fix = city_cells + inter_city_lines
for cell in cells_to_fix:
check = grid_map.cell_neighbours_valid(cell, True)
cell_valid = grid_map.cell_neighbours_valid(cell, True)
if grid_map.grid[cell] == int('1000010000100001', 2):
grid_map.fix_transitions(cell)
if not check:
if not cell_valid:
rails_to_fix[2 * rails_to_fix_cnt] = cell[0]
rails_to_fix[2 * rails_to_fix_cnt + 1] = cell[1]
rails_to_fix_cnt += 1
# Fix all other cells
for cell in range(rails_to_fix_cnt):
grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]))
grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]), )
def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment