From 3579e545184c0a6a24e25065ba5cb86624579f75 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 10 Oct 2019 10:55:38 -0400
Subject: [PATCH] fixed infeasible agents to always have feasible tasks

---
 examples/introduction_flatland_2_1.py         | 20 ++++-----
 flatland/envs/schedule_generators.py          | 42 +++++++++++--------
 ...est_flatland_envs_sparse_rail_generator.py |  2 +-
 tests/test_flatland_malfunction.py            |  4 +-
 4 files changed, 38 insertions(+), 30 deletions(-)

diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index 1e80021c..42917fb7 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -1,8 +1,6 @@
 # In Flatland you can use custom observation builders and predicitors
 # Observation builders generate the observation needed by the controller
 # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
-import time
-
 from flatland.envs.observations import GlobalObsForRailEnv
 # First of all we import the Flatland rail environment
 from flatland.envs.rail_env import RailEnv
@@ -91,6 +89,7 @@ env_renderer = RenderTool(env, gl="PILSVG",
                           screen_width=1000)  # Adjust these parameters to fit your resolution
 
 
+# The first thing we notice is that some agents don't have feasible paths to their target.
 # We first look at the map we have created
 
 # nv_renderer.render_env(show=True)
@@ -217,19 +216,19 @@ for agent_idx, agent in enumerate(env.agents):
 for a in range(env.get_num_agents()):
     action = controller.act(0)
     action_dict.update({a: action})
-
 # Do the environment step
 observations, rewards, dones, information = env.step(action_dict)
-print("\n Thefollowing agents can register an action:")
+print("\n The following agents can register an action:")
 print("========================================")
-print(information['action_required'])
+for info in information['action_required']:
+    print("Agent {} needs to submit an action.".format(info))
 
 # We recommend that you monitor the malfunction data and the action required in order to optimize your training
 # and controlling code.
 
 # Let us now look at an episode playing out with random actions performed
 
-print("Start episode...")
+print("\nStart episode...")
 
 # Reset the rendering system
 env_renderer.reset()
@@ -237,10 +236,12 @@ env_renderer.reset()
 # Here you can also further enhance the provided observation by means of normalization
 # See training navigation example in the baseline repository
 
+
 score = 0
 # Run episode
 frame_step = 0
-for step in range(500):
+
+for step in range(10):
     # Chose an action for each agent in the environment
     for a in range(env.get_num_agents()):
         action = controller.act(observations[a])
@@ -248,10 +249,9 @@ for step in range(500):
 
     # Environment step which returns the observations for all agents, their corresponding
     # reward and whether their are done
-    start_time = time.time()
+
     next_obs, all_rewards, done, _ = env.step(action_dict)
-    end_time = time.time()
-    print(end_time - start_time)
+
     # env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
     frame_step += 1
     # Update replay buffer and train agent
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index ef05733d..ef27ac1a 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -82,29 +82,37 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
         agents_position = []
         agents_target = []
         agents_direction = []
-        for agent_idx in range(num_agents):
-
-            # Set target for agent
-            start_city = agent_start_targets_cities[agent_idx][0]
-            target_city = agent_start_targets_cities[agent_idx][1]
 
-            start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
-            target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
-            start = train_stations[start_city][start_idx]
-            target = train_stations[target_city][target_idx]
+        for agent_idx in range(num_agents):
+            infeasible_agent = True
+            tries = 0
+            while infeasible_agent:
+                tries += 1
+                infeasible_agent = False
+                # Set target for agent
+                city_idx = np.random.randint(len(agent_start_targets_cities))
+                start_city = agent_start_targets_cities[city_idx][0]
+                target_city = agent_start_targets_cities[city_idx][1]
 
-            while start[1] % 2 != 0:
                 start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
-                start = train_stations[start_city][start_idx]
-            while target[1] % 2 != 1:
                 target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
+                start = train_stations[start_city][start_idx]
                 target = train_stations[target_city][target_idx]
-            agent_orientation = (agent_start_targets_cities[agent_idx][2] + 2 * start[1]) % 4
-            if not rail.check_path_exists(start[0], agent_orientation, target[0]):
-                agent_orientation = (agent_orientation + 2) % 4
-            if not (rail.check_path_exists(start[0], agent_orientation, target[0])):
-                warnings.warn("Infeasible task for agent {}".format(agent_idx))
 
+                while start[1] % 2 != 0:
+                    start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
+                    start = train_stations[start_city][start_idx]
+                while target[1] % 2 != 1:
+                    target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
+                    target = train_stations[target_city][target_idx]
+                agent_orientation = (agent_start_targets_cities[city_idx][2] + 2 * start[1]) % 4
+                if not rail.check_path_exists(start[0], agent_orientation, target[0]):
+                    agent_orientation = (agent_orientation + 2) % 4
+                if not (rail.check_path_exists(start[0], agent_orientation, target[0])):
+                    infeasible_agent = True
+                if tries >= 100:
+                    warnings.warn("Did not find any possible path, check your parameters!!!")
+                    break
             agents_position.append((start[0][0], start[0][1]))
             agents_target.append((target[0][0], target[0][1]))
 
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 7bd6334b..f254e091 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -504,7 +504,7 @@ def test_sparse_rail_generator():
     for a in range(env.get_num_agents()):
         s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0))
         s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0))
-    assert s0 == 33, "actual={}".format(s0)
+    assert s0 == 31, "actual={}".format(s0)
     assert s1 == 24, "actual={}".format(s1)
 
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e9b5a15d..8bbfefbb 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -122,11 +122,11 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 20, "Actual {}".format(
         env.agents[0].malfunction_data['nr_malfunctions'])
 
     # Check that 20 stops where performed
-    assert agent_halts == 21
+    assert agent_halts == 20
 
     # Check that malfunctioning data was standing around
     assert total_down_time > 0
-- 
GitLab