From f17ec7e06e307a2fe69b53a4f131fa6ecbe327b3 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 07:35:06 -0400
Subject: [PATCH] minor bugfix in level generator

---
 examples/flatland_2_0_example.py     | 15 +++++++++------
 flatland/envs/rail_generators.py     |  6 ++++--
 flatland/envs/schedule_generators.py |  4 ++--
 3 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index a99009a5..1a18cb40 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,3 +1,5 @@
+import time
+
 import numpy as np
 
 from flatland.envs.observations import TreeObsForRailEnv
@@ -30,18 +32,18 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=20,  # Number of cities in map (where train stations are)
-                                                   num_intersections=5,  # Number of intersections (no start / target)
-                                                   num_trainstations=15,  # Number of possible start/targets on map
+              rail_generator=sparse_rail_generator(num_cities=25,  # Number of cities in map (where train stations are)
+                                                   num_intersections=0,  # Number of intersections (no start / target)
+                                                   num_trainstations=0,  # Number of possible start/targets on map
                                                    min_node_dist=3,  # Minimal distance of nodes
                                                    node_radius=2,  # Proximity of stations to city center
-                                                   num_neighb=4,  # Number of connections to other cities/intersections
+                                                   num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
                                                    realistic_mode=True,
-                                                   enhance_intersection=True
+                                                   enhance_intersection=False
                                                    ),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
-              number_of_agents=10,
+              number_of_agents=0,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=TreeObservation)
 
@@ -112,6 +114,7 @@ for step in range(500):
     next_obs, all_rewards, done, _ = env.step(action_dict)
     env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
     frame_step += 1
+    time.sleep(10.1)
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
         agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index b69d80a0..93e7ce55 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -675,9 +675,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
         # Place train stations close to the node
         # We currently place them uniformly distirbuted among all cities
+        built_num_trainstation = 0
+        train_stations = [[] for i in range(num_cities)]
+
         if num_cities > 1:
-            train_stations = [[] for i in range(num_cities)]
-            built_num_trainstation = 0
+
             for station in range(num_trainstations):
                 spot_found = True
                 trainstation_node = int(station / num_trainstations * num_cities)
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index a3a6dc1e..7a116f9e 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -110,7 +110,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
         else:
             speeds = [1.0] * len(agents_position)
 
-        return agents_position, agents_direction, agents_target, speeds
+        return agents_position, agents_direction, agents_target, speeds, None
 
     return generator
 
@@ -203,7 +203,7 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
                         np.random.choice(len(valid_starting_directions), 1)[0]]
 
         agents_speed = speed_initialization_helper(num_agents, speed_ratio_map)
-        return agents_position, agents_direction, agents_target, agents_speed
+        return agents_position, agents_direction, agents_target, agents_speed, None
 
     return generator
 
-- 
GitLab