Skip to content
Snippets Groups Projects
Commit 3739bd5d authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

also fixing inter_city lines

parent 37239f53
No related branches found
No related tags found
No related merge requests found
...@@ -582,7 +582,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, ...@@ -582,7 +582,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
# Connect the cities through the connection points # Connect the cities through the connection points
city_connection_time = time.time() city_connection_time = time.time()
_connect_cities(node_positions, outer_connection_points, connection_info, city_cells, rail_trans, grid_map) inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells,
rail_trans, grid_map)
print("City connection time", time.time() - city_connection_time) print("City connection time", time.time() - city_connection_time)
# Build inner cities # Build inner cities
city_build_time = time.time() city_build_time = time.time()
...@@ -604,7 +605,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, ...@@ -604,7 +605,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
# Fix all transition elements # Fix all transition elements
grid_fix_time = time.time() grid_fix_time = time.time()
_fix_transitions(city_cells, grid_map) _fix_transitions(city_cells, inter_city_lines, grid_map)
print("Grid fix time", time.time() - grid_fix_time) print("Grid fix time", time.time() - grid_fix_time)
# Generate start target pairs # Generate start target pairs
...@@ -728,6 +729,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, ...@@ -728,6 +729,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
:param grid_map: Grid map :param grid_map: Grid map
:return: :return:
""" """
all_paths = []
for current_node in np.arange(len(node_positions)): for current_node in np.arange(len(node_positions)):
direction = 0 direction = 0
connected_to_city = [] connected_to_city = []
...@@ -758,11 +760,12 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, ...@@ -758,11 +760,12 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
if tmp_dist < min_connection_dist: if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point neighb_connection_point = tmp_in_connection_point
connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
city_cells) city_cells)
all_paths.extend(new_line)
direction += 1 direction += 1
return
return all_paths
def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans, def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans,
grid_map): grid_map):
...@@ -891,14 +894,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, ...@@ -891,14 +894,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
num_agents -= 1 num_agents -= 1
return agent_start_targets_nodes, num_agents return agent_start_targets_nodes, num_agents
def _fix_transitions(city_cells, grid_map): def _fix_transitions(city_cells, inter_city_lines, grid_map):
""" """
Function to fix all transition elements in environment Function to fix all transition elements in environment
""" """
# Fix all nodes with illegal transition maps # Fix all nodes with illegal transition maps
rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int') rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int')
rails_to_fix_cnt = 0 rails_to_fix_cnt = 0
for cell in city_cells: cells_to_fix = city_cells + inter_city_lines
for cell in cells_to_fix:
check = grid_map.cell_neighbours_valid(cell, True) check = grid_map.cell_neighbours_valid(cell, True)
if grid_map.grid[cell] == int('1000010000100001', 2): if grid_map.grid[cell] == int('1000010000100001', 2):
grid_map.fix_transitions(cell) grid_map.fix_transitions(cell)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment