From 0b3967ed1929941ef5abc435e69f02e2743d6d2d Mon Sep 17 00:00:00 2001 From: nimishsantosh107 <nimishsantosh107@icloud.com> Date: Fri, 13 Aug 2021 05:36:28 +0530 Subject: [PATCH] allows odd number of agents --- flatland/envs/line_generators.py | 55 ++++++++++++++++---------------- flatland/envs/rail_env.py | 3 -- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 4a2ec647..c94e416d 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -142,13 +142,15 @@ class SparseLineGen(BaseLineGen): agents_target = [] agents_direction = [] - for agent_pair_idx in range(0, num_agents, 2): - infeasible_agent = True - tries = 0 - while infeasible_agent: - tries += 1 - infeasible_agent = False + city1, city2 = None, None + city1_num_stations, city2_num_stations = None, None + city1_possible_orientations, city2_possible_orientations = None, None + + + for agent_idx in range(num_agents): + + if (agent_idx % 2 == 0): # Setlect 2 cities, find their num_stations and possible orientations city_idx = np_random.choice(len(city_positions), 2, replace=False) city1 = city_idx[0] @@ -159,33 +161,32 @@ class SparseLineGen(BaseLineGen): (city_orientation[city1] + 2) % 4] city2_possible_orientations = [city_orientation[city2], (city_orientation[city2] + 2) % 4] + # Agent 1 : city1 > city2, Agent 2: city2 > city1 - agent1_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations - agent1_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations - agent2_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations - agent2_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations + agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations + agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations + + agent_start = train_stations[city1][agent_start_idx] + agent_target = train_stations[city2][agent_target_idx] + + agent_orientation = np_random.choice(city1_possible_orientations) + + + else: + agent_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations + agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations - agent1_start = train_stations[city1][agent1_start_idx] - agent1_target = train_stations[city2][agent1_target_idx] - agent2_start = train_stations[city2][agent2_start_idx] - agent2_target = train_stations[city1][agent2_target_idx] + agent_start = train_stations[city2][agent_start_idx] + agent_target = train_stations[city1][agent_target_idx] - agent1_orientation = np_random.choice(city1_possible_orientations) - agent2_orientation = np_random.choice(city2_possible_orientations) + agent_orientation = np_random.choice(city2_possible_orientations) - # check path exists then break if tries > 100 - if tries >= 100: - warnings.warn("Did not find any possible path, check your parameters!!!") - break # agent1 details - agents_position.append((agent1_start[0][0], agent1_start[0][1])) - agents_target.append((agent1_target[0][0], agent1_target[0][1])) - agents_direction.append(agent1_orientation) - # agent2 details - agents_position.append((agent2_start[0][0], agent2_start[0][1])) - agents_target.append((agent2_target[0][0], agent2_target[0][1])) - agents_direction.append(agent2_orientation) + agents_position.append((agent_start[0][0], agent_start[0][1])) + agents_target.append((agent_target[0][0], agent_target[0][1])) + agents_direction.append(agent_orientation) + if self.speed_ratio_map: speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 5d97e49d..0a0a84e6 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -23,7 +23,6 @@ from flatland.envs.rail_env_action import RailEnvActions from flatland.envs import malfunction_generators as mal_gen from flatland.envs import rail_generators as rail_gen from flatland.envs import line_generators as line_gen -# NEW : Imports from flatland.envs.schedule_generators import schedule_generator from flatland.envs import persistence from flatland.envs import agent_chains as ac @@ -199,8 +198,6 @@ class RailEnv(Environment): self.malfunction_generator = mal_gen.NoMalfunctionGen() self.malfunction_process_data = self.malfunction_generator.get_process_data() - if number_of_agents % 2 == 1: - raise ValueError("Odd number of agents is no longer supported, set number_of_agents to an even number") self.number_of_agents = number_of_agents # self.rail_generator: RailGenerator = rail_generator -- GitLab