diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 1e80021c4c3c2f06e6d5c897692b814a5d063025..42917fb79e8d548f0ab29afb43c6de72ebb7d714 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -1,8 +1,6 @@ # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network -import time - from flatland.envs.observations import GlobalObsForRailEnv # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv @@ -91,6 +89,7 @@ env_renderer = RenderTool(env, gl="PILSVG", screen_width=1000) # Adjust these parameters to fit your resolution +# The first thing we notice is that some agents don't have feasible paths to their target. # We first look at the map we have created # nv_renderer.render_env(show=True) @@ -217,19 +216,19 @@ for agent_idx, agent in enumerate(env.agents): for a in range(env.get_num_agents()): action = controller.act(0) action_dict.update({a: action}) - # Do the environment step observations, rewards, dones, information = env.step(action_dict) -print("\n Thefollowing agents can register an action:") +print("\n The following agents can register an action:") print("========================================") -print(information['action_required']) +for info in information['action_required']: + print("Agent {} needs to submit an action.".format(info)) # We recommend that you monitor the malfunction data and the action required in order to optimize your training # and controlling code. # Let us now look at an episode playing out with random actions performed -print("Start episode...") +print("\nStart episode...") # Reset the rendering system env_renderer.reset() @@ -237,10 +236,12 @@ env_renderer.reset() # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository + score = 0 # Run episode frame_step = 0 -for step in range(500): + +for step in range(10): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): action = controller.act(observations[a]) @@ -248,10 +249,9 @@ for step in range(500): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done - start_time = time.time() + next_obs, all_rewards, done, _ = env.step(action_dict) - end_time = time.time() - print(end_time - start_time) + # env_renderer.render_env(show=True, show_observations=False, show_predictions=False) frame_step += 1 # Update replay buffer and train agent diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index ef05733db7564dc6a9bfb132cf42c047b9928b7a..ef27ac1a18db6eb05ee5380c1a1a75332733142a 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -82,29 +82,37 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see agents_position = [] agents_target = [] agents_direction = [] - for agent_idx in range(num_agents): - - # Set target for agent - start_city = agent_start_targets_cities[agent_idx][0] - target_city = agent_start_targets_cities[agent_idx][1] - start_idx = np.random.choice(np.arange(len(train_stations[start_city]))) - target_idx = np.random.choice(np.arange(len(train_stations[target_city]))) - start = train_stations[start_city][start_idx] - target = train_stations[target_city][target_idx] + for agent_idx in range(num_agents): + infeasible_agent = True + tries = 0 + while infeasible_agent: + tries += 1 + infeasible_agent = False + # Set target for agent + city_idx = np.random.randint(len(agent_start_targets_cities)) + start_city = agent_start_targets_cities[city_idx][0] + target_city = agent_start_targets_cities[city_idx][1] - while start[1] % 2 != 0: start_idx = np.random.choice(np.arange(len(train_stations[start_city]))) - start = train_stations[start_city][start_idx] - while target[1] % 2 != 1: target_idx = np.random.choice(np.arange(len(train_stations[target_city]))) + start = train_stations[start_city][start_idx] target = train_stations[target_city][target_idx] - agent_orientation = (agent_start_targets_cities[agent_idx][2] + 2 * start[1]) % 4 - if not rail.check_path_exists(start[0], agent_orientation, target[0]): - agent_orientation = (agent_orientation + 2) % 4 - if not (rail.check_path_exists(start[0], agent_orientation, target[0])): - warnings.warn("Infeasible task for agent {}".format(agent_idx)) + while start[1] % 2 != 0: + start_idx = np.random.choice(np.arange(len(train_stations[start_city]))) + start = train_stations[start_city][start_idx] + while target[1] % 2 != 1: + target_idx = np.random.choice(np.arange(len(train_stations[target_city]))) + target = train_stations[target_city][target_idx] + agent_orientation = (agent_start_targets_cities[city_idx][2] + 2 * start[1]) % 4 + if not rail.check_path_exists(start[0], agent_orientation, target[0]): + agent_orientation = (agent_orientation + 2) % 4 + if not (rail.check_path_exists(start[0], agent_orientation, target[0])): + infeasible_agent = True + if tries >= 100: + warnings.warn("Did not find any possible path, check your parameters!!!") + break agents_position.append((start[0][0], start[0][1])) agents_target.append((target[0][0], target[0][1])) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 7bd6334b877478c19205e56431dd95111a97caee..f254e09126de58f856649eb7a777ed5ca0c49c8e 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -504,7 +504,7 @@ def test_sparse_rail_generator(): for a in range(env.get_num_agents()): s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) - assert s0 == 33, "actual={}".format(s0) + assert s0 == 31, "actual={}".format(s0) assert s1 == 24, "actual={}".format(s1) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index e9b5a15dade4fd30d6718886eed02a483172f159..8bbfefbb2a0c5b8da3c0c7ff465290294d568f84 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -122,11 +122,11 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format( + assert env.agents[0].malfunction_data['nr_malfunctions'] == 20, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) # Check that 20 stops where performed - assert agent_halts == 21 + assert agent_halts == 20 # Check that malfunctioning data was standing around assert total_down_time > 0