diff --git a/docs/flatland_2.0.md b/docs/flatland_2.0.md index fb9277b25636f6de619dc92e550c88ff5f74e0d1..78649e18d3585d7a55d1c5e293cc241be995e31e 100644 --- a/docs/flatland_2.0.md +++ b/docs/flatland_2.0.md @@ -165,6 +165,18 @@ The different speed profiles can be generated using the `schedule_generator`, wh Keep in mind that the *fastest speed* is 1 and all slower speeds must be between 1 and 0. For the submission scoring you can assume that there will be no more than 5 speed profiles. + + +Later versions of **Flat**land might have varying speeds during episodes. Therefore, we return the agent speeds. +Notice that we do not guarantee that the speed will be computed at each step, but if not costly we will return it at each step. +In your controller, you can get the agents' speed from the `info` returned by `step`: +``` +obs, rew, done, info = env.step(actions) +... +for a in range(env.get_num_agents()): + speed = info['speed'][a] +``` + ## Actions and observation with different speed levels Because the different speeds are implemented as fractions the agents ability to perform actions has been updated. @@ -177,7 +189,7 @@ This action is then executed when a step to the next cell is valid. For example - Agents can make observations at any time step. Make sure to discard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation. - The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell -In your controller, you can check whether an action has an effect in the environment's next step: +In your controller, you can check whether an agent is entering by checking `info`: ``` obs, rew, done, info = env.step(actions) ... diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index b9d925443731c2da5f717bfcc1bd1660539b40fd..323742f7316dbda50ca8835cb0c7831f31cce01a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -314,6 +314,7 @@ class RailEnv(Environment): info_dict = { 'entering': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, + 'speed': {i: 0 for i in range(self.get_num_agents())} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -426,18 +427,17 @@ class RailEnv(Environment): if agent.speed_data['position_fraction'] >= 1.0: - # Perform stored action to transition to the next cell + # Perform stored action to transition to the next cell as soon as cell is free cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) - # Check that everything is still free and that the agent can move - if all([new_cell_valid, transition_valid, cell_free]): + if all([new_cell_valid, transition_valid, cell_free]) and agent.malfunction_data['malfunction'] == 0: agent.position = new_position agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 - # else: - # # If the agent cannot move due to any reason, we set its state to not moving - # agent.moving = False + elif not transition_valid or not new_cell_valid: + # If the agent cannot move due to an invalid transition, we set its state to not moving + agent.moving = False if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True @@ -455,16 +455,18 @@ class RailEnv(Environment): for k in self.dones.keys(): self.dones[k] = True - entering_agents = {i: self.agents[i].speed_data['position_fraction'] <= epsilon \ - for i in range(self.get_num_agents()) - } - malfunction_agents = {i: self.agents[i].malfunction_data['malfunction'] \ - for i in range(self.get_num_agents()) - } + entering_agents = { + i: self.agents[i].speed_data['position_fraction'] <= epsilon for i in range(self.get_num_agents()) + } + malfunction_agents = { + i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + } + speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} info_dict = { 'entering': entering_agents, - 'malfunction': malfunction_agents + 'malfunction': malfunction_agents, + 'speed': speed_agents } return self._get_observations(), self.rewards_dict, self.dones, info_dict diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 3aa9ea58f1db705c2a355910f70c14c44b58ff8d..a56873d12b44fd6e57a5b08a784f3a0807ced77e 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -55,24 +55,24 @@ def test_rail_env_entering_info(): obs_builder_object=GlobalObsForRailEnv()) np.random.seed(0) env_only_if_entering = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False - # Ordered distribution of nodes - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, + # Number of interesections in map + num_trainstations=50, + # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, + # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities + seed=5, # Random seed + grid_mode=False + # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) for step in range(100): @@ -108,7 +108,7 @@ def test_rail_env_entering_info(): break -def test_rail_env_malfunction_info(): +def test_rail_env_malfunction_speed_info(): np.random.seed(0) stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence @@ -149,6 +149,8 @@ def test_rail_env_malfunction_info(): assert 'malfunction' in info for a in range(env.get_num_agents()): assert info['malfunction'][a] >= 0 + assert info['speed'][a] >= 0 and info['speed'][a] <= 1 + assert info['speed'][a] == env.agents[a].speed_data['speed'] env_renderer.render_env(show=True, show_observations=False, show_predictions=False)