From 9bbf2ed67b4fd7ef7d6bf3a735707d7f12873ae3 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 3 Oct 2019 14:35:14 +0200
Subject: [PATCH] refactor from nodes to cities

---
 flatland/core/grid/grid_utils.py |  12 +-
 flatland/envs/rail_generators.py | 183 +++++++++++++------------------
 2 files changed, 85 insertions(+), 110 deletions(-)

diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py
index 9c2fe294..7f920c11 100644
--- a/flatland/core/grid/grid_utils.py
+++ b/flatland/core/grid/grid_utils.py
@@ -2,6 +2,8 @@ from typing import Tuple, Callable, List, Type
 
 import numpy as np
 
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+
 Vector2D: Type = Tuple[float, float]
 IntVector2D: Type = Tuple[int, int]
 
@@ -296,7 +298,7 @@ def distance_on_rail(pos1, pos2, metric="Euclidean"):
         return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
 
 
-def direction_to_point(pos1, pos2):
+def direction_to_city(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
     """
     Returns the closest direction orientation of position 2 relative to position 1
     :param pos1: position we are interested in
@@ -308,11 +310,11 @@ def direction_to_point(pos1, pos2):
     direction = np.sign(diff_vec[axis])
     if axis == 0:
         if direction > 0:
-            return 0
+            return Grid4TransitionsEnum.NORTH
         else:
-            return 2
+            return Grid4TransitionsEnum.SOUTH
     else:
         if direction > 0:
-            return 3
+            return Grid4TransitionsEnum.WEST
         else:
-            return 1
+            return Grid4TransitionsEnum.EAST
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index fa18298c..a25862dc 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -7,7 +7,8 @@ import msgpack
 import numpy as np
 
 from flatland.core.grid.grid4_utils import get_direction, mirror
-from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
+from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \
+    Vec2dOperations
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line
@@ -545,8 +546,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
     :return: generator
     """
 
-    DEBUG_PRINT_TIMING = False
-
     def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
         np.random.seed(seed + num_resets)
 
@@ -560,7 +559,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
 
         # Evenly distribute cities
-        node_time_start = time.time()
         if grid_mode:
             city_positions, city_cells = _generate_evenly_distr_city_positions(max_num_cities, city_radius, width, height)
         else:
@@ -568,62 +566,40 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
         # reduce num_cities if less were generated in random mode
         num_cities = len(city_positions)
-        if DEBUG_PRINT_TIMING:
-            print("City position time", time.time() - node_time_start, "Seconds")
 
         # Set up connection points for all cities
-        node_connection_time = time.time()
-        inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points(
+        inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_city_connection_points(
             city_positions, city_radius, rails_between_cities,
             rails_in_city)
-        if DEBUG_PRINT_TIMING:
-            print("Connection points", time.time() - node_connection_time)
 
         # Connect the cities through the connection points
-        city_connection_time = time.time()
         inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells,
                                            rail_trans, grid_map)
-        if DEBUG_PRINT_TIMING:
-            print("City connection time", time.time() - city_connection_time)
         # Build inner cities
-        city_build_time = time.time()
-        through_tracks, free_tracks = _build_inner_cities(city_positions, inner_connection_points,
-                                                          outer_connection_points,
-                                                          city_radius,
-                                                          rail_trans,
-                                                          grid_map)
-        if DEBUG_PRINT_TIMING:
-            print("City build time", time.time() - city_build_time)
+        through_tracks, free_rails = _build_inner_cities(city_positions, inner_connection_points,
+                                                         outer_connection_points,
+                                                         rail_trans,
+                                                         grid_map)
         # Populate cities
-        train_station_time = time.time()
-        train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_tracks,
-                                                                             grid_map)
-        if DEBUG_PRINT_TIMING:
-            print("Trainstation placing time", time.time() - train_station_time)
+        train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_rails)
 
         # Fix all transition elements
-        grid_fix_time = time.time()
         _fix_transitions(city_cells, inter_city_lines, grid_map)
-        if DEBUG_PRINT_TIMING:
-            print("Grid fix time", time.time() - grid_fix_time)
 
         # Generate start target pairs
-        schedule_time = time.time()
-        agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations,
-                                                                             city_orientations)
-        if DEBUG_PRINT_TIMING:
-            print("Schedule time", time.time() - schedule_time)
+        agent_start_targets_cities, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations,
+                                                                              city_orientations)
 
         return grid_map, {'agents_hints': {
             'num_agents': num_agents,
-            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'agent_start_targets_nodes': agent_start_targets_cities,
             'train_stations': train_stations,
             'city_orientations': city_orientations
         }}
 
-    def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
-        city_positions: List[Tuple[int, int]] = []
-        city_cells: List[Tuple[int, int]] = []
+    def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (IntVector2DArray, IntVector2DArray):
+        city_positions: IntVector2DArray = []
+        city_cells: IntVector2DArray = []
         for city_idx in range(num_cities):
             too_close = True
             tries = 0
@@ -632,9 +608,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                 row = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1))
                 col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1))
                 too_close = False
-                # Check distance to nodes
-                for node_pos in city_positions:
-                    if _are_cities_overlapping((row, col), node_pos, 2 * (city_radius + 1) + 1):
+                # Check distance to cities
+                for city_pos in city_positions:
+                    if _are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
                         too_close = True
 
                 if not too_close:
@@ -650,7 +626,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                     break
         return city_positions, city_cells
 
-    def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
+    def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (IntVector2DArray, IntVector2DArray):
         aspect_ratio = height / width
         cities_per_row = int(np.ceil(np.sqrt(num_cities * aspect_ratio)))
         cities_per_col = int(np.ceil(num_cities / cities_per_row))
@@ -665,17 +641,17 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius))
         return city_positions, city_cells
 
-    def _generate_node_connection_points(city_positions: List[Tuple[int, int]], city_radius: int, rails_between_cities: int, rails_in_city: int = 2):
+    def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int, rails_in_city: int = 2):
         inner_connection_points = []
         outer_connection_points = []
         connection_info = []
         city_orientations = []
-        for node_position in city_positions:
+        for city_position in city_positions:
 
             # Chose the directions where close cities are situated
             neighb_dist = []
-            for neighb_node in city_positions:
-                neighb_dist.append(distance_on_rail(node_position, neighb_node, metric="Manhattan"))
+            for neighbour_city in city_positions:
+                neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city))
             closest_neighb_idx = argsort(neighb_dist)
 
             # Store the directions to these neighbours and orient city to face closest neighbour
@@ -684,7 +660,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             if grid_mode:
                 current_closest_direction = np.random.randint(4)
             else:
-                current_closest_direction = direction_to_point(node_position, city_positions[closest_neighb_idx[idx]])
+                current_closest_direction = direction_to_city(city_position, city_positions[closest_neighb_idx[idx]])
             connection_sides_idx.append(current_closest_direction)
             connection_sides_idx.append((current_closest_direction + 2) % 4)
             city_orientations.append(current_closest_direction)
@@ -703,16 +679,16 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                 for connection_idx in range(connections_per_direction[direction]):
                     if direction == 0:
                         tmp_coordinates = (
-                            node_position[0] - city_radius, node_position[1] + connection_slots[connection_idx])
+                            city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx])
                     if direction == 1:
                         tmp_coordinates = (
-                            node_position[0] + connection_slots[connection_idx], node_position[1] + city_radius)
+                            city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius)
                     if direction == 2:
                         tmp_coordinates = (
-                            node_position[0] + city_radius, node_position[1] + connection_slots[connection_idx])
+                            city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx])
                     if direction == 3:
                         tmp_coordinates = (
-                            node_position[0] + connection_slots[connection_idx], node_position[1] - city_radius)
+                            city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius)
                     connection_points_coordinates_inner[direction].append(tmp_coordinates)
                     if connection_idx in range(start_idx, start_idx + number_of_out_rails + 1):
                         connection_points_coordinates_outer[direction].append(tmp_coordinates)
@@ -722,7 +698,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             connection_info.append(connections_per_direction)
         return inner_connection_points, outer_connection_points, connection_info, city_orientations
 
-    def _connect_cities(city_positions: List[Tuple[int, int]], connection_points, city_cells: List[Tuple[int, int]],
+    def _connect_cities(city_positions: IntVector2DArray, connection_points, city_cells: IntVector2DArray,
                         rail_trans, grid_map):
         """
         Function to connect the different cities through their connection points
@@ -749,8 +725,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                     for dir in range(4):
                         current_points = connection_points[neighb_idx][dir]
                         for tmp_in_connection_point in current_points:
-                            tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point,
-                                                        metric="Manhattan")
+                            tmp_dist = Vec2dOperations.get_manhattan_distance(tmp_out_connection_point,
+                                                                              tmp_in_connection_point)
                             if tmp_dist < min_connection_dist:
                                 min_connection_dist = tmp_dist
                                 neighb_connection_point = tmp_in_connection_point
@@ -762,7 +738,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
         return all_paths
 
-    def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans,
+    def _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, rail_trans,
                             grid_map):
         """
         Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
@@ -773,9 +749,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         :param grid_map:
         :return: Returns the cells of the through path which cannot be occupied by trainstations
         """
-        through_path_cells = [[] for i in range(len(node_positions))]
-        free_tracks = [[] for i in range(len(node_positions))]
-        for current_city in range(len(node_positions)):
+        through_path_cells = [[] for i in range(len(city_positions))]
+        free_tracks = [[] for i in range(len(city_positions))]
+        for current_city in range(len(city_positions)):
             all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in
                                            sublist]
             # This part only works if we have keep same number of connection points for both directions
@@ -805,63 +781,61 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                     free_tracks[current_city].append(current_track)
         return through_path_cells, free_tracks
 
-    def _set_trainstation_positions(node_positions, node_radius, free_tracks, grid_map):
+    def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int, free_rails):
         """
 
         :param city_positions:
         :param num_trainstations:
         :return:
         """
-        nb_nodes = len(node_positions)
-        train_stations = [[] for i in range(nb_nodes)]
-        left = 0
-        right = 0
+        num_cities = len(city_positions)
+        train_stations = [[] for i in range(num_cities)]
         built_num_trainstations = 0
-        for current_city in range(len(node_positions)):
-            for track_nbr in range(len(free_tracks[current_city])):
-                possible_location = free_tracks[current_city][track_nbr][node_radius]
+        for current_city in range(len(city_positions)):
+            for track_nbr in range(len(free_rails[current_city])):
+                possible_location = free_rails[current_city][track_nbr][city_radius]
                 train_stations[current_city].append((possible_location, track_nbr))
         return train_stations, built_num_trainstations
 
-    def _generate_start_target_pairs(num_agents, nb_nodes, train_stations, city_orientation):
+    def _generate_start_target_pairs(num_agents, num_cities, train_stations, city_orientation):
         """
         Fill the trainstation positions with targets and goals
         :param num_agents:
-        :param nb_nodes:
+        :param num_cities:
         :param train_stations:
         :return:
         """
-        # Generate start and target node directory for all agents.
-        # Assure that start and target are not in the same node
-        agent_start_targets_nodes = []
+        # Generate start and target city directory for all agents.
+        # Assure that start and target are not in the same city
+        agent_start_targets_cities = []
 
-        # Slot availability in node
-        node_available_start = []
-        node_available_target = []
-        for node_idx in range(nb_nodes):
-            node_available_start.append(len(train_stations[node_idx]))
-            node_available_target.append(len(train_stations[node_idx]))
+        # Slot availability in city
+        city_available_start = []
+        city_available_target = []
+        for city_idx in range(num_cities):
+            city_available_start.append(len(train_stations[city_idx]))
+            city_available_target.append(len(train_stations[city_idx]))
 
         # Assign agents to slots
         for agent_idx in range(num_agents):
-            avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
-            avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
+            avail_start_cities = [idx for idx, val in enumerate(city_available_start) if val > 0]
+            avail_target_cities = [idx for idx, val in enumerate(city_available_target) if val > 0]
             # Set probability to choose start and stop from trainstations
-            sum_start = sum(np.array(node_available_start)[avail_start_nodes])
-            sum_target = sum(np.array(node_available_target)[avail_target_nodes])
-            p_avail_start = [float(i) / sum_start for i in np.array(node_available_start)[avail_start_nodes]]
+            sum_start = sum(np.array(city_available_start)[avail_start_cities])
+            sum_target = sum(np.array(city_available_target)[avail_target_cities])
+            p_avail_start = [float(i) / sum_start for i in np.array(city_available_start)[avail_start_cities]]
 
-            start_target_tuple = np.random.choice(avail_start_nodes, p=p_avail_start, size=2, replace=False)
-            start_node = start_target_tuple[0]
-            target_node = start_target_tuple[1]
-            agent_start_targets_nodes.append((start_node, target_node, city_orientation[start_node]))
-        return agent_start_targets_nodes, num_agents
+            start_target_tuple = np.random.choice(avail_start_cities, p=p_avail_start, size=2, replace=False)
+            start_city = start_target_tuple[0]
+            target_city = start_target_tuple[1]
+            agent_start_targets_cities.append((start_city, target_city, city_orientation[start_city]))
+        return agent_start_targets_cities, num_agents
 
     def _fix_transitions(city_cells, inter_city_lines, grid_map):
         """
         Function to fix all transition elements in environment
         """
-        # Fix all nodes with illegal transition maps
+        # Fix all cities with illegal transition maps
         rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int')
         rails_to_fix_cnt = 0
         cells_to_fix = city_cells + inter_city_lines
@@ -878,35 +852,34 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         for cell in range(rails_to_fix_cnt):
             grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]))
 
-    def _closest_neighbour_in_direction(current_city_idx: int, node_positions: List[Tuple[int, int]]):
+    def _closest_neighbour_in_direction(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
         """
-        Returns indices of closest neighbours in every direction NESW
-        :param current_city_idx: Index of node in city_positions list
+        Returns indices of closest neighbour in every direction NESW
+        :param current_city_idx: Index of city in city_positions list
         :param city_positions: list of all points being considered
-        :return: list of index of closest neighbours in all directions
+        :return: list of index of closest neighbour in all directions
         """
-        node_dist = []
-        closest_neighb = [None for i in range(4)]
-        for av_node in range(len(node_positions)):
-            node_dist.append(
-                distance_on_rail(node_positions[current_city_idx], node_positions[av_node], metric="Manhattan"))
-        sorted_neighbours = np.argsort(node_dist)
+        city_distances = []
+        closest_neighbour: List[int] = [None for i in range(4)]
+        for city_idx in range(len(city_positions)):
+            city_distances.append(Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx]))
+        sorted_neighbours = np.argsort(city_distances)
         direction_set = 0
-        for neighb in sorted_neighbours[1:]:
-            direction_to_neighb = direction_to_point(node_positions[current_city_idx], node_positions[neighb])
-            if closest_neighb[direction_to_neighb] == None:
-                closest_neighb[direction_to_neighb] = neighb
+        for neighbour in sorted_neighbours[1:]:
+            direction_to_neighbour = direction_to_city(city_positions[current_city_idx], city_positions[neighbour])
+            if closest_neighbour[direction_to_neighbour] == None:
+                closest_neighbour[direction_to_neighbour] = neighbour
                 direction_set += 1
 
             if direction_set == 4:
-                return closest_neighb
-        return closest_neighb
+                return closest_neighbour
+        return closest_neighbour
 
     def argsort(seq):
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)
 
-    def _get_cells_in_city(center: Tuple[int, int], radius: int) -> List[Tuple[int, int]]:
+    def _get_cells_in_city(center: IntVector2D, radius: int) -> IntVector2DArray:
         """
 
         Parameters
-- 
GitLab