From 0b3967ed1929941ef5abc435e69f02e2743d6d2d Mon Sep 17 00:00:00 2001
From: nimishsantosh107 <nimishsantosh107@icloud.com>
Date: Fri, 13 Aug 2021 05:36:28 +0530
Subject: [PATCH] allows odd number of agents

---
 flatland/envs/line_generators.py | 55 ++++++++++++++++----------------
 flatland/envs/rail_env.py        |  3 --
 2 files changed, 28 insertions(+), 30 deletions(-)

diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py
index 4a2ec647..c94e416d 100644
--- a/flatland/envs/line_generators.py
+++ b/flatland/envs/line_generators.py
@@ -142,13 +142,15 @@ class SparseLineGen(BaseLineGen):
         agents_target = []
         agents_direction = []
 
-        for agent_pair_idx in range(0, num_agents, 2):
-            infeasible_agent = True
-            tries = 0
-            while infeasible_agent:
-                tries += 1
-                infeasible_agent = False
 
+        city1, city2 = None, None
+        city1_num_stations, city2_num_stations = None, None
+        city1_possible_orientations, city2_possible_orientations = None, None
+
+
+        for agent_idx in range(num_agents):
+
+            if (agent_idx % 2 == 0):
                 # Setlect 2 cities, find their num_stations and possible orientations
                 city_idx = np_random.choice(len(city_positions), 2, replace=False)
                 city1 = city_idx[0]
@@ -159,33 +161,32 @@ class SparseLineGen(BaseLineGen):
                                                 (city_orientation[city1] + 2) % 4]
                 city2_possible_orientations = [city_orientation[city2],
                                                 (city_orientation[city2] + 2) % 4]
+
                 # Agent 1 : city1 > city2, Agent 2: city2 > city1
-                agent1_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
-                agent1_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations
-                agent2_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations
-                agent2_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations
+                agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
+                agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations
+
+                agent_start = train_stations[city1][agent_start_idx]
+                agent_target = train_stations[city2][agent_target_idx]
+
+                agent_orientation = np_random.choice(city1_possible_orientations)
+
+
+            else:
+                agent_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations
+                agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations
                 
-                agent1_start = train_stations[city1][agent1_start_idx]
-                agent1_target = train_stations[city2][agent1_target_idx]
-                agent2_start = train_stations[city2][agent2_start_idx]
-                agent2_target = train_stations[city1][agent2_target_idx]
+                agent_start = train_stations[city2][agent_start_idx]
+                agent_target = train_stations[city1][agent_target_idx]
                             
-                agent1_orientation = np_random.choice(city1_possible_orientations)
-                agent2_orientation = np_random.choice(city2_possible_orientations)
+                agent_orientation = np_random.choice(city2_possible_orientations)
 
-                # check path exists then break if tries > 100
-                if tries >= 100:
-                    warnings.warn("Did not find any possible path, check your parameters!!!")
-                    break
             
             # agent1 details
-            agents_position.append((agent1_start[0][0], agent1_start[0][1]))
-            agents_target.append((agent1_target[0][0], agent1_target[0][1]))
-            agents_direction.append(agent1_orientation)
-            # agent2 details
-            agents_position.append((agent2_start[0][0], agent2_start[0][1]))
-            agents_target.append((agent2_target[0][0], agent2_target[0][1]))
-            agents_direction.append(agent2_orientation)
+            agents_position.append((agent_start[0][0], agent_start[0][1]))
+            agents_target.append((agent_target[0][0], agent_target[0][1]))
+            agents_direction.append(agent_orientation)
+
 
         if self.speed_ratio_map:
             speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 5d97e49d..0a0a84e6 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -23,7 +23,6 @@ from flatland.envs.rail_env_action import RailEnvActions
 from flatland.envs import malfunction_generators as mal_gen
 from flatland.envs import rail_generators as rail_gen
 from flatland.envs import line_generators as line_gen
-# NEW : Imports
 from flatland.envs.schedule_generators import schedule_generator
 from flatland.envs import persistence
 from flatland.envs import agent_chains as ac
@@ -199,8 +198,6 @@ class RailEnv(Environment):
             self.malfunction_generator = mal_gen.NoMalfunctionGen()
             self.malfunction_process_data = self.malfunction_generator.get_process_data()
         
-        if number_of_agents % 2 == 1:
-            raise ValueError("Odd number of agents is no longer supported, set number_of_agents to an even number")
         self.number_of_agents = number_of_agents
 
         # self.rail_generator: RailGenerator = rail_generator
-- 
GitLab