diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index f659ec8436a941606b6d649e24d2481e5be9b66d..5e6bcb20a0fd1a23513ea7737b66401417e63274 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,3 +1,4 @@ +from enum import IntEnum from itertools import starmap from typing import Tuple @@ -7,6 +8,12 @@ from attr import attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum +class RailAgentStatus(IntEnum): + READY_TO_DEPART = 0 + ACTIVE = 1 + DONE = 2 + + @attrs class EnvAgentStatic(object): """ EnvAgentStatic - Stores initial position, direction and target. @@ -18,6 +25,7 @@ class EnvAgentStatic(object): direction = attrib(type=Grid4TransitionsEnum) target = attrib(type=Tuple[int, int]) moving = attrib(default=False, type=bool) + # position = attrib(default=None,type=Optional[Tuple[int, int]]) # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous @@ -33,6 +41,8 @@ class EnvAgentStatic(object): lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False}))) + status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) + @classmethod def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index c23d4345a03c761ad4c4ac1d936db817f8acc529..2612fd10b3b8b4f7c8f861983c52b993dd97766c 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -11,11 +11,11 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_utils import RailAgentStatus from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): - Node = collections.namedtuple('Node', 'dist_own_target_encountered ' 'dist_other_target_encountered ' 'dist_other_agent_encountered ' @@ -296,7 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder): self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 @@ -307,7 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder): and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step+1 @@ -318,7 +318,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: @@ -621,7 +621,8 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + def get_many(self, handles: Optional[List[int]] = None) -> Dict[ + int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4fb1c55e5a179e99e6c4985fff3c5596ba1f0a66..43e095fb9f68b851667b75ed4ef7d62693e78a7c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -15,7 +15,7 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent +from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator @@ -354,7 +354,8 @@ class RailEnv(Environment): info_dict = { 'action_required': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, - 'speed': {i: 0 for i in range(self.get_num_agents())} + 'speed': {i: 0 for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -369,21 +370,18 @@ class RailEnv(Environment): if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True - for k in self.dones.keys(): - self.dones[k] = True - - action_required_agents = { - i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) - } - malfunction_agents = { - i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) - } - speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} + for i in range(self.get_num_agents()): + self.agents[i].status = RailAgentStatus.DONE + self.dones[i] = True info_dict = { - 'action_required': action_required_agents, - 'malfunction': malfunction_agents, - 'speed': speed_agents + 'action_required': {i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in + range(self.get_num_agents())}, + 'malfunction': { + i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + }, + 'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -401,10 +399,18 @@ class RailEnv(Environment): action_dict_ : Dict[int,RailEnvActions] """ - if self.dones[i_agent]: # this agent has already completed... + agent = self.agents[i_agent] + if agent.status == RailAgentStatus.DONE: # this agent has already completed... return - agent = self.agents[i_agent] + # agent gets active by a MOVE_* action and if c + if agent.status == RailAgentStatus.READY_TO_DEPART: + if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, + RailEnvActions.MOVE_FORWARD]: # and self.cell_free(agent.position): + agent.status = RailAgentStatus.ACTIVE + else: + return + agent.old_direction = agent.direction agent.old_position = agent.position @@ -497,6 +503,7 @@ class RailEnv(Environment): # has the agent reached its target? if np.equal(agent.position, agent.target).all(): + agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False else: @@ -543,9 +550,12 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_free = self.cell_free(new_position) return cell_free, new_cell_valid, new_direction, new_position, transition_valid + def cell_free(self, position): + return not np.any(np.equal(position, [agent.position for agent in self.agents]).all(1)) + def check_action(self, agent: EnvAgent, action: RailEnvActions): """ @@ -591,7 +601,7 @@ class RailEnv(Environment): return self.obs_dict def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: - return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col)) + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) def get_full_state_msg(self): grid_data = self.rail.grid.tolist() diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py new file mode 100644 index 0000000000000000000000000000000000000000..28e137f3603dab77a535fc1845b2c64de418a3a8 --- /dev/null +++ b/tests/test_flaltland_rail_agent_status.py @@ -0,0 +1,120 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay + +np.random.seed(1) + + +def test_initial_status(): + """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + set_penalties_for_replay(env) + test_config = ReplayConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + reward=0, + status=RailAgentStatus.READY_TO_DEPART + ), + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty! + status=RailAgentStatus.READY_TO_DEPART + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # done + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ) + + ], + target=(3, 5), + speed=0.5 + ) + + run_replay_config(env, [test_config]) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index fa2920a9fa7c78331cc7c32ae5308633b4d3f8da..083b15ea143285c2d46191dc5de54e5217cd92ad 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -2,14 +2,15 @@ import random from typing import Dict, List import numpy as np -from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay class SingleAgentNavigationObs(ObservationBuilder): @@ -271,17 +272,19 @@ def test_initial_malfunction_stop_moving(): Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, - reward=env.step_penalty # full step penalty when stopped + reward=env.step_penalty, # full step penalty when stopped + status=RailAgentStatus.READY_TO_DEPART ), Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, - reward=env.step_penalty # full step penalty when stopped + reward=env.step_penalty, # full step penalty when stopped + status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving @@ -291,7 +294,8 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, malfunction=1, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we have stopped and do nothing --> should stand still Replay( @@ -299,7 +303,8 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( @@ -307,14 +312,16 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped + reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), Replay( position=(28, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # full step penalty while stopped + reward=env.step_penalty * 1.0, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], @@ -363,17 +370,19 @@ def test_initial_malfunction_do_nothing(): replay=[Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, - reward=env.step_penalty # full step penalty while malfunctioning + reward=env.step_penalty, # full step penalty while malfunctioning + status=RailAgentStatus.READY_TO_DEPART ), Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, - reward=env.step_penalty # full step penalty while malfunctioning + reward=env.step_penalty, # full step penalty while malfunctioning + status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving @@ -383,7 +392,8 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=1, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we haven't started moving yet --> stay here Replay( @@ -391,7 +401,8 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( @@ -399,14 +410,16 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0 + reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 + status=RailAgentStatus.ACTIVE ), Replay( position=(28, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # step penalty for speed 1.0 + reward=env.step_penalty * 1.0, # step penalty for speed 1.0 + status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], diff --git a/tests/test_utils.py b/tests/test_utils.py index 903120d868aa65833e7c2393ddfcc821c26da4f6..3da01998857acfac460a5bfa0df01b4845b212d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ import numpy as np from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.utils.rendertools import RenderTool @@ -18,6 +18,7 @@ class Replay(object): malfunction = attrib(default=0, type=int) set_malfunction = attrib(default=None, type=Optional[int]) reward = attrib(default=None, type=Optional[float]) + status = attrib(default=None, type=Optional[RailAgentStatus]) @attrs @@ -47,10 +48,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: - position, direction before step are verified - optionally, set_malfunction is applied - malfunction is verified + - status is verified (optionally) *After each step* - reward is verified after step + Parameters ---------- env @@ -77,8 +80,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: def _assert(a, actual, expected, msg): assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, - actual, - expected) + actual, + expected) action_dict = {} @@ -88,6 +91,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') + if replay.status is not None: + _assert(a, agent.status, replay.status, 'status') if replay.action is not None: assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( @@ -109,5 +114,3 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for a, test_config in enumerate(test_configs): replay = test_config.replay[step] _assert(a, rewards_dict[a], replay.reward, 'reward') - -