diff --git a/docs/flatland_2.0.md b/docs/flatland_2.0.md index c41fea1377513721b3b96d705ddb6e4c81ccf6f4..fb9277b25636f6de619dc92e550c88ff5f74e0d1 100644 --- a/docs/flatland_2.0.md +++ b/docs/flatland_2.0.md @@ -90,14 +90,14 @@ This is very common for railway networks where the initial plan usually needs to We implemted a poisson process to simulate delays by stopping agents at random times for random durations. The parameters necessary for the stochastic events can be provided when creating the environment. -``` # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 10 # Max duration of malfunction - } - +``` +stochastic_data = { + 'prop_malfunction': 0.5, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 10 # Max duration of malfunction +} ``` The parameters are as follows: @@ -109,12 +109,23 @@ The parameters are as follows: You can introduce stochasticity by simply creating the env as follows: ``` -# Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction - } +env = RailEnv( + ... + stochastic_data=stochastic_data, # Malfunction data generator + ... +) +``` +In your controller, you can check whether an agent is malfunctioning: +``` +obs, rew, done, info = env.step(actions) +... +action_dict = dict() +for a in range(env.get_num_agents()): + if info['malfunction'][a] == 0: + action_dict.update({a: ...}) + +``` + # Custom observation builder tree_observation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) @@ -166,18 +177,18 @@ This action is then executed when a step to the next cell is valid. For example - Agents can make observations at any time step. Make sure to discard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation. - The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell -You can check whether an action has an effect in the environment's next step: +In your controller, you can check whether an action has an effect in the environment's next step: ``` obs, rew, done, info = env.step(actions) ... action_dict = dict() for a in range(env.get_num_agents()): - if info['actionable_agents'][a]: + if info['entering'][a] && info['malfunction'][a] == 0 &&: action_dict.update({a: ...}) ``` -Notice that `info['actionable_agents'][a]` does not mean that the action has an effect: -if the next cell is blocked, the action cannot be performed. If the action is valid, it will be performend, though. +Notice that `info['entering'][a]` does not mean that the action will have an effect: +if the next cell is blocked or the agent is malfunctioning, the action cannot be performed. ## Example code diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 9000aaa96ce25f8de0308be9ceb0c406cc522275..b9d925443731c2da5f717bfcc1bd1660539b40fd 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -312,7 +312,8 @@ class RailEnv(Environment): if self.dones["__all__"]: self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} info_dict = { - 'actionable_agents': {i: False for i in range(self.get_num_agents())} + 'entering': {i: False for i in range(self.get_num_agents())}, + 'malfunction': {i: 0 for i in range(self.get_num_agents())}, } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -454,15 +455,18 @@ class RailEnv(Environment): for k in self.dones.keys(): self.dones[k] = True - actionable_agents = {i: self.agents[i].speed_data['position_fraction'] <= epsilon \ + entering_agents = {i: self.agents[i].speed_data['position_fraction'] <= epsilon \ for i in range(self.get_num_agents()) } + malfunction_agents = {i: self.agents[i].malfunction_data['malfunction'] \ + for i in range(self.get_num_agents()) + } + info_dict = { - 'actionable_agents': actionable_agents + 'entering': entering_agents, + 'malfunction': malfunction_agents } - for i, agent in enumerate(self.agents): - print(" {}: {}".format(i, agent.position)) return self._get_observations(), self.rewards_dict, self.dones, info_dict def _check_action_on_agent(self, action, agent): diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 4f481dba0b370fdb10c5494f182a68434948f4f8..3aa9ea58f1db705c2a355910f70c14c44b58ff8d 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -29,7 +29,7 @@ def test_sparse_rail_generator(): # TODO test assertions! -def test_rail_env_actionable(): +def test_rail_env_entering_info(): np.random.seed(0) speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train @@ -54,7 +54,7 @@ def test_rail_env_actionable(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) np.random.seed(0) - env_only_if_actionable = RailEnv(width=50, + env_only_if_entering = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map num_intersections=10, @@ -79,30 +79,78 @@ def test_rail_env_actionable(): print("step {}".format(step)) action_dict_always_action = dict() - action_dict_only_if_actionable = dict() + action_dict_only_if_entering = dict() # Chose an action for each agent in the environment for a in range(env_always_action.get_num_agents()): action = np.random.choice(np.arange(4)) action_dict_always_action.update({a: action}) - if step == 0 or info_only_if_actionable['actionable_agents'][a]: - action_dict_only_if_actionable.update({a: action}) + if step == 0 or info_only_if_entering['entering'][a]: + action_dict_only_if_entering.update({a: action}) else: - print("[{}] not actionable {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data)) + print("[{}] not entering {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data)) obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step( action_dict_always_action) - obs_only_if_actionable, rewards_only_if_actionable, done_only_if_actionable, info_only_if_actionable = env_only_if_actionable.step( - action_dict_only_if_actionable) + obs_only_if_entering, rewards_only_if_entering, done_only_if_entering, info_only_if_entering = env_only_if_entering.step( + action_dict_only_if_entering) for a in range(env_always_action.get_num_agents()): - assert len(obs_always_action[a]) == len(obs_only_if_actionable[a]) + assert len(obs_always_action[a]) == len(obs_only_if_entering[a]) for i in range(len(obs_always_action[a])): - assert np.array_equal(obs_always_action[a][i], obs_only_if_actionable[a][i]) - assert np.array_equal(rewards_always_action[a], rewards_only_if_actionable[a]) - assert np.array_equal(done_always_action[a], done_only_if_actionable[a]) - assert info_always_action['actionable_agents'][a] == info_only_if_actionable['actionable_agents'][a] + assert np.array_equal(obs_always_action[a][i], obs_only_if_entering[a][i]) + assert np.array_equal(rewards_always_action[a], rewards_only_if_entering[a]) + assert np.array_equal(done_always_action[a], done_only_if_entering[a]) + assert info_always_action['entering'][a] == info_only_if_entering['entering'][a] env_renderer.render_env(show=True, show_observations=False, show_predictions=False) if done_always_action['__all__']: break + + +def test_rail_env_malfunction_info(): + np.random.seed(0) + stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 10 # Max duration of malfunction + } + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, + # Number of interesections in map + num_trainstations=50, + # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, + # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities + seed=5, # Random seed + grid_mode=False # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), + stochastic_data=stochastic_data) + + env_renderer = RenderTool(env, gl="PILSVG", ) + for step in range(100): + action_dict = dict() + # Chose an action for each agent in the environment + for a in range(env.get_num_agents()): + action = np.random.choice(np.arange(4)) + action_dict.update({a: action}) + + obs, rewards, done, info = env.step( + action_dict) + + assert 'malfunction' in info + for a in range(env.get_num_agents()): + assert info['malfunction'][a] >= 0 + + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + + if done['__all__']: + break