From ea7a351eb41a28c0d96fa69ee1c3b6f4cc123781 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 2 Sep 2019 17:03:19 +0200
Subject: [PATCH] #154 actionable agents in info dict

---
 docs/flatland_2.0.md                          | 12 +++
 flatland/envs/rail_env.py                     | 25 ++++--
 ...est_flatland_envs_sparse_rail_generator.py | 82 +++++++++++++++++++
 3 files changed, 112 insertions(+), 7 deletions(-)

diff --git a/docs/flatland_2.0.md b/docs/flatland_2.0.md
index e35251dd..6ce07c90 100644
--- a/docs/flatland_2.0.md
+++ b/docs/flatland_2.0.md
@@ -150,6 +150,18 @@ Because the different speeds are implemented as fractions the agents ability to
     - Agents can make observations at any time step. Make sure to dscard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation.
 - The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell
 
+You can check whether an action has an effect in the environment's next step: 
+```
+obs, rew, done, info = env.step(actions) 
+...
+action_dict = dict()
+for a in range(env.get_num_agents()):
+    if info['actionable_agents'][a]:
+        action_dict.update({a: ...})
+
+```
+Notice that `info['actionable_agents'][a]` does not mean that the action has an effect: 
+if the next cell is blocked, the action cannot be performed. If the action is valid, it will be performend, though. 
 
 ## Example code
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index d59ca7dc..0e412d0c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -310,7 +310,10 @@ class RailEnv(Environment):
 
         if self.dones["__all__"]:
             self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
-            return self._get_observations(), self.rewards_dict, self.dones, {}
+            info_dict = {
+                'actionable_agents': {i: False for i in range(self.get_num_agents())}
+            }
+            return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
         for i_agent in range(self.get_num_agents()):
             agent = self.agents[i_agent]
@@ -422,18 +425,17 @@ class RailEnv(Environment):
             if agent.speed_data['position_fraction'] >= 1.0:
 
                 # Perform stored action to transition to the next cell
-
                 cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
                     self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
 
-                # Check that everything is still fee and that the agent can move
+                # Check that everything is still free and that the agent can move
                 if all([new_cell_valid, transition_valid, cell_free]):
                     agent.position = new_position
                     agent.direction = new_direction
                     agent.speed_data['position_fraction'] = 0.0
-                else:
-                    # If the agent cannot move due to any reason, we set its state to not moving
-                    agent.moving = False
+                # else:
+                #     # If the agent cannot move due to any reason, we set its state to not moving
+                #     agent.moving = False
 
             if np.equal(agent.position, agent.target).all():
                 self.dones[i_agent] = True
@@ -451,7 +453,16 @@ class RailEnv(Environment):
             for k in self.dones.keys():
                 self.dones[k] = True
 
-        return self._get_observations(), self.rewards_dict, self.dones, {}
+        actionable_agents = {i: self.agents[i].speed_data['position_fraction'] <= epsilon \
+                             for i in range(self.get_num_agents())
+                             }
+        info_dict = {
+            'actionable_agents': actionable_agents
+        }
+
+        for i, agent in enumerate(self.agents):
+            print(" {}: {}".format(i, agent.position))
+        return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
     def _check_action_on_agent(self, action, agent):
         # compute number of possible transitions in the current
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index c60d5062..4f481dba 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1,3 +1,5 @@
+import numpy as np
+
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
@@ -24,3 +26,83 @@ def test_sparse_rail_generator():
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
     env_renderer.gl.save_image("./sparse_generator_false.png")
+    # TODO test assertions!
+
+
+def test_rail_env_actionable():
+    np.random.seed(0)
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
+    env_always_action = RailEnv(width=50,
+                                height=50,
+                                rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                                     num_intersections=10,
+                                                                     # Number of interesections in map
+                                                                     num_trainstations=50,
+                                                                     # Number of possible start/targets on map
+                                                                     min_node_dist=6,  # Minimal distance of nodes
+                                                                     node_radius=3,
+                                                                     # Proximity of stations to city center
+                                                                     num_neighb=3,
+                                                                     # Number of connections to other cities
+                                                                     seed=5,  # Random seed
+                                                                     grid_mode=False  # Ordered distribution of nodes
+                                                                     ),
+                                schedule_generator=sparse_schedule_generator(speed_ration_map),
+                                number_of_agents=10,
+                                obs_builder_object=GlobalObsForRailEnv())
+    np.random.seed(0)
+    env_only_if_actionable = RailEnv(width=50,
+                                     height=50,
+                                     rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                                          num_intersections=10,
+                                                                          # Number of interesections in map
+                                                                          num_trainstations=50,
+                                                                          # Number of possible start/targets on map
+                                                                          min_node_dist=6,  # Minimal distance of nodes
+                                                                          node_radius=3,
+                                                                          # Proximity of stations to city center
+                                                                          num_neighb=3,
+                                                                          # Number of connections to other cities
+                                                                          seed=5,  # Random seed
+                                                                          grid_mode=False
+                                                                          # Ordered distribution of nodes
+                                                                          ),
+                                     schedule_generator=sparse_schedule_generator(speed_ration_map),
+                                     number_of_agents=10,
+                                     obs_builder_object=GlobalObsForRailEnv())
+    env_renderer = RenderTool(env_always_action, gl="PILSVG", )
+
+    for step in range(100):
+        print("step {}".format(step))
+
+        action_dict_always_action = dict()
+        action_dict_only_if_actionable = dict()
+        # Chose an action for each agent in the environment
+        for a in range(env_always_action.get_num_agents()):
+            action = np.random.choice(np.arange(4))
+            action_dict_always_action.update({a: action})
+            if step == 0 or info_only_if_actionable['actionable_agents'][a]:
+                action_dict_only_if_actionable.update({a: action})
+            else:
+                print("[{}] not actionable {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data))
+
+        obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
+            action_dict_always_action)
+        obs_only_if_actionable, rewards_only_if_actionable, done_only_if_actionable, info_only_if_actionable = env_only_if_actionable.step(
+            action_dict_only_if_actionable)
+
+        for a in range(env_always_action.get_num_agents()):
+            assert len(obs_always_action[a]) == len(obs_only_if_actionable[a])
+            for i in range(len(obs_always_action[a])):
+                assert np.array_equal(obs_always_action[a][i], obs_only_if_actionable[a][i])
+            assert np.array_equal(rewards_always_action[a], rewards_only_if_actionable[a])
+            assert np.array_equal(done_always_action[a], done_only_if_actionable[a])
+            assert info_always_action['actionable_agents'][a] == info_only_if_actionable['actionable_agents'][a]
+
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+
+        if done_always_action['__all__']:
+            break
-- 
GitLab