From 96f38175bfc927352d50ce5b6fdb025be0f5c43a Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 31 Aug 2019 14:35:38 -0400
Subject: [PATCH] updated example of Flatland 2.0 and minor changes to schedule
 generator

---
 examples/flatland_2_0_example.py     | 24 +++++++++++++-----------
 flatland/envs/schedule_generators.py |  5 ++++-
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 71a185c7..6f6c4b08 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,9 +1,9 @@
 import numpy as np
-from flatland.envs.rail_generators import sparse_rail_generator
 
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
@@ -13,17 +13,23 @@ np.random.seed(1)
 # Training on simple small tasks is the best way to get familiar with the environment
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
+stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                    'malfunction_rate': 30,  # Rate of malfunction occurence
                    'min_duration': 3,  # Minimal duration of malfunction
                    'max_duration': 10  # Max duration of malfunction
                    }
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
-env = RailEnv(width=20,
-              height=20,
-              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
-                                                   num_intersections=1,  # Number of intersections (no start / target)
+
+speed_ration_map = {1.: 0.1,  # Fast passenger train
+                    0.5: 0.2,  # Slow commuter train
+                    0.25: 0.2,  # Fast freight train
+                    0.125: 0.5}  # Slow freight train
+
+env = RailEnv(width=50,
+              height=50,
+              rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map (where train stations are)
+                                                   num_intersections=5,  # Number of intersections (no start / target)
                                                    num_trainstations=15,  # Number of possible start/targets on map
                                                    min_node_dist=3,  # Minimal distance of nodes
                                                    node_radius=3,  # Proximity of stations to city center
@@ -32,7 +38,7 @@ env = RailEnv(width=20,
                                                    realistic_mode=True,
                                                    enhance_intersection=True
                                                    ),
-              schedule_generator=sparse_schedule_generator(),
+              schedule_generator=sparse_schedule_generator(speed_ration_map),
               number_of_agents=5,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=TreeObservation)
@@ -83,10 +89,6 @@ action_dict = dict()
 print("Start episode...")
 # Reset environment and get initial observations for all agents
 obs = env.reset()
-# Update/Set agent's speed
-for idx in range(env.get_num_agents()):
-    speed = 1.0 / ((idx % 5) + 1.0)
-    env.agents[idx].speed_data["speed"] = speed
 
 # Reset the rendering sytem
 env_renderer.reset()
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 4843e004..a0f6825d 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -60,7 +60,10 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
         train_stations = hints['train_stations']
         agent_start_targets_nodes = hints['agent_start_targets_nodes']
-        num_agents = hints['num_agents']
+        max_num_agents = hints['num_agents']
+        if num_agents > max_num_agents:
+            num_agents = max_num_agents
+            warnings.warn("Too many agents! Changes number of agents.")
         # Place agents and targets within available train stations
         agents_position = []
         agents_target = []
-- 
GitLab