From a6d1326abb6fb8bde23ab9e6f00acd73fcd021a5 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Mon, 30 Sep 2019 22:54:27 +0200
Subject: [PATCH] performance boost

---
 flatland/envs/rail_generators.py | 19 +++++++------------
 1 file changed, 7 insertions(+), 12 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index cc516de2..0ec268dc 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -896,25 +896,20 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         Function to fix all transition elements in environment
         """
         # Fix all nodes with illegal transition maps
-        empty_to_fix = []
-        rails_to_fix = []
+        rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int')
+        rails_to_fix_cnt = 0
         for cell in city_cells:
             check = grid_map.cell_neighbours_valid(cell, True)
             if grid_map.grid[cell] == int('1000010000100001', 2):
                 grid_map.fix_transitions(cell)
             if not check:
-                if grid_map.grid[cell] == 0:
-                    empty_to_fix.append(cell)
-                else:
-                    rails_to_fix.append(cell)
-
-        # Fix empty cells first to avoid cutting the network
-        for cell in empty_to_fix:
-            grid_map.fix_transitions(cell)
+                rails_to_fix[2 * rails_to_fix_cnt] = cell[0]
+                rails_to_fix[2 * rails_to_fix_cnt + 1] = cell[1]
+                rails_to_fix_cnt += 1
 
         # Fix all other cells
-        for cell in rails_to_fix:
-            grid_map.fix_transitions(cell)
+        for cell in range(rails_to_fix_cnt):
+            grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]))
 
     def _closest_neigh_in_direction(current_node, node_positions):
         """
-- 
GitLab