From ccd73a5b994666790f7875f51102193331a4a5e1 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 28 Aug 2019 13:33:11 +0200
Subject: [PATCH] merge #147 -> #141 schedule_generator for
 sparse_rail_generator

---
 examples/flatland_2_0_example.py     |  5 ++-
 flatland/envs/rail_generators.py     | 47 ++---------------------
 flatland/envs/schedule_generators.py | 56 ++++++++++++++++++++++++++++
 flatland/utils/graphics_pil.py       | 12 +++---
 4 files changed, 69 insertions(+), 51 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 916e50b2..9f4d62cf 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,9 +1,10 @@
 import numpy as np
-
 from flatland.envs.generators import sparse_rail_generator
+
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer
 from flatland.utils.rendertools import RenderTool
 
 np.random.seed(1)
@@ -31,6 +32,7 @@ env = RailEnv(width=20,
                                                    realistic_mode=True,
                                                    enhance_intersection=True
                                                    ),
+              agent_generator=sparse_rail_generator_agents_placer(),
               number_of_agents=5,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=TreeObservation)
@@ -75,7 +77,6 @@ class RandomAgent:
 # Set action space to 4 to remove stop action
 agent = RandomAgent(218, 4)
 
-
 # Empty dictionary for all agent action
 action_dict = dict()
 
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index d338301c..ed507dca 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -786,48 +786,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             else:
                 num_agents -= 1
 
-        # Place agents and targets within available train stations
-        agents_position = []
-        agents_target = []
-        agents_direction = []
-
-        for agent_idx in range(num_agents):
-            # Set target for agent
-            current_target_node = agent_start_targets_nodes[agent_idx][1]
-            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
-            target = train_stations[current_target_node][target_station_idx]
-            tries = 0
-            while (target[0], target[1]) in agents_target:
-                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
-                target = train_stations[current_target_node][target_station_idx]
-                tries += 1
-                if tries > 100:
-                    warnings.warn("Could not set target position, removing an agent")
-                    break
-            agents_target.append((target[0], target[1]))
-
-            # Set start for agent
-            current_start_node = agent_start_targets_nodes[agent_idx][0]
-            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
-            start = train_stations[current_start_node][start_station_idx]
-            tries = 0
-            while (start[0], start[1]) in agents_position:
-                tries += 1
-                if tries > 100:
-                    warnings.warn("Could not set start position, please change initial parameters!!!!")
-                    break
-                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
-                start = train_stations[current_start_node][start_station_idx]
-
-            agents_position.append((start[0], start[1]))
-
-            # Orient the agent correctly
-            for orientation in range(4):
-                transitions = grid_map.get_transitions(start[0], start[1], orientation)
-                if any(transitions) > 0:
-                    agents_direction.append(orientation)
-                    continue
-
-        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return grid_map, {'agents_hints': {
+            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'train_stations': train_stations
+        }}
 
     return generator
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 50f31378..ef1f9666 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -1,4 +1,5 @@
 """Schedule generators (railway undertaking, "EVU")."""
+import warnings
 from typing import Tuple, List, Callable, Mapping, Optional, Any
 
 import msgpack
@@ -55,6 +56,61 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float]
     return generator
 
 
+def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        train_stations = hints['train_stations']
+        agent_start_targets_nodes = hints['agent_start_targets_nodes']
+        # Place agents and targets within available train stations
+        agents_position = []
+        agents_target = []
+        agents_direction = []
+        for agent_idx in range(num_agents):
+            # Set target for agent
+            current_target_node = agent_start_targets_nodes[agent_idx][1]
+            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+            target = train_stations[current_target_node][target_station_idx]
+            tries = 0
+            while (target[0], target[1]) in agents_target:
+                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+                target = train_stations[current_target_node][target_station_idx]
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set target position, removing an agent")
+                    break
+            agents_target.append((target[0], target[1]))
+
+            # Set start for agent
+            current_start_node = agent_start_targets_nodes[agent_idx][0]
+            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+            start = train_stations[current_start_node][start_station_idx]
+            tries = 0
+            while (start[0], start[1]) in agents_position:
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set start position, please change initial parameters!!!!")
+                    break
+                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+                start = train_stations[current_start_node][start_station_idx]
+
+            agents_position.append((start[0], start[1]))
+
+            # Orient the agent correctly
+            for orientation in range(4):
+                transitions = rail.get_transitions(start[0], start[1], orientation)
+                if any(transitions) > 0:
+                    agents_direction.append(orientation)
+                    continue
+
+        if speed_ratio_map:
+            speeds = speed_initialization_helper(num_agents, speed_ratio_map)
+        else:
+            speeds = [1.0] * len(agents_position)
+
+        return agents_position, agents_direction, agents_target, speeds
+
+    return generator
+
+
 def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
     """
     Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 6a0a9282..92a0f84f 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -172,7 +172,7 @@ class PILGL(GraphicsLayer):
     def text(self, xPx, yPx, strText, layer=RAIL_LAYER):
         xyPixLeftTop = (xPx, yPx)
         self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
-        
+
     def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
         print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer)
         xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
@@ -500,9 +500,9 @@ class PILSVG(PILGL):
                                           False)[0]
         self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
 
-    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, 
-            show_debug=True):
-        
+    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None,
+                    show_debug=True):
+
         if binary_trans in self.pil_rail:
             pil_track = self.pil_rail[binary_trans]
             if target is not None:
@@ -510,7 +510,7 @@ class PILSVG(PILGL):
                 target_img = Image.alpha_composite(pil_track, target_img)
                 self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER)
                 if show_debug:
-                    self.text_rowcol((row+0.8, col+0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
+                    self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
 
             if binary_trans == 0:
                 if self.background_grid[col][row] <= 4:
@@ -607,7 +607,7 @@ class PILSVG(PILGL):
 
         if show_debug:
             print("Call text:")
-            self.text_rowcol((row+0.2, col+0.2,), str(agent_idx))
+            self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
 
     def set_cell_occupied(self, agent_idx, row, col):
         occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
-- 
GitLab