From 979b637c94760bf0c35b14faae6f7aa0621666c5 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 17 Aug 2019 18:09:02 -0400
Subject: [PATCH] minor bugfixes

---
 examples/flatland_2_0_example.py | 25 +++++++++++++------------
 flatland/envs/generators.py      |  5 ++---
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 96fb87f8..15b19f8c 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -12,24 +12,25 @@ np.random.seed(1)
 # Training on simple small tasks is the best way to get familiar with the environment
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'prop_malfunction': 0.5,
-                   'malfunction_rate': 30,
-                   'min_duration': 3,
-                   'max_duration': 10}
+stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
+                   'malfunction_rate': 30,  # Rate of malfunction occurence
+                   'min_duration': 3,  # Minimal duration of malfunction
+                   'max_duration': 10  # Max duration of malfunction
+                   }
 
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
-env = RailEnv(width=50,
-              height=50,
-              rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
-                                                   num_intersections=3,  # Number of interesections in map
-                                                   num_trainstations=40,  # Number of possible start/targets on map
-                                                   min_node_dist=10,  # Minimal distance of nodes
+env = RailEnv(width=10,
+              height=10,
+              rail_generator=sparse_rail_generator(num_cities=3,  # Number of cities in map
+                                                   num_intersections=1,  # Number of interesections in map
+                                                   num_trainstations=8,  # Number of possible start/targets on map
+                                                   min_node_dist=3,  # Minimal distance of nodes
                                                    node_radius=2,  # Proximity of stations to city center
-                                                   num_neighb=4,  # Number of connections to other cities
+                                                   num_neighb=2,  # Number of connections to other cities
                                                    seed=15,  # Random seed
                                                    ),
-              number_of_agents=10,
+              number_of_agents=5,
               stochastic_data=stochastic_data,  # Malfunction generator data
               obs_builder_object=TreeObservation)
 
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 0b7ea708..2aa14899 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -875,7 +875,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             # Set number of neighboring nodes
             if len(available_nodes) >= num_neighb:
                 connected_neighb_idx = available_nodes[
-                                       0:2]  # np.random.choice(available_nodes, num_neighb, replace=False)
+                                       0:num_neighb]  # np.random.choice(available_nodes, num_neighb, replace=False)
             else:
                 connected_neighb_idx = available_nodes
 
@@ -885,10 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                     node_stack.append(neighb)
                 connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
             node_stack.pop(0)
-
         # Place train stations close to the node
         # We currently place them uniformly distirbuted among all cities
-
         train_stations = [[] for i in range(num_cities)]
 
         for station in range(num_trainstations):
@@ -940,6 +938,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 tries += 1
                 # Test again with new start node if no pair is found (This code needs to be improved)
                 if tries > 10:
+                    break
                     start_node = np.random.choice(avail_start_nodes)
 
             node_available_start[start_node] -= 1
-- 
GitLab