diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 855d1f5dffbef29da26ca1dead933af846b863bd..f75cb74537f03dc6fb1aecbadff37a183432e55a 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -54,11 +54,9 @@ class ObservePredictions(ObservationBuilder): pos_list.append(self.predictions[a][t][1:3]) # We transform (x,y) coodrinates to a single integer number for simpler comparison self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) - observations = {} - # Collect all the different observation for all the agents - for h in handles: - observations[h] = self.get(h) + observations = super().get_many(handles) + return observations def get(self, handle: int = 0) -> np.ndarray: diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 3cc21966162dd28d183493a97cd6072a34abb738..2302fff9cdeeaf162b7ab38b1e67f5052c76b3ef 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -24,7 +24,7 @@ class ObservationBuilder: self.env = None def set_env(self, env: Environment): - self.env = env + self.env: Environment = env def reset(self): """ diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index f659ec8436a941606b6d649e24d2481e5be9b66d..d8c05c2020a9524ae3e6eab232a3e320bc699187 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,5 +1,6 @@ +from enum import IntEnum from itertools import starmap -from typing import Tuple +from typing import Tuple, Optional import numpy as np from attr import attrs, attrib, Factory @@ -7,6 +8,13 @@ from attr import attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum +class RailAgentStatus(IntEnum): + READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position + ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path + DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever + DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None + + @attrs class EnvAgentStatic(object): """ EnvAgentStatic - Stores initial position, direction and target. @@ -14,7 +22,7 @@ class EnvAgentStatic(object): rather than where it is at the moment. The target should also be stored here. """ - position = attrib(type=Tuple[int, int]) + initial_position = attrib(type=Tuple[int, int]) direction = attrib(type=Grid4TransitionsEnum) target = attrib(type=Tuple[int, int]) moving = attrib(default=False, type=bool) @@ -33,6 +41,9 @@ class EnvAgentStatic(object): lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False}))) + status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) + position = attrib(default=None, type=Optional[Tuple[int, int]]) + @classmethod def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets @@ -65,7 +76,7 @@ class EnvAgentStatic(object): # I can't find an expression which works on both tuples, lists and ndarrays # which converts them all to a list of native python ints. - lPos = self.position + lPos = self.initial_position if type(lPos) is np.ndarray: lPos = lPos.tolist() diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index c13c4cb7c6e3926a522adbd0055e53132079a2c1..8adf94d3d5b01946aa42659c19b6060f0d593868 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -11,11 +11,20 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_utils import RailAgentStatus, EnvAgent from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the graph structure of the rail + network to simplify the representation of the state of the environment for each agent. + For details about the features in the tree observation see the get() function. + """ Node = collections.namedtuple('Node', 'dist_own_target_encountered ' 'dist_other_target_encountered ' 'dist_other_agent_encountered ' @@ -27,19 +36,10 @@ class TreeObsForRailEnv(ObservationBuilder): 'num_agents_opposite_direction ' 'num_agents_malfunctioning ' 'speed_min_fractional ' + 'num_agents_ready_to_depart ' 'childs') - tree_explorted_actions_char = ['L', 'F', 'R', 'B'] - - """ - TreeObsForRailEnv object. - - This object returns observation vectors for agents in the RailEnv environment. - The information is local to each agent and exploits the graph structure of the rail - network to simplify the representation of the state of the environment for each agent. - - For details about the features in the tree observation see the get() function. - """ + tree_explored_actions_char = ['L', 'F', 'R', 'B'] def __init__(self, max_depth: int, predictor: PredictionBuilder = None): super().__init__() @@ -67,19 +67,21 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: - + # TODO hacky hacky: `range(len(self.predictions[0]))` does not seem safe!! for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] for a in handles: + if self.predictions[a] is None: + continue pos_list.append(self.predictions[a][t][1:3]) dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) - observations = {} - for h in handles: - observations[h] = self.get(h) + + observations = super().get_many(handles) + return observations def get(self, handle: int = 0) -> Node: @@ -150,6 +152,8 @@ class TreeObsForRailEnv(ObservationBuilder): 1 if no agent is observed min_fractional speed otherwise + #12: + number of agents ready to depart but no yet active Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -160,28 +164,41 @@ class TreeObsForRailEnv(ObservationBuilder): """ # Update local lookup table for all agents' positions - self.location_has_agent = dict() - self.location_has_agent_direction = dict() + # ignore other agents not in the grid (only status active and done) + self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if + agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + self.location_has_agent_ready_to_depart = {} for agent in self.env.agents: - if tuple(agent.position) in self.location_has_agent: - self.location_has_agent[tuple(agent.position)] = self.location_has_agent[tuple(agent.position)] + 1 - else: - self.location_has_agent[tuple(agent.position)] = 1 - - if (agent.position, agent.direction) in self.location_has_agent_direction: - self.location_has_agent_direction[(agent.position, agent.direction)] = \ - self.location_has_agent_direction[(agent.position, agent.direction)] + 1 - else: - self.location_has_agent_direction[(agent.position, agent.direction)] = 1 - - self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} - self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in - self.env.agents} + if agent.status == RailAgentStatus.READY_TO_DEPART: + self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1 + self.location_has_agent_direction = { + tuple(agent.position): agent.direction + for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + } + self.location_has_agent_speed = { + tuple(agent.position): agent.speed_data['speed'] + for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + } + self.location_has_agent_malfunction = { + tuple(agent.position): agent.malfunction_data['malfunction'] + for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + } if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + return None + + possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Here information about the agent itself is stored @@ -190,11 +207,13 @@ class TreeObsForRailEnv(ObservationBuilder): root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, - dist_min_to_target=distance_map[(handle, *agent.position, - agent.direction)], + dist_min_to_target=distance_map[ + (handle, *_agent_initial_position, + agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], speed_min_fractional=agent.speed_data['speed'], + num_agents_ready_to_depart=0, childs={}) visited = OrderedSet() @@ -210,16 +229,16 @@ class TreeObsForRailEnv(ObservationBuilder): for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: - new_cell = get_new_position(agent.position, branch_direction) + new_cell = get_new_position(_agent_initial_position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) - root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation + root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation visited |= branch_visited else: # add cells filled with infinity if no transition is possible - root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf + root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited return root_node_observation @@ -257,6 +276,7 @@ class TreeObsForRailEnv(ObservationBuilder): malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 + other_agent_ready_to_depart_encountered = 0 while exploring: # ############################# # ############################# @@ -270,9 +290,11 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] - if (position, direction) in self.location_has_agent_direction: + other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0) + + if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction - other_agent_same_direction += self.location_has_agent_direction[(position, direction)] + other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] @@ -281,9 +303,9 @@ class TreeObsForRailEnv(ObservationBuilder): # Other direction agents # TODO: Test that this behavior is as expected - other_agent_opposite_direction += self.location_has_agent[position] - \ - self.location_has_agent_direction[ - (position, direction)] + other_agent_opposite_direction += \ + self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction), + 0) else: # If no agent in the same direction was found all agents in that position are other direction @@ -314,7 +336,7 @@ class TreeObsForRailEnv(ObservationBuilder): self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 @@ -325,7 +347,7 @@ class TreeObsForRailEnv(ObservationBuilder): and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step+1 @@ -336,7 +358,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: @@ -424,6 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder): num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, + num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, childs={}) # ############################# @@ -443,7 +466,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4, tot_dist + 1, depth + 1) - node.childs[self.tree_explorted_actions_char[i]] = branch_observation + node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: @@ -453,12 +476,12 @@ class TreeObsForRailEnv(ObservationBuilder): branch_direction, tot_dist + 1, depth + 1) - node.childs[self.tree_explorted_actions_char[i]] = branch_observation + node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity - node.childs[self.tree_explorted_actions_char[i]] = -np.inf + node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() @@ -469,7 +492,7 @@ class TreeObsForRailEnv(ObservationBuilder): Utility function to print tree observations returned by this object. """ self.print_node_features(tree, "root", "") - for direction in self.tree_explorted_actions_char: + for direction in self.tree_explored_actions_char: self.print_subtree(tree.childs[direction], direction, "\t") @staticmethod @@ -478,7 +501,8 @@ class TreeObsForRailEnv(ObservationBuilder): node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ", node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ", node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction, - ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional) + ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ", + node.num_agents_ready_to_depart) def print_subtree(self, node, label, indent): if node == -np.inf or not node: @@ -490,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder): if not node.childs: return - for direction in self.tree_explorted_actions_char: + for direction in self.tree_explored_actions_char: self.print_subtree(node.childs[direction], direction, indent + "\t") def set_env(self, env: Environment): @@ -510,15 +534,15 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16),\ assuming 16 bits encoding of transitions. - - A 3D array (map_height, map_width, 5) with + - obs_agents_state: A 3D array (map_height, map_width, 5) with - first channel containing the agents position and direction - - second channel containing the other agents positions and diretion + - second channel containing the other agents positions and direction - third channel containing agent/other agent malfunctions - fourth channel containing agent/other agent fractional speeds - ' fifth channel containing number of agents in cell (only larger then one at start position) + - fifth channel containing number of other agents ready to depart - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ - target and the positions of the other agents targets. + - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ + target and the positions of the other agents targets (flag only, no counter!). """ def __init__(self): @@ -537,20 +561,36 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + agent = self.env.agents[handle] + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + return None + obs_targets = np.zeros((self.env.height, self.env.width, 2)) obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 - agent = self.env.agents[handle] - obs_agents_state[agent.position][0] = agent.direction + + obs_agents_state[_agent_initial_position][0] = agent.direction obs_targets[agent.target][0] = 1 for i in range(len(self.env.agents)): - other_agent = self.env.agents[i] - obs_agents_state[other_agent.position][4] += 1 - if i != handle: + other_agent:EnvAgent = self.env.agents[i] + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + obs_targets[other_agent.target][1] = 1 + + # third to fifth channel only if different agent and in the grid + if i != handle and other_agent.position is not None: obs_agents_state[other_agent.position][1] = other_agent.direction - obs_targets[other_agent.target][1] = 1 - obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] - obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] return self.rail_obs, obs_agents_state, obs_targets @@ -640,18 +680,14 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + def get_many(self, handles: Optional[List[int]] = None) -> Dict[ + int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ - observations = {} - if handles is None: - handles = [] - for h in handles: - observations[h] = self.get(h) - return observations + return super().get_many(handles) def field_of_view(self, position, direction, state=None): # Compute the local field of view for an agent in the environment diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 76095a2a2e1d9532951600118c6a777612641101..29b6947c28fa054cd1f4c44a204c740e4f536181 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_env_shortest_paths import get_shortest_paths @@ -47,6 +48,9 @@ class DummyPredictorForRailEnv(PredictionBuilder): prediction_dict = {} for agent in agents: + if agent.status != RailAgentStatus.ACTIVE: + # TODO make this generic + continue action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] _agent_initial_position = agent.position _agent_initial_direction = agent.direction @@ -122,7 +126,17 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): prediction_dict = {} for agent in agents: - _agent_initial_position = agent.position + + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + prediction_dict[agent.handle] = None + continue + _agent_initial_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 98fa4d2e4b557edced94c6eb3692f4949abe9221..bf4eff2137b45bd87c33d6d35edb3d2d1ed2cb18 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -16,7 +16,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import Vec2dOperations from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent +from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator @@ -235,12 +235,18 @@ class RailEnv(Environment): self.agents_static.append(agent_static) return len(self.agents_static) - 1 + def set_agent_active(self, handle: int): + agent = self.agents[handle] + if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + agent.position = agent.initial_position + def restart_agents(self): """ Reset the agents to their starting positions defined in agents_static """ self.agents = EnvAgent.list_from_static(self.agents_static) - def reset(self, regen_rail=True, replace_agents=True): + def reset(self, regen_rail=True, replace_agents=True, activate_agents=False): """ if regen_rail then regenerate the rails. if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) @@ -277,8 +283,13 @@ class RailEnv(Environment): *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) self.restart_agents() - for i_agent in range(self.get_num_agents()): - agent = self.agents[i_agent] + if activate_agents: + for i_agent in range(self.get_num_agents()): + self.set_agent_active(i_agent) + + for i_agent, agent in enumerate(self.agents): + if agent.status != RailAgentStatus.ACTIVE: + continue # A proportion of agent in the environment will receive a positive malfunction rate if np.random.random() < self.proportion_malfunctioning_trains: @@ -366,7 +377,8 @@ class RailEnv(Environment): info_dict = { 'action_required': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, - 'speed': {i: 0 for i in range(self.get_num_agents())} + 'speed': {i: 0 for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -380,21 +392,19 @@ class RailEnv(Environment): self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True - for k in self.dones.keys(): - self.dones[k] = True - - action_required_agents = { - i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) - } - malfunction_agents = { - i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) - } - speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} + for i in range(self.get_num_agents()): + self.agents[i].status = RailAgentStatus.DONE + self.dones[i] = True info_dict = { - 'action_required': action_required_agents, - 'malfunction': malfunction_agents, - 'speed': speed_agents + 'action_required': { + i: (agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0) + for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + }, + 'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -412,10 +422,19 @@ class RailEnv(Environment): action_dict_ : Dict[int,RailEnvActions] """ - if self.dones[i_agent]: # this agent has already completed... + agent = self.agents[i_agent] + if agent.status == RailAgentStatus.DONE: # this agent has already completed... return - agent = self.agents[i_agent] + # agent gets active by a MOVE_* action and if c + if agent.status == RailAgentStatus.READY_TO_DEPART: + if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, + RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + agent.position = agent.initial_position + else: + return + agent.old_direction = agent.direction agent.old_position = agent.position @@ -508,6 +527,7 @@ class RailEnv(Environment): # has the agent reached its target? if np.equal(agent.position, agent.target).all(): + agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False @@ -558,9 +578,15 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_free = self.cell_free(new_position) return cell_free, new_cell_valid, new_direction, new_position, transition_valid + def cell_free(self, position): + + agent_positions = [agent.position for agent in self.agents if agent.position is not None] + ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1)) + return ret + def check_action(self, agent: EnvAgent, action: RailEnvActions): """ diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index 793601d4d18ac38b729d15883089d5acbfc41ed3..7944a49daf388403c8a226dbf866ac810d1286a3 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions from flatland.utils.ordered_set import OrderedSet @@ -92,7 +93,15 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non shortest_paths = dict() def _shortest_path_for_agent(agent): - position = agent.position + if agent.status == RailAgentStatus.READY_TO_DEPART: + position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + position = agent.position + elif agent.status == RailAgentStatus.DONE: + position = agent.target + else: + shortest_paths[agent.handle] = None + return direction = agent.direction shortest_paths[agent.handle] = [] distance = math.inf diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index bdd4b48b356d69fc7afb5cb76bf12ee08880516d..05c6cc0da831a571c4f0b4f328ff6f3f4703cff1 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -223,7 +223,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]] # setup with loaded data - agents_position = [a.position for a in agents_static] + agents_position = [a.initial_position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] if len(data['agents_static'][0]) > 5: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index a26aa5ef3191e0b968c5cd1396f397d6cafd1dd9..c238319ad17bb09d9dbaea80335685ce1283feb1 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -147,6 +147,9 @@ class RenderTool(object): Plot a simple agent. Assumes a working graphics layer context (cf a MPL figure). """ + if position_row_col is None: + return + rt = self.__class__ direction_row_col = rt.transitions_row_col[direction] # agent direction in RC @@ -537,7 +540,7 @@ class RenderTool(object): for agent_idx, agent in enumerate(self.env.agents): - if agent is None: + if agent is None or agent.position is None: continue if self.agent_render_variant == AgentRenderVariant.BOX_ONLY: diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py new file mode 100644 index 0000000000000000000000000000000000000000..b4396c101d1879bf8f5f74f9a79f9a8541d45e31 --- /dev/null +++ b/tests/test_flaltland_rail_agent_status.py @@ -0,0 +1,124 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay + +np.random.seed(1) + + +def test_initial_status(): + """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + set_penalties_for_replay(env) + test_config = ReplayConfig( + replay=[ + Replay( + position=None, # not entered grid yet + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.READY_TO_DEPART, + action=RailEnvActions.DO_NOTHING, + reward=0, + + ), + Replay( + position=None, # not entered grid yet before step + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.READY_TO_DEPART, + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty! + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.ACTIVE, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # done + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ) + + ], + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, + target=(3, 5), + speed=0.5 + ) + + run_replay_config(env, [test_config], activate_agents=False) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 0d6d309765690b1f95c681d7d109a13071d7f86b..52b047244850a5b1c6b39dadd74568fc5f92deec 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -29,6 +29,9 @@ def test_global_obs(): global_obs = env.reset() + # we have to take step for the agent to enter the grid. + global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD}) + assert (global_obs[0][0].shape == rail_map.shape + (16,)) rail_map_recons = np.zeros_like(rail_map) @@ -109,12 +112,14 @@ def test_reward_function_conflict(rendering=False): agent.direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE agent = env.agents_static[1] agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static env.reset(False, False) @@ -184,16 +189,20 @@ def test_reward_function_waiting(rendering=False): # set the initial position agent = env.agents_static[0] + agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.target = (3, 1) # west dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE agent = env.agents_static[1] + agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 8) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static env.reset(False, False) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index f4ab68bc45a82b8f196fcfbebb26fd68a36c37a4..2517fb84de00c1346e5c601f903a9a1a677cfae4 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,6 +5,7 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction @@ -31,12 +32,13 @@ def test_dummy_predictor(rendering=False): env.reset() # set initial position and direction for testing... - env.agents_static[0].position = (5, 6) + env.agents_static[0].initial_position = (5, 6) env.agents_static[0].direction = 0 env.agents_static[0].target = (3, 0) # reset to set agents from agents_static env.reset(False, False) + env.set_agent_active(0) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -124,10 +126,12 @@ def test_shortest_path_predictor(rendering=False): # set the initial position agent = env.agents_static[0] + agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static env.reset(False, False) @@ -139,9 +143,9 @@ def test_shortest_path_predictor(rendering=False): # compute the observations and predictions distance_map = env.distance_map.get() - assert distance_map[0, agent.position[0], agent.position[ - 1], agent.direction] == 5.0, "found {} instead of {}".format( - distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) + assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \ + "found {} instead of {}".format( + distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0) paths = get_shortest_paths(env.distance_map)[0] assert paths == [ @@ -259,19 +263,23 @@ def test_shortest_path_predictor_conflicts(rendering=False): # set the initial position agent = env.agents_static[0] + agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE agent = env.agents_static[1] + agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static - observations = env.reset(False, False) + observations = env.reset(False, False, True) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -295,14 +303,14 @@ def test_shortest_path_predictor_conflicts(rendering=False): def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt) - for a_1 in obs_builder.tree_explorted_actions_char: + for a_1 in obs_builder.tree_explored_actions_char: if tree.childs[a_1] == -np.inf: assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) continue else: conflict = tree.childs[a_1].num_agents_opposite_direction assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) - for a_2 in obs_builder.tree_explorted_actions_char: + for a_2 in obs_builder.tree_explored_actions_char: if tree.childs[a_1].childs[a_2] == -np.inf: assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) else: diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 0fefd3e212ddb5f084c1e219f4063079e03dabdf..e0281bb0d3f21a8e25fdff86ade5bed05f5dea13 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -85,7 +85,7 @@ def test_rail_environment_single_agent(): obs_builder_object=GlobalObsForRailEnv()) for _ in range(200): - _ = rail_env.reset() + _ = rail_env.reset(False, False, True) # We do not care about target for the moment agent = rail_env.agents[0] @@ -130,9 +130,6 @@ def test_rail_environment_single_agent(): done = dones['__all__'] -test_rail_environment_single_agent() - - def test_dead_end(): transitions = RailEnvTransitions() @@ -164,32 +161,12 @@ def test_dead_end(): number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) - def check_consistency(rail_env): - # We run step to check that trains do not move anymore - # after being done. - # TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving - # until they are manually stopped. - for i in range(7): - prev_pos = rail_env.agents[0].position - - # The train cannot turn, so we check that when it tries, - # it stays where it is. - _ = rail_env.step({0: 1}) - _ = rail_env.step({0: 3}) - assert (rail_env.agents[0].position == prev_pos) - _, _, dones, _ = rail_env.step({0: 2}) - - if i < 5: - assert (not dones[0] and not dones['__all__']) - else: - assert (dones[0] and dones['__all__']) - # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=1, target=(0, 0), moving=False)] rail_env.reset() - rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=3, target=(0, 4), moving=False)] # In the vertical configuration: rail_map = np.array( @@ -210,10 +187,12 @@ def test_dead_end(): obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=2, target=(0, 0), moving=False)] rail_env.reset() - rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=0, target=(4, 0), moving=False)] + + # TODO make assertions def test_get_entry_directions(): diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 4600c4a3002995e1238a0ccbda762501ac985408..65d2d68c45155efda24536ecfd776bef5ebaab0c 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -25,6 +25,7 @@ def test_get_shortest_paths_unreachable(): # set the initial position agent = env.agents_static[0] agent.position = (3, 1) # west dead-end + agent.initial_position = (3, 1) # west dead-end agent.direction = Grid4TransitionsEnum.WEST agent.target = (3, 9) # east dead-end agent.moving = True diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 416efea5c909a5c89a9e6b8d782e18a1dd4bae62..ecba37641b740e0d2d71028b188b48f936730e74 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -21,6 +21,7 @@ def test_sparse_rail_generator(): schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) + env.reset(False, False, True) expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) expected_grid_map[1][33] = 8192 expected_grid_map[2][33] = 32800 @@ -1522,6 +1523,9 @@ def test_rail_env_action_required_info(): obs_builder_object=GlobalObsForRailEnv()) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) + env_always_action.reset(False, False, True) + env_only_if_action_required.reset(False, False, True) + for step in range(100): print("step {}".format(step)) @@ -1575,6 +1579,7 @@ def test_rail_env_malfunction_speed_info(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv(), stochastic_data=stochastic_data) + env.reset(False, False, True) env_renderer = RenderTool(env, gl="PILSVG", ) for step in range(100): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 1bcf9c5112a4423fd22e1102500890a39f853c74..9800a53735b2698fc8cf30d59fb85a5e631f5580 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -6,6 +6,7 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator @@ -29,7 +30,16 @@ class SingleAgentNavigationObs(ObservationBuilder): def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + return None + + possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions are available; @@ -41,7 +51,7 @@ class SingleAgentNavigationObs(ObservationBuilder): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = get_new_position(agent.position, direction) + new_position = get_new_position(_agent_initial_position, direction) min_distances.append( self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: @@ -70,16 +80,19 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) - obs = env.reset() + obs = env.reset(False, False, True) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 + for agent in env.agents: + agent.status = RailAgentStatus.ACTIVE agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position for step in range(100): actions = {} + for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 @@ -104,7 +117,8 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 21 + assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format( + env.agents[0].malfunction_data['nr_malfunctions']) # Check that 20 stops where performed assert agent_halts == 20 @@ -120,8 +134,6 @@ def test_malfunction_process_statistically(): 'malfunction_rate': 2, 'min_duration': 3, 'max_duration': 3} - np.random.seed(5) - random.seed(0) env = RailEnv(width=20, height=20, @@ -131,8 +143,9 @@ def test_malfunction_process_statistically(): number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) - - env.reset() + np.random.seed(5) + random.seed(0) + env.reset(False, False, True) nb_malfunction = 0 for step in range(100): action_dict: Dict[int, RailEnvActions] = {} @@ -149,9 +162,6 @@ def test_malfunction_process_statistically(): def test_initial_malfunction(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -162,7 +172,8 @@ def test_initial_malfunction(): 1. / 2.: 0., # Fast freight train 1. / 3.: 0., # Slow commuter train 1. / 4.: 0.} # Slow freight train - + np.random.seed(5) + random.seed(0) env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, @@ -218,15 +229,15 @@ def test_initial_malfunction(): ) ], speed=env.agents[0].speed_data['speed'], - target=env.agents[0].target + target=env.agents[0].target, + initial_position=(28, 5), + initial_direction=Grid4TransitionsEnum.EAST, ) + run_replay_config(env, [replay_config]) def test_initial_malfunction_stop_moving(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -253,19 +264,21 @@ def test_initial_malfunction_stop_moving(): replay_config = ReplayConfig( replay=[ Replay( - position=(28, 5), + position=None, direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, - reward=env.step_penalty # full step penalty when stopped + reward=env.step_penalty, # full step penalty when stopped + status=RailAgentStatus.READY_TO_DEPART ), Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, - reward=env.step_penalty # full step penalty when stopped + reward=env.step_penalty, # full step penalty when stopped + status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving @@ -275,7 +288,8 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, malfunction=1, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we have stopped and do nothing --> should stand still Replay( @@ -283,7 +297,8 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( @@ -291,21 +306,24 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped + reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), Replay( position=(28, 6), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # full step penalty while stopped + reward=env.step_penalty * 1.0, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], - target=env.agents[0].target + target=env.agents[0].target, + initial_position=(28, 5), + initial_direction=Grid4TransitionsEnum.EAST, ) - - run_replay_config(env, [replay_config]) + run_replay_config(env, [replay_config], activate_agents=False) def test_initial_malfunction_do_nothing(): @@ -336,20 +354,23 @@ def test_initial_malfunction_do_nothing(): ) set_penalties_for_replay(env) replay_config = ReplayConfig( - replay=[Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - set_malfunction=3, - malfunction=3, - reward=env.step_penalty # full step penalty while malfunctioning - ), + replay=[ + Replay( + position=None, + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + set_malfunction=3, + malfunction=3, + reward=env.step_penalty, # full step penalty while malfunctioning + status=RailAgentStatus.READY_TO_DEPART + ), Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, - reward=env.step_penalty # full step penalty while malfunctioning + reward=env.step_penalty, # full step penalty while malfunctioning + status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving @@ -359,7 +380,8 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=1, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we haven't started moving yet --> stay here Replay( @@ -367,7 +389,8 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( @@ -375,21 +398,25 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0 + reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 + status=RailAgentStatus.ACTIVE ), Replay( position=(28, 6), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # step penalty for speed 1.0 + reward=env.step_penalty * 1.0, # step penalty for speed 1.0 + status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], - target=env.agents[0].target + target=env.agents[0].target, + initial_position=(28, 5), + initial_direction=Grid4TransitionsEnum.EAST, ) - run_replay_config(env, [replay_config]) + run_replay_config(env, [replay_config], activate_agents=False) def test_initial_nextmalfunction_not_below_zero(): diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 9035062bad47dbdf432c6cd980c9d2361d42de6e..393d0e0087d3068050759575684d749d7224bc6b 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -1,7 +1,8 @@ import numpy as np +from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator @@ -31,24 +32,27 @@ def test_get_global_observation(): schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) - obs, all_rewards, done, _ = env.step({0: 0}) + obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) for i in range(len(env.agents)): + agent: EnvAgent = env.agents[i] + print("[{}] status={}, position={}, target={}, initial_position={}".format(i, agent.status, agent.position, + agent.target, + agent.initial_position)) + + for i, agent in enumerate(env.agents): obs_agents_state = obs[i][1] obs_targets = obs[i][2] + # nr_agents = np.count_nonzero(obs_targets[:, :, 0]) - nr_agents_other = np.count_nonzero(obs_targets[:, :, 1]) assert nr_agents == 1 - assert nr_agents_other == (number_of_agents - 1) - - # since the array is initialized with -1 add one in order to used np.count_nonzero - obs_agents_state += 1 - obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0]) - obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1]) - obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2]) - obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3]) - assert obs_agents_state_0 == 1 - assert obs_agents_state_1 == (number_of_agents - 1) - assert obs_agents_state_2 == number_of_agents - assert obs_agents_state_3 == number_of_agents + + for r in range(env.height): + for c in range(env.width): + _other_agent_target = 0 + for other_i, other_agent in enumerate(env.agents): + if other_agent.target == (r, c): + _other_agent_target = 1 + break + assert obs_targets[(r, c)][1] == _other_agent_target, "agent {} at {} expected {}".format(i, (r,c), _other_agent_target) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index b0f274ba4c4b5453140fcc50bc6137e39e8e4f04..3cd0a4c1d9812c728089255eb1150111b783a463 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -63,7 +63,8 @@ def test_multi_speed_init(): # Set all the different speeds # Reset environment and get initial observations for all agents - env.reset() + env.reset(False, False, True) + # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository old_pos = [] @@ -188,7 +189,9 @@ def test_multispeed_actions_no_malfunction_no_blocking(): ), ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config]) @@ -285,7 +288,10 @@ def test_multispeed_actions_no_malfunction_blocking(): ) ], target=(3, 0), # west dead-end - speed=1 / 3), + speed=1 / 3, + initial_position=(3, 8), + initial_direction=Grid4TransitionsEnum.WEST, + ), ReplayConfig( replay=[ Replay( @@ -369,7 +375,9 @@ def test_multispeed_actions_no_malfunction_blocking(): ), ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) ] @@ -505,7 +513,9 @@ def test_multispeed_actions_malfunction_no_blocking(): ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config]) @@ -587,7 +597,9 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 903120d868aa65833e7c2393ddfcc821c26da4f6..f5d1cd5c957d1e81ca2f2465fb8265c6b2795342 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ import numpy as np from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.utils.rendertools import RenderTool @@ -18,6 +18,7 @@ class Replay(object): malfunction = attrib(default=0, type=int) set_malfunction = attrib(default=None, type=Optional[int]) reward = attrib(default=None, type=Optional[float]) + status = attrib(default=None, type=Optional[RailAgentStatus]) @attrs @@ -25,6 +26,8 @@ class ReplayConfig(object): replay = attrib(type=List[Replay]) target = attrib(type=Tuple[int, int]) speed = attrib(type=float) + initial_position = attrib(type=Tuple[int, int]) + initial_direction = attrib(type=Grid4TransitionsEnum) # ensure that env is working correctly with start/stop/invalidaction penalty different from 0 @@ -35,7 +38,7 @@ def set_penalties_for_replay(env: RailEnv): env.invalid_action_penalty = -29 -def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): +def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True): """ Runs the replay configs and checks assertions. @@ -47,10 +50,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: - position, direction before step are verified - optionally, set_malfunction is applied - malfunction is verified + - status is verified (optionally) *After each step* - reward is verified after step + Parameters ---------- env @@ -67,18 +72,20 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for step in range(len(test_configs[0].replay)): if step == 0: for a, test_config in enumerate(test_configs): - agent: EnvAgent = env.agents[a] - replay = test_config.replay[0] + agent: EnvAgent = env.agents_static[a] # set the initial position - agent.position = replay.position - agent.direction = replay.direction + agent.initial_position = test_config.initial_position + agent.direction = test_config.initial_direction agent.target = test_config.target agent.speed_data['speed'] = test_config.speed + env.reset(False, False, activate_agents) def _assert(a, actual, expected, msg): - assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, - actual, - expected) + print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected)) + assert (actual == expected) or ( + np.allclose(actual, expected)), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, + actual, + expected) action_dict = {} @@ -88,26 +95,29 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') + if replay.status is not None: + _assert(a, agent.status, replay.status, 'status') if replay.action is not None: - assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( + assert info_dict['action_required'][ + a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format( step, a, True) action_dict[a] = replay.action else: - assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( - step, a, False) + assert info_dict['action_required'][ + a] == False, "[{}] agent {} expecting action_required={}, but found {}".format( + step, a, False, info_dict['action_required'][a]) if replay.set_malfunction is not None: agent.malfunction_data['malfunction'] = replay.set_malfunction agent.malfunction_data['moving_before_malfunction'] = agent.moving _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - + print(step) _, rewards_dict, _, info_dict = env.step(action_dict) if rendering: renderer.render_env(show=True, show_observations=True) for a, test_config in enumerate(test_configs): replay = test_config.replay[step] - _assert(a, rewards_dict[a], replay.reward, 'reward') - + _assert(a, rewards_dict[a], replay.reward, 'reward') diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 4c925789e6560077d637e2a594c736df8850d00a..9bfe3a47d0871cbfbbabb4e35eac0bd6ecaaedf0 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -102,6 +102,7 @@ def test_rail_from_grid_transition_map(): schedule_generator=random_schedule_generator(), number_of_agents=n_agents ) + env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) # Check if the number of non-empty rail cells is ok