Skip to content
Snippets Groups Projects
Commit ea7a351e authored by u214892's avatar u214892
Browse files

#154 actionable agents in info dict

parent 0d264139
No related branches found
No related tags found
No related merge requests found
......@@ -150,6 +150,18 @@ Because the different speeds are implemented as fractions the agents ability to
- Agents can make observations at any time step. Make sure to dscard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation.
- The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell
You can check whether an action has an effect in the environment's next step:
```
obs, rew, done, info = env.step(actions)
...
action_dict = dict()
for a in range(env.get_num_agents()):
if info['actionable_agents'][a]:
action_dict.update({a: ...})
```
Notice that `info['actionable_agents'][a]` does not mean that the action has an effect:
if the next cell is blocked, the action cannot be performed. If the action is valid, it will be performend, though.
## Example code
......
......@@ -310,7 +310,10 @@ class RailEnv(Environment):
if self.dones["__all__"]:
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
return self._get_observations(), self.rewards_dict, self.dones, {}
info_dict = {
'actionable_agents': {i: False for i in range(self.get_num_agents())}
}
return self._get_observations(), self.rewards_dict, self.dones, info_dict
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
......@@ -422,18 +425,17 @@ class RailEnv(Environment):
if agent.speed_data['position_fraction'] >= 1.0:
# Perform stored action to transition to the next cell
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
# Check that everything is still fee and that the agent can move
# Check that everything is still free and that the agent can move
if all([new_cell_valid, transition_valid, cell_free]):
agent.position = new_position
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
else:
# If the agent cannot move due to any reason, we set its state to not moving
agent.moving = False
# else:
# # If the agent cannot move due to any reason, we set its state to not moving
# agent.moving = False
if np.equal(agent.position, agent.target).all():
self.dones[i_agent] = True
......@@ -451,7 +453,16 @@ class RailEnv(Environment):
for k in self.dones.keys():
self.dones[k] = True
return self._get_observations(), self.rewards_dict, self.dones, {}
actionable_agents = {i: self.agents[i].speed_data['position_fraction'] <= epsilon \
for i in range(self.get_num_agents())
}
info_dict = {
'actionable_agents': actionable_agents
}
for i, agent in enumerate(self.agents):
print(" {}: {}".format(i, agent.position))
return self._get_observations(), self.rewards_dict, self.dones, info_dict
def _check_action_on_agent(self, action, agent):
# compute number of possible transitions in the current
......
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
......@@ -24,3 +26,83 @@ def test_sparse_rail_generator():
env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
env_renderer.gl.save_image("./sparse_generator_false.png")
# TODO test assertions!
def test_rail_env_actionable():
np.random.seed(0)
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env_always_action = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
np.random.seed(0)
env_only_if_actionable = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
for step in range(100):
print("step {}".format(step))
action_dict_always_action = dict()
action_dict_only_if_actionable = dict()
# Chose an action for each agent in the environment
for a in range(env_always_action.get_num_agents()):
action = np.random.choice(np.arange(4))
action_dict_always_action.update({a: action})
if step == 0 or info_only_if_actionable['actionable_agents'][a]:
action_dict_only_if_actionable.update({a: action})
else:
print("[{}] not actionable {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data))
obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
action_dict_always_action)
obs_only_if_actionable, rewards_only_if_actionable, done_only_if_actionable, info_only_if_actionable = env_only_if_actionable.step(
action_dict_only_if_actionable)
for a in range(env_always_action.get_num_agents()):
assert len(obs_always_action[a]) == len(obs_only_if_actionable[a])
for i in range(len(obs_always_action[a])):
assert np.array_equal(obs_always_action[a][i], obs_only_if_actionable[a][i])
assert np.array_equal(rewards_always_action[a], rewards_only_if_actionable[a])
assert np.array_equal(done_always_action[a], done_only_if_actionable[a])
assert info_always_action['actionable_agents'][a] == info_only_if_actionable['actionable_agents'][a]
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
if done_always_action['__all__']:
break
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment