From 6f8e4f379e6e8b296bc65d6a9b66df295834dfbb Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 16 Aug 2019 17:58:13 -0400
Subject: [PATCH] introduced cities and intersections. These will make up the
 sparse network. THen you can add trainstations to the cities and populate
 them with tasks for agents (start/target). Orientation of agents needs to be
 fixed. Also check for invalid transitions in cities and nodes needs to be
 implemented.

---
 flatland/envs/generators.py                   | 135 ++++++++++++------
 ...test_flatland_env_sparse_rail_generator.py |  25 ++--
 2 files changed, 105 insertions(+), 55 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index d63e161a..be21f86c 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -9,7 +9,7 @@ from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes, connect_to_nodes
+from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes
 from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
 
@@ -776,12 +776,13 @@ def realistic_rail_generator(nr_start_goal=1,  seed=0):
     return generator
 
 
-def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2,
+def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstations=2, min_node_dist=20, node_radius=2,
+                          num_neighb=4,
                           seed=0):
     '''
 
     :param nr_train_stations:
-    :param nr_nodes:
+    :param num_cities:
     :param mean_node_neighbours:
     :param min_node_dist:
     :param seed:
@@ -789,6 +790,11 @@ def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_
     '''
 
     def generator(width, height, num_agents, num_resets=0):
+
+        if num_agents > num_trainstations:
+            num_agents = num_trainstations
+            warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
+
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
@@ -797,12 +803,12 @@ def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_
 
         # Generate a set of nodes for the sparse network
         node_positions = []
-        for node_idx in range(nr_nodes):
+        for node_idx in range(num_cities + num_intersections):
             to_close = True
             tries = 0
             while to_close:
-                x_tmp = np.random.randint(height)
-                y_tmp = np.random.randint(width)
+                x_tmp = 1 + np.random.randint(height - 1)
+                y_tmp = 1 + np.random.randint(width - 1)
                 to_close = False
                 for node_pos in node_positions:
                     if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
@@ -815,7 +821,7 @@ def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_
                     break
 
         # Chose node connection
-        available_nodes = np.arange(nr_nodes)
+        available_nodes = np.arange(num_cities + num_intersections)
         current_node = 0
         node_stack = [current_node]
 
@@ -824,61 +830,98 @@ def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_
             delete_idx = np.where(available_nodes == current_node)
             available_nodes = np.delete(available_nodes, delete_idx, 0)
 
-            # Get random number of neighbors
-            num_neighb = 2  # np.random.randint(1, max_neigbours)
+            # Sort available neighbors according to their distance.
+            node_dist = []
+            for av_node in available_nodes:
+                node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
+            available_nodes = available_nodes[np.argsort(node_dist)]
+
+            # Set number of neighboring nodes
+            # np.random.randint(1, max_neigbours)
+
             if len(available_nodes) >= num_neighb:
-                connected_neighb_idx = np.random.choice(available_nodes, num_neighb, replace=False)
+                connected_neighb_idx = available_nodes[
+                                       0:2]  # np.random.choice(available_nodes, num_neighb, replace=False)
             else:
                 connected_neighb_idx = available_nodes
 
+            # Connect to the neighboring nodes
             for neighb in connected_neighb_idx:
                 if neighb not in node_stack:
                     node_stack.append(neighb)
                 new_path = connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
             node_stack.pop(0)
 
-        # Generate start and target node directory for all agents
+        # Place train stations close to the node
+        train_stations = [[] for i in range(num_cities)]
+
+        for station in range(num_trainstations):
+            trainstation_node = int(station / num_trainstations * num_cities)
+
+            station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0,
+                                height - 1)
+            station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0,
+                                width - 1)
+            while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
+                trainstation_node] or \
+                rail_array[(station_x, station_y)] != 0:
+                station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                                    0,
+                                    height - 1)
+                station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                                    0,
+                                    width - 1)
+            train_stations[trainstation_node].append((station_x, station_y))
+
+            # Connect train station to the correct node
+            new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
+                                          (station_x, station_y))
+
+        # Generate start and target node directory for all agents.
+        # Assure that start and target are not in the same node
         agent_start_targets_nodes = []
+
+        # Slot availability in node
+        node_available_start = []
+        node_available_target = []
+        for node_idx in range(num_cities):
+            node_available_start.append(len(train_stations[node_idx]))
+            node_available_target.append(len(train_stations[node_idx]))
+
+        # Assign agents to slots
         for agent_idx in range(num_agents):
-            start_target_tuple = np.random.choice(nr_nodes, 2, replace=False)
-            agent_start_targets_nodes.append(start_target_tuple)
+            av_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
+            av_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
+            start_node = np.random.choice(av_start_nodes)
+            target_node = np.random.choice(av_target_nodes)
+            while target_node == start_node:
+                target_node = np.random.choice(av_target_nodes)
+            node_available_start[start_node] -= 1
+            node_available_target[target_node] -= 1
+            print(node_available_target, node_available_start)
+
+            agent_start_targets_nodes.append((start_node, target_node))
+
+        # Place agents and targets within available train stations
 
-        # Generate actual start and target locations from around nodes
         agents_position = []
         agents_target = []
         agents_direction = []
-        agent_idx = 0
-        for start_target in agent_start_targets_nodes:
-            start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0,
-                              height - 1)
-            start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0,
-                              width - 1)
-            target_x = np.clip(node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
-                               height - 1)
-            target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
-                               width - 1)
-            # Make sure we don't put to starts or targets on same cell
-            while (start_x, start_y) in agents_position or (start_x, start_y) == node_positions[start_target[0]]:
-                start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
-                                  0,
-                                  height - 1)
-                start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
-                                  0,
-                                  width - 1)
-            while (target_x, target_y) in agents_target or (target_x, target_y) == node_positions[start_target[1]] or \
-                rail_array[(target_x, target_y)] != 0:
-                target_x = np.clip(
-                    node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
-                    height - 1)
-                target_y = np.clip(
-                    node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
-                    width - 1)
-            agents_position.append((start_x, start_y))
-            agents_target.append((target_x, target_y))
-            new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx],
-                                        node_positions[start_target[0]])
-            new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]],
-                                          agents_target[agent_idx])
+
+        for agent_idx in range(num_agents):
+            current_start_node = agent_start_targets_nodes[agent_idx][0]
+            current_target_node = agent_start_targets_nodes[agent_idx][1]
+            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+            target = train_stations[current_target_node][target_station_idx]
+            start = train_stations[current_start_node][start_station_idx]
+            while (target[0], target[1]) in agents_target:
+                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+                target = train_stations[current_target_node][target_station_idx]
+                start = train_stations[current_start_node][start_station_idx]
+            agents_position.append((start[0], start[1]))
+            agents_target.append((target[0], target[1]))
             agents_direction.append(0)
             agent_idx += 1
 
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 49739297..2a361ca9 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -1,11 +1,12 @@
 import time
 
-from flatland.envs.generators import sparse_rail_generator,realistic_rail_generator
+import numpy as np
+
+from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
-import numpy as np
 
 def test_realistic_rail_generator():
     for test_loop in range(5):
@@ -22,14 +23,20 @@ def test_realistic_rail_generator():
         env_renderer.close_window()
 
 def test_sparse_rail_generator():
-
-    env = RailEnv(width=20,
-                  height=20,
-                  rail_generator=sparse_rail_generator(nr_nodes=3, min_node_dist=8,
-                                                       node_radius=4),
-                  number_of_agents=15,
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                       num_intersections=3,  # Number of interesections in map
+                                                       num_trainstations=30,  # Number of possible start/targets on map
+                                                       min_node_dist=10,  # Minimal distance of nodes
+                                                       node_radius=2,  # Proximity of stations to city center
+                                                       num_neighb=4,  # Number of connections to other cities
+                                                       seed=15,  # Random seed
+                                                       ),
+                  number_of_agents=20,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-    time.sleep(2)
+    env_renderer.gl.save_image("flatalnd_2_0.png")
+    time.sleep(100)
-- 
GitLab