From 0107cbcb7d77763939bd872a40e6a456c64f88ca Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 28 Sep 2019 15:43:39 -0400
Subject: [PATCH] introduce probability when sampling for start and target
 positions

---
 flatland/envs/rail_generators.py | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index b0c4cc64..a86e95d2 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -815,7 +815,13 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
         return train_stations, built_num_trainstations
 
     def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
-
+        """
+        Fill the trainstation positions with targets and goals
+        :param num_agents:
+        :param nb_nodes:
+        :param train_stations:
+        :return:
+        """
         # Generate start and target node directory for all agents.
         # Assure that start and target are not in the same node
         agent_start_targets_nodes = []
@@ -831,8 +837,13 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
         for agent_idx in range(num_agents):
             avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
             avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
-            start_node = np.random.choice(avail_start_nodes)
-            target_node = np.random.choice(avail_target_nodes)
+            # Set probability to choose start and stop from trainstations
+            sum_start = sum(node_available_start)
+            sum_target = sum(node_available_target)
+            p_avail_start = [float(i) / sum_start for i in node_available_start]
+            p_avail_target = [float(i) / sum_target for i in node_available_target]
+            start_node = np.random.choice(avail_start_nodes, p=p_avail_start)
+            target_node = np.random.choice(avail_target_nodes, p=p_avail_target)
             tries = 0
             found_agent_pair = True
             while target_node == start_node:
-- 
GitLab