From ddfc3f36352f3485e497de5a3163af240a64f953 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 1 Oct 2019 13:34:28 -0400
Subject: [PATCH] new trainstation distribution

---
 flatland/envs/rail_generators.py     | 52 +++++-----------------------
 flatland/envs/schedule_generators.py | 45 +++++-------------------
 2 files changed, 18 insertions(+), 79 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index d7050d63..ad5c78a1 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -604,13 +604,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         print("City build time", time.time() - city_build_time)
         # Populate cities
         train_station_time = time.time()
-        train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, free_tracks, grid_map)
+        train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, node_radius, free_tracks,
+                                                                             grid_map)
         print("Trainstation placing time", time.time() - train_station_time)
 
-        # Adjust the number of agents if you could not build enough trainstations
-        if num_agents > built_num_trainstation:
-            num_agents = built_num_trainstation
-            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
 
         # Fix all transition elements
         grid_fix_time = time.time()
@@ -822,7 +819,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                     free_tracks[current_city].append(current_track)
         return through_path_cells, free_tracks
 
-    def _set_trainstation_positions(node_positions, free_tracks, grid_map):
+    def _set_trainstation_positions(node_positions, node_radius, free_tracks, grid_map):
         """
 
         :param node_positions:
@@ -836,20 +833,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         built_num_trainstations = 0
         for current_city in range(len(node_positions)):
             for track_nbr in range(len(free_tracks[current_city])):
-                for possible_location in free_tracks[current_city][track_nbr]:
-                    # Only build trainstation on non diverging elements
-                    cell_type = grid_map.get_full_transitions(*possible_location)
-                    nbits = 0
-                    while cell_type > 0:
-                        nbits += (cell_type & 1)
-                        cell_type = cell_type >> 1
-                    if 1 <= nbits <= 2:
-                        built_num_trainstations += 1
-                        if track_nbr % 2 == 0:
-                            left += 1
-                        else:
-                            right += 1
-                        train_stations[current_city].append((possible_location, track_nbr))
+                possible_location = free_tracks[current_city][track_nbr][node_radius]
+                train_stations[current_city].append((possible_location, track_nbr))
         return train_stations, built_num_trainstations
 
     def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
@@ -881,29 +866,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             p_avail_start = [float(i) / sum_start for i in np.array(node_available_start)[avail_start_nodes]]
             p_avail_target = [float(i) / sum_target for i in np.array(node_available_target)[avail_target_nodes]]
 
-            start_node = np.random.choice(avail_start_nodes, p=p_avail_start)
-            target_node = np.random.choice(avail_target_nodes, p=p_avail_target)
-            tries = 0
-            found_agent_pair = True
-            while target_node == start_node:
-                target_node = np.random.choice(avail_target_nodes)
-                tries += 1
-                # Test again with new start node if no pair is found (This code needs to be improved)
-                if (tries + 1) % 10 == 0:
-                    start_node = np.random.choice(avail_start_nodes)
-                if tries > 100:
-                    warnings.warn("Could not set trainstations, removing agent!")
-                    found_agent_pair = False
-                    break
-            if found_agent_pair:
-                node_available_start[start_node] -= 1
-                node_available_target[target_node] -= 1
-                shortest_path = nx.astar_path(G, start_node, target_node, weight='length')
-                start_orientation = nx.get_edge_attributes(G, "direction")[(shortest_path[0], shortest_path[1])]
-                agent_start_targets_nodes.append((start_node, target_node, start_orientation))
-
-            else:
-                num_agents -= 1
+            start_target_tuple = np.random.choice(avail_start_nodes, p=p_avail_start, size=2, replace=False)
+            start_node = start_target_tuple[0]
+            target_node = start_target_tuple[1]
+            agent_start_targets_nodes.append((start_node, target_node, 0))
         return agent_start_targets_nodes, num_agents
 
     def _fix_transitions(city_cells, inter_city_lines, grid_map):
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index d6f5a2c3..30f6b57e 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -70,47 +70,20 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
         agents_position = []
         agents_target = []
         agents_direction = []
-        start_slots = train_stations
-        target_slots = train_stations
         for agent_idx in range(num_agents):
             # Set target for agent
-            current_target_node = agent_start_targets_nodes[agent_idx][1]
-            target_station_idx = np.random.randint(len(target_slots[current_target_node]))
-            target = target_slots[current_target_node][target_station_idx]
-            target_slots[current_target_node].pop(target_station_idx)
-            agents_target.append((target[0][0], target[0][1]))
-
-            # Set start for agent and corresponding orientation
-            current_start_node = agent_start_targets_nodes[agent_idx][0]
-            agent_start_orientation = agent_start_targets_nodes[agent_idx][2]
-
-            # Place the agent on the corresponding track
-            if city_orientations[current_start_node] == agent_start_orientation:
-                track_to_use = 0
-            else:
-                track_to_use = 1
-
-            for i in range(len(start_slots[current_start_node])):
-                if start_slots[current_start_node][i][1] == track_to_use:
-                    start_station_idx = i
-                    break
-                else:
-                    start_station_idx = None
-
-            if start_station_idx is None:
-                warnings.warn("No slot available with required start orientation")
-                continue
-            start = start_slots[current_start_node][start_station_idx]
-            start_slots[current_start_node].pop(start_station_idx)
-
+            start_city = agent_start_targets_nodes[agent_idx][0]
+            target_city = agent_start_targets_nodes[agent_idx][0]
+            agent_orientation = agent_start_targets_nodes[agent_idx][2]
+            start_city_idx = np.random.randint(len(train_stations[start_city]))
+            start = train_stations[start_city][start_city_idx]
+            target_station_idx = np.random.randint(len(train_stations[target_city]))
+            target = train_stations[target_city][target_station_idx]
             agents_position.append((start[0][0], start[0][1]))
-            agents_direction.append(agent_start_orientation)
+            agents_target.append((target[0][0], target[0][1]))
+            agents_direction.append(agent_orientation)
             # Orient the agent correctly
 
-        for agent_idx in range(1, len(agents_position)):
-            agents_position[agent_idx] = agents_position[0]
-            agents_direction[agent_idx] = (agents_direction[0] + np.random.choice([0, 2])) % 4
-
         if speed_ratio_map:
             speeds = speed_initialization_helper(num_agents, speed_ratio_map)
         else:
-- 
GitLab