From ecd93268f9c95808366210520b44357fae6466e5 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 29 Aug 2019 10:30:02 +0200 Subject: [PATCH] fix convergence issue --- flatland/envs/schedule_generators.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index f2d37e9c..dca166a4 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -164,14 +164,14 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> return [], [], [], [] if len(valid_positions) < num_agents: - warnings("schedule_generators: len(valid_positions) < num_agents") + warnings.warn("schedule_generators: len(valid_positions) < num_agents") return [], [], [], [] agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)] - update_agents = np.ones(num_agents) + update_agents = np.zeros(num_agents) re_generate = True cnt = 0 @@ -180,12 +180,10 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # update position for i in range(num_agents): if update_agents[i] == 1: - x = np.arange(len(valid_positions)) - x = np.setdiff1d(x,agents_position_idx) + x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx) agents_position_idx[i] = np.random.choice(x) agents_position[i] = valid_positions[agents_position_idx[i]] - x = np.arange(len(valid_positions)) - x = np.setdiff1d(x,agents_target_idx) + x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx) agents_target_idx[i] = np.random.choice(x) agents_target[i] = valid_positions[agents_target_idx[i]] update_agents = np.zeros(num_agents) @@ -212,8 +210,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> if len(valid_starting_directions) == 0: re_generate = True - print("agent:", i, new_position) - print("invalid") update_agents[i] = 1 break else: -- GitLab