From 5ff4a7a86a75bd2f63746db134eac7dd87af1655 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 26 Sep 2019 17:01:37 -0400
Subject: [PATCH] fixed orientation where x-axis was inverted in connection
 point generation

---
 examples/flatland_2_0_example.py | 14 ++---
 flatland/envs/rail_generators.py | 98 ++++++++++++--------------------
 2 files changed, 44 insertions(+), 68 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index fc358a2a..ccfbecbd 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -32,16 +32,16 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=3,  # Number of cities in map (where train stations are)
-                                                   num_trainstations=100,  # Number of possible start/targets on map
-                                                   min_node_dist=10,  # Minimal distance of nodes
+              rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
+                                                   num_trainstations=50,  # Number of possible start/targets on map
+                                                   min_node_dist=30,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
                                                    num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
-                                                   grid_mode=False,
+                                                   grid_mode=True,
                                                    nr_parallel_tracks=2,
-                                                   connectin_points_per_side=100,
-                                                   max_nr_connection_directions=3,
+                                                   connection_points_per_side=2,
+                                                   max_nr_connection_directions=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
               number_of_agents=50,
@@ -114,7 +114,7 @@ for step in range(500):
     # reward and whether their are done
     next_obs, all_rewards, done, _ = env.step(action_dict)
     env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
-    time.sleep(1)
+    time.sleep(10)
     frame_step += 1
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index b7ca0f4d..8c472cba 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -533,7 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
 
 
 def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2,
-                          num_neighb=3, nr_parallel_tracks=2, grid_mode=False, connectin_points_per_side=4,
+                          num_neighb=3, nr_parallel_tracks=2, grid_mode=False, connection_points_per_side=4,
                           max_nr_connection_directions=2,
                           seed=0) -> RailGenerator:
     """
@@ -598,7 +598,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
         # Set up connection points for all cities
         connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius,
-                                                                              connectin_points_per_side,
+                                                                              connection_points_per_side,
                                                                               max_nr_connection_directions)
 
         # Connect the cities through the connection points
@@ -712,13 +712,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 for connection_idx in range(connections_per_direction[direction]):
                     if direction == 0:
                         connection_points_coordinates.append(
-                            (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
+                            (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
                     if direction == 1:
                         connection_points_coordinates.append(
                             (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size))
                     if direction == 2:
                         connection_points_coordinates.append(
-                            (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
+                            (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
                     if direction == 3:
                         connection_points_coordinates.append(
                             (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size))
@@ -737,61 +737,37 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         :param grid_map: Grid map
         :return:
         """
-
-        # Start at some node
-        available_nodes = np.arange(len(node_positions))
-        current_node = np.random.randint(len(available_nodes))
-        node_stack = [current_node]
-        open_nodes = np.copy(available_nodes)
         boarder_connections = set()
-        while len(open_nodes) > 0:
-            if len(node_stack) > 0:
-                current_node = node_stack[0]
-            else:
-                current_node = np.random.choice(open_nodes)
-                node_stack.append(current_node)
-            delete_idx = np.where(open_nodes == current_node)
-            open_nodes = np.delete(open_nodes, delete_idx, 0)
-
-            # Sort available neighbors according to their distance.
-            node_dist = []
-            for av_node in available_nodes:
-                node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
-            available_nodes = available_nodes[np.argsort(node_dist)]
-
-            # Set number of neighboring
-            allowed_connections = np.count_nonzero(connection_info[current_node])
-
-            if len(available_nodes) >= allowed_connections:
-                connected_neighb_idx = available_nodes[1:allowed_connections + 1]
-            else:
-                connected_neighb_idx = available_nodes
-
-            # Connect to the neighboring nodes
-            for neighb in connected_neighb_idx:
-                if neighb not in node_stack and neighb in open_nodes:
-                    node_stack.append(neighb)
-
-                dist_from_center = distance_on_rail(node_positions[current_node], node_positions[neighb])
-                connection_distances = []
-                for tmp_out_connection_point in connection_points[current_node]:
-                    tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb])
-                    connection_distances.append(tmp_dist_to_node)
-                possible_connection_points = argsort(connection_distances)
-                for sort_idx in possible_connection_points[:nr_parallel_tracks]:
-                    # Find closest connection point
-                    tmp_out_connection_point = connection_points[current_node][sort_idx]
-                    min_connection_dist = np.inf
-                    for tmp_in_connection_point in connection_points[neighb]:
-                        tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
-                        if tmp_dist < min_connection_dist:
-                            min_connection_dist = tmp_dist
-                            neighb_connection_point = tmp_in_connection_point
-                    connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
-                    boarder_connections.add((tmp_out_connection_point, current_node))
-                    boarder_connections.add((neighb_connection_point, neighb))
-
-            node_stack.pop(0)
+        # Start at some node
+        for current_node in np.arange(len(node_positions)):
+            direction = 0
+            for nbr_connection_points in connection_info[current_node]:
+                if nbr_connection_points > 0:
+                    neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
+                    print(current_node, direction, neighb_idx, connection_info[current_node])
+                else:
+                    continue
+
+                if neighb_idx is not None:
+                    connection_distances = []
+                    for tmp_out_connection_point in connection_points[current_node]:
+                        tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb_idx])
+                        connection_distances.append(tmp_dist_to_node)
+                    possible_connection_points = argsort(connection_distances)
+                    for sort_idx in possible_connection_points[:nr_parallel_tracks]:
+                        # Find closest connection point
+                        tmp_out_connection_point = connection_points[current_node][sort_idx]
+                        min_connection_dist = np.inf
+                        for tmp_in_connection_point in connection_points[neighb_idx]:
+                            tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
+                            if tmp_dist < min_connection_dist:
+                                min_connection_dist = tmp_dist
+                                neighb_connection_point = tmp_in_connection_point
+                        connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
+                        boarder_connections.add((tmp_out_connection_point, current_node))
+                        boarder_connections.add((neighb_connection_point, neighb_idx))
+                direction += 1
+
 
     def _build_cities(node_positions, connection_points, rail_trans, grid_map):
         # Place train stations close to the node
@@ -921,7 +897,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
     def _closest_neigh_in_direction(current_node, direction, node_positions):
         # Sort available neighbors according to their distance.
-        available_nodes = np.arange(node_positions)
+        available_nodes = np.arange(len(node_positions))
         node_dist = []
         for av_node in available_nodes:
             node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
@@ -931,7 +907,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
             distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0])
             distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1])
             if direction == 0:
-                if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
+                if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
                     return neighb
 
             if direction == 1:
@@ -939,7 +915,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                     return neighb
 
             if direction == 2:
-                if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
+                if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
                     return neighb
 
             if direction == 3:
-- 
GitLab