From afa93e0bcf94c3f1b20b4bc3408a6def6012da83 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 1 Oct 2019 13:36:19 -0400
Subject: [PATCH] all agents oriented as city

---
 flatland/envs/rail_generators.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index ad5c78a1..39a91bc4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -616,7 +616,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
         # Generate start target pairs
         schedule_time = time.time()
-        agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations)
+        agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations,
+                                                                             city_orientations)
         print("Schedule time", time.time() - schedule_time)
         return grid_map, {'agents_hints': {
             'num_agents': num_agents,
@@ -837,7 +838,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                 train_stations[current_city].append((possible_location, track_nbr))
         return train_stations, built_num_trainstations
 
-    def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
+    def _generate_start_target_pairs(num_agents, nb_nodes, train_stations, city_orientation):
         """
         Fill the trainstation positions with targets and goals
         :param num_agents:
@@ -864,12 +865,11 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             sum_start = sum(np.array(node_available_start)[avail_start_nodes])
             sum_target = sum(np.array(node_available_target)[avail_target_nodes])
             p_avail_start = [float(i) / sum_start for i in np.array(node_available_start)[avail_start_nodes]]
-            p_avail_target = [float(i) / sum_target for i in np.array(node_available_target)[avail_target_nodes]]
 
             start_target_tuple = np.random.choice(avail_start_nodes, p=p_avail_start, size=2, replace=False)
             start_node = start_target_tuple[0]
             target_node = start_target_tuple[1]
-            agent_start_targets_nodes.append((start_node, target_node, 0))
+            agent_start_targets_nodes.append((start_node, target_node, city_orientation[start_node]))
         return agent_start_targets_nodes, num_agents
 
     def _fix_transitions(city_cells, inter_city_lines, grid_map):
-- 
GitLab