diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index f2d37e9cd2bc9ec019bce1df4038668f6c18cedf..dca166a4a753aae0093da5e772fc65017a0e3cb3 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -164,14 +164,14 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> return [], [], [], [] if len(valid_positions) < num_agents: - warnings("schedule_generators: len(valid_positions) < num_agents") + warnings.warn("schedule_generators: len(valid_positions) < num_agents") return [], [], [], [] agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)] - update_agents = np.ones(num_agents) + update_agents = np.zeros(num_agents) re_generate = True cnt = 0 @@ -180,12 +180,10 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # update position for i in range(num_agents): if update_agents[i] == 1: - x = np.arange(len(valid_positions)) - x = np.setdiff1d(x,agents_position_idx) + x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx) agents_position_idx[i] = np.random.choice(x) agents_position[i] = valid_positions[agents_position_idx[i]] - x = np.arange(len(valid_positions)) - x = np.setdiff1d(x,agents_target_idx) + x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx) agents_target_idx[i] = np.random.choice(x) agents_target[i] = valid_positions[agents_target_idx[i]] update_agents = np.zeros(num_agents) @@ -212,8 +210,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> if len(valid_starting_directions) == 0: re_generate = True - print("agent:", i, new_position) - print("invalid") update_agents[i] = 1 break else: