From bf30a586585af953e6864363b1f0f604d1c5cc98 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 16:11:28 -0400
Subject: [PATCH] updated city placement in sparse_rail_generator Updated
 flatland 2.0 example added new feature: Malfunctioning agents will
 automatically resume operation when they are fixed

---
 examples/flatland_2_0_example.py |  8 ++++----
 flatland/envs/rail_env.py        | 24 ++++++++++++++++--------
 flatland/envs/rail_generators.py |  9 ++-------
 3 files changed, 22 insertions(+), 19 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index a1ad9a85..da8cacef 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -13,7 +13,7 @@ np.random.seed(1)
 # Training on simple small tasks is the best way to get familiar with the environment
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
+stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                    'malfunction_rate': 30,  # Rate of malfunction occurence
                    'min_duration': 3,  # Minimal duration of malfunction
                    'max_duration': 20  # Max duration of malfunction
@@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 env = RailEnv(width=50,
               height=50,
               rail_generator=sparse_rail_generator(num_cities=25,  # Number of cities in map (where train stations are)
-                                                   num_intersections=0,  # Number of intersections (no start / target)
+                                                   num_intersections=10,  # Number of intersections (no start / target)
                                                    num_trainstations=50,  # Number of possible start/targets on map
                                                    min_node_dist=3,  # Minimal distance of nodes
-                                                   node_radius=2,  # Proximity of stations to city center
-                                                   num_neighb=3,  # Number of connections to other cities/intersections
+                                                   node_radius=4,  # Proximity of stations to city center
+                                                   num_neighb=4,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
                                                    grid_mode=True,
                                                    enhance_intersection=False
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 8486fc7f..4926fb1f 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -266,7 +266,8 @@ class RailEnv(Environment):
 
     def _agent_malfunction(self, agent):
         # Decrease counter for next event
-        agent.malfunction_data['next_malfunction'] -= 1
+        if agent.malfunction_data['malfunction_rate'] > 0:
+            agent.malfunction_data['next_malfunction'] -= 1
 
         # Only agents that have a positive rate for malfunctions and are not currently broken are considered
         if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']:
@@ -328,15 +329,22 @@ class RailEnv(Environment):
 
             # The train is broken
             if agent.malfunction_data['malfunction'] > 0:
-                agent.malfunction_data['malfunction'] -= 1
+                # Last step of malfunction --> Agent starts moving again
+                if agent.malfunction_data['malfunction'] < 2:
+                    agent.malfunction_data['malfunction'] -= 1
+                    self.agents[i_agent].moving = True
+                    action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
-                # Broken agents are stopped
-                self.rewards_dict[i_agent] += step_penalty  # * agent.speed_data['speed']
-                self.agents[i_agent].moving = False
-                action_dict[i_agent] = RailEnvActions.DO_NOTHING
+                else:
+                    agent.malfunction_data['malfunction'] -= 1
 
-                # Nothing left to do with broken agent
-                continue
+                    # Broken agents are stopped
+                    self.rewards_dict[i_agent] += step_penalty  # * agent.speed_data['speed']
+                    self.agents[i_agent].moving = False
+                    action_dict[i_agent] = RailEnvActions.DO_NOTHING
+
+                    # Nothing left to do with broken agent
+                    continue
 
             if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
                 print('ERROR: illegal action=', action_dict[i_agent],
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 39796515..c2359346 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         node_positions = []
         city_positions = []
         intersection_positions = []
-
         # Evenly distribute cities and intersections
         if grid_mode:
             tot_num_node = num_intersections + num_cities
@@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
             x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
             y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
-
-            fraction = 0
-            city_fraction = num_cities / tot_num_node
-            step = np.gcd(num_intersections, num_cities) / tot_num_node
+            city_idx = np.random.choice(np.arange(tot_num_node), num_cities)
 
         for node_idx in range(num_cities + num_intersections):
             to_close = True
@@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                         warnings.warn("Could not set nodes, please change initial parameters!!!!")
                         break
             else:
-                fraction = (fraction + step) % 1.
                 x_tmp = x_positions[node_idx % nodes_per_row]
                 y_tmp = y_positions[node_idx // nodes_per_row]
-                if len(city_positions) < num_cities and fraction < city_fraction:
+                if node_idx in city_idx:
                     city_positions.append((x_tmp, y_tmp))
                 else:
                     intersection_positions.append((x_tmp, y_tmp))
-- 
GitLab