From 8afa13e1d4bd4ceda3df561a2a9a2b8af3783975 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 19 Aug 2019 11:36:34 -0400
Subject: [PATCH] removed constraint on start location

---
 flatland/envs/generators.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 7d160d06..a4b627eb 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1034,14 +1034,18 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
         agents_direction = []
 
         for agent_idx in range(num_agents):
+
+            # Set target for agent
             current_target_node = agent_start_targets_nodes[agent_idx][1]
             target_station_idx = np.random.randint(len(train_stations[current_target_node]))
             target = train_stations[current_target_node][target_station_idx]
             while (target[0], target[1]) in agents_target:
                 target_station_idx = np.random.randint(len(train_stations[current_target_node]))
                 target = train_stations[current_target_node][target_station_idx]
+
             agents_target.append((target[0], target[1]))
 
+            # Set start for agent
             current_start_node = agent_start_targets_nodes[agent_idx][0]
             start_station_idx = np.random.randint(len(train_stations[current_start_node]))
             start = train_stations[current_start_node][start_station_idx]
@@ -1049,6 +1053,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             while (start[0], start[1]) in agents_position:
                 start_station_idx = np.random.randint(len(train_stations[current_start_node]))
                 start = train_stations[current_start_node][start_station_idx]
+
             agents_position.append((start[0], start[1]))
 
             # Orient the agent correctly
@@ -1057,7 +1062,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 if any(transitions) > 0:
                     agents_direction.append(orientation)
                     continue
-            agent_idx += 1
 
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
-- 
GitLab