From ea7a351eb41a28c0d96fa69ee1c3b6f4cc123781 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 2 Sep 2019 17:03:19 +0200 Subject: [PATCH] #154 actionable agents in info dict --- docs/flatland_2.0.md | 12 +++ flatland/envs/rail_env.py | 25 ++++-- ...est_flatland_envs_sparse_rail_generator.py | 82 +++++++++++++++++++ 3 files changed, 112 insertions(+), 7 deletions(-) diff --git a/docs/flatland_2.0.md b/docs/flatland_2.0.md index e35251dd..6ce07c90 100644 --- a/docs/flatland_2.0.md +++ b/docs/flatland_2.0.md @@ -150,6 +150,18 @@ Because the different speeds are implemented as fractions the agents ability to - Agents can make observations at any time step. Make sure to dscard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation. - The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell +You can check whether an action has an effect in the environment's next step: +``` +obs, rew, done, info = env.step(actions) +... +action_dict = dict() +for a in range(env.get_num_agents()): + if info['actionable_agents'][a]: + action_dict.update({a: ...}) + +``` +Notice that `info['actionable_agents'][a]` does not mean that the action has an effect: +if the next cell is blocked, the action cannot be performed. If the action is valid, it will be performend, though. ## Example code diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d59ca7dc..0e412d0c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -310,7 +310,10 @@ class RailEnv(Environment): if self.dones["__all__"]: self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} - return self._get_observations(), self.rewards_dict, self.dones, {} + info_dict = { + 'actionable_agents': {i: False for i in range(self.get_num_agents())} + } + return self._get_observations(), self.rewards_dict, self.dones, info_dict for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] @@ -422,18 +425,17 @@ class RailEnv(Environment): if agent.speed_data['position_fraction'] >= 1.0: # Perform stored action to transition to the next cell - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) - # Check that everything is still fee and that the agent can move + # Check that everything is still free and that the agent can move if all([new_cell_valid, transition_valid, cell_free]): agent.position = new_position agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 - else: - # If the agent cannot move due to any reason, we set its state to not moving - agent.moving = False + # else: + # # If the agent cannot move due to any reason, we set its state to not moving + # agent.moving = False if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True @@ -451,7 +453,16 @@ class RailEnv(Environment): for k in self.dones.keys(): self.dones[k] = True - return self._get_observations(), self.rewards_dict, self.dones, {} + actionable_agents = {i: self.agents[i].speed_data['position_fraction'] <= epsilon \ + for i in range(self.get_num_agents()) + } + info_dict = { + 'actionable_agents': actionable_agents + } + + for i, agent in enumerate(self.agents): + print(" {}: {}".format(i, agent.position)) + return self._get_observations(), self.rewards_dict, self.dones, info_dict def _check_action_on_agent(self, action, agent): # compute number of possible transitions in the current diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index c60d5062..4f481dba 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1,3 +1,5 @@ +import numpy as np + from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -24,3 +26,83 @@ def test_sparse_rail_generator(): env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) env_renderer.gl.save_image("./sparse_generator_false.png") + # TODO test assertions! + + +def test_rail_env_actionable(): + np.random.seed(0) + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + env_always_action = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, + # Number of interesections in map + num_trainstations=50, + # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, + # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities + seed=5, # Random seed + grid_mode=False # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) + np.random.seed(0) + env_only_if_actionable = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, + # Number of interesections in map + num_trainstations=50, + # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, + # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities + seed=5, # Random seed + grid_mode=False + # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) + env_renderer = RenderTool(env_always_action, gl="PILSVG", ) + + for step in range(100): + print("step {}".format(step)) + + action_dict_always_action = dict() + action_dict_only_if_actionable = dict() + # Chose an action for each agent in the environment + for a in range(env_always_action.get_num_agents()): + action = np.random.choice(np.arange(4)) + action_dict_always_action.update({a: action}) + if step == 0 or info_only_if_actionable['actionable_agents'][a]: + action_dict_only_if_actionable.update({a: action}) + else: + print("[{}] not actionable {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data)) + + obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step( + action_dict_always_action) + obs_only_if_actionable, rewards_only_if_actionable, done_only_if_actionable, info_only_if_actionable = env_only_if_actionable.step( + action_dict_only_if_actionable) + + for a in range(env_always_action.get_num_agents()): + assert len(obs_always_action[a]) == len(obs_only_if_actionable[a]) + for i in range(len(obs_always_action[a])): + assert np.array_equal(obs_always_action[a][i], obs_only_if_actionable[a][i]) + assert np.array_equal(rewards_always_action[a], rewards_only_if_actionable[a]) + assert np.array_equal(done_always_action[a], done_only_if_actionable[a]) + assert info_always_action['actionable_agents'][a] == info_only_if_actionable['actionable_agents'][a] + + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + + if done_always_action['__all__']: + break -- GitLab