From a6d1326abb6fb8bde23ab9e6f00acd73fcd021a5 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Mon, 30 Sep 2019 22:54:27 +0200 Subject: [PATCH] performance boost --- flatland/envs/rail_generators.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index cc516de2..0ec268dc 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -896,25 +896,20 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, Function to fix all transition elements in environment """ # Fix all nodes with illegal transition maps - empty_to_fix = [] - rails_to_fix = [] + rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int') + rails_to_fix_cnt = 0 for cell in city_cells: check = grid_map.cell_neighbours_valid(cell, True) if grid_map.grid[cell] == int('1000010000100001', 2): grid_map.fix_transitions(cell) if not check: - if grid_map.grid[cell] == 0: - empty_to_fix.append(cell) - else: - rails_to_fix.append(cell) - - # Fix empty cells first to avoid cutting the network - for cell in empty_to_fix: - grid_map.fix_transitions(cell) + rails_to_fix[2 * rails_to_fix_cnt] = cell[0] + rails_to_fix[2 * rails_to_fix_cnt + 1] = cell[1] + rails_to_fix_cnt += 1 # Fix all other cells - for cell in rails_to_fix: - grid_map.fix_transitions(cell) + for cell in range(rails_to_fix_cnt): + grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1])) def _closest_neigh_in_direction(current_node, node_positions): """ -- GitLab