diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 855d1f5dffbef29da26ca1dead933af846b863bd..f75cb74537f03dc6fb1aecbadff37a183432e55a 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -54,11 +54,9 @@ class ObservePredictions(ObservationBuilder): pos_list.append(self.predictions[a][t][1:3]) # We transform (x,y) coodrinates to a single integer number for simpler comparison self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) - observations = {} - # Collect all the different observation for all the agents - for h in handles: - observations[h] = self.get(h) + observations = super().get_many(handles) + return observations def get(self, handle: int = 0) -> np.ndarray: diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 3cc21966162dd28d183493a97cd6072a34abb738..2302fff9cdeeaf162b7ab38b1e67f5052c76b3ef 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -24,7 +24,7 @@ class ObservationBuilder: self.env = None def set_env(self, env: Environment): - self.env = env + self.env: Environment = env def reset(self): """ diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index f659ec8436a941606b6d649e24d2481e5be9b66d..d8c05c2020a9524ae3e6eab232a3e320bc699187 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,5 +1,6 @@ +from enum import IntEnum from itertools import starmap -from typing import Tuple +from typing import Tuple, Optional import numpy as np from attr import attrs, attrib, Factory @@ -7,6 +8,13 @@ from attr import attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum +class RailAgentStatus(IntEnum): + READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position + ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path + DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever + DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None + + @attrs class EnvAgentStatic(object): """ EnvAgentStatic - Stores initial position, direction and target. @@ -14,7 +22,7 @@ class EnvAgentStatic(object): rather than where it is at the moment. The target should also be stored here. """ - position = attrib(type=Tuple[int, int]) + initial_position = attrib(type=Tuple[int, int]) direction = attrib(type=Grid4TransitionsEnum) target = attrib(type=Tuple[int, int]) moving = attrib(default=False, type=bool) @@ -33,6 +41,9 @@ class EnvAgentStatic(object): lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False}))) + status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) + position = attrib(default=None, type=Optional[Tuple[int, int]]) + @classmethod def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets @@ -65,7 +76,7 @@ class EnvAgentStatic(object): # I can't find an expression which works on both tuples, lists and ndarrays # which converts them all to a list of native python ints. - lPos = self.position + lPos = self.initial_position if type(lPos) is np.ndarray: lPos = lPos.tolist() diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index c23d4345a03c761ad4c4ac1d936db817f8acc529..5cc6a8c11c2b7d5045d66a820f312dc0fd61a492 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -11,11 +11,20 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_utils import RailAgentStatus, EnvAgent from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): + """ + TreeObsForRailEnv object. + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the graph structure of the rail + network to simplify the representation of the state of the environment for each agent. + + For details about the features in the tree observation see the get() function. + """ Node = collections.namedtuple('Node', 'dist_own_target_encountered ' 'dist_other_target_encountered ' 'dist_other_agent_encountered ' @@ -27,19 +36,10 @@ class TreeObsForRailEnv(ObservationBuilder): 'num_agents_opposite_direction ' 'num_agents_malfunctioning ' 'speed_min_fractional ' + 'num_agents_ready_to_depart ' 'childs') - tree_explorted_actions_char = ['L', 'F', 'R', 'B'] - - """ - TreeObsForRailEnv object. - - This object returns observation vectors for agents in the RailEnv environment. - The information is local to each agent and exploits the graph structure of the rail - network to simplify the representation of the state of the environment for each agent. - - For details about the features in the tree observation see the get() function. - """ + tree_explored_actions_char = ['L', 'F', 'R', 'B'] def __init__(self, max_depth: int, predictor: PredictionBuilder = None): super().__init__() @@ -67,19 +67,21 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: - + # TODO hacky hacky: `range(len(self.predictions[0]))` does not seem safe!! for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] for a in handles: + if self.predictions[a] is None: + continue pos_list.append(self.predictions[a][t][1:3]) dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) - observations = {} - for h in handles: - observations[h] = self.get(h) + + observations = super().get_many(handles) + return observations def get(self, handle: int = 0) -> Node: @@ -150,6 +152,8 @@ class TreeObsForRailEnv(ObservationBuilder): 1 if no agent is observed min_fractional speed otherwise + #12: + number of agents ready to depart but no yet active Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -160,16 +164,41 @@ class TreeObsForRailEnv(ObservationBuilder): """ # Update local lookup table for all agents' positions - self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} - self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} - self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} - self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in - self.env.agents} + # ignore other agents not in the grid (only status active and done) + self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if + agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + self.location_has_agent_ready_to_depart = {} + for agent in self.env.agents: + if agent.status == RailAgentStatus.READY_TO_DEPART: + self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1 + self.location_has_agent_direction = { + tuple(agent.position): agent.direction + for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + } + self.location_has_agent_speed = { + tuple(agent.position): agent.speed_data['speed'] + for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + } + self.location_has_agent_malfunction = { + tuple(agent.position): agent.malfunction_data['malfunction'] + for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + } if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + return None + + possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Here information about the agent itself is stored @@ -178,11 +207,13 @@ class TreeObsForRailEnv(ObservationBuilder): root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, - dist_min_to_target=distance_map[(handle, *agent.position, - agent.direction)], + dist_min_to_target=distance_map[ + (handle, *_agent_initial_position, + agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], speed_min_fractional=agent.speed_data['speed'], + num_agents_ready_to_depart=0, childs={}) visited = OrderedSet() @@ -198,16 +229,16 @@ class TreeObsForRailEnv(ObservationBuilder): for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: - new_cell = get_new_position(agent.position, branch_direction) + new_cell = get_new_position(_agent_initial_position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) - root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation + root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation visited |= branch_visited else: # add cells filled with infinity if no transition is possible - root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf + root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited return root_node_observation @@ -245,6 +276,7 @@ class TreeObsForRailEnv(ObservationBuilder): malfunctioning_agent = 0 min_fractional_speed = 1. num_steps = 1 + other_agent_ready_to_depart_encountered = 0 while exploring: # ############################# # ############################# @@ -258,6 +290,8 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] + other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0) + if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction other_agent_same_direction += 1 @@ -296,7 +330,7 @@ class TreeObsForRailEnv(ObservationBuilder): self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 @@ -307,7 +341,7 @@ class TreeObsForRailEnv(ObservationBuilder): and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step+1 @@ -318,7 +352,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir[post_step][ca])] == 1 \ and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: + if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: @@ -406,6 +440,7 @@ class TreeObsForRailEnv(ObservationBuilder): num_agents_opposite_direction=other_agent_opposite_direction, num_agents_malfunctioning=malfunctioning_agent, speed_min_fractional=min_fractional_speed, + num_agents_ready_to_depart=other_agent_ready_to_depart_encountered, childs={}) # ############################# @@ -425,7 +460,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4, tot_dist + 1, depth + 1) - node.childs[self.tree_explorted_actions_char[i]] = branch_observation + node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: @@ -435,12 +470,12 @@ class TreeObsForRailEnv(ObservationBuilder): branch_direction, tot_dist + 1, depth + 1) - node.childs[self.tree_explorted_actions_char[i]] = branch_observation + node.childs[self.tree_explored_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity - node.childs[self.tree_explorted_actions_char[i]] = -np.inf + node.childs[self.tree_explored_actions_char[i]] = -np.inf if depth == self.max_depth: node.childs.clear() @@ -451,7 +486,7 @@ class TreeObsForRailEnv(ObservationBuilder): Utility function to print tree observations returned by this object. """ self.print_node_features(tree, "root", "") - for direction in self.tree_explorted_actions_char: + for direction in self.tree_explored_actions_char: self.print_subtree(tree.childs[direction], direction, "\t") @staticmethod @@ -460,7 +495,8 @@ class TreeObsForRailEnv(ObservationBuilder): node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ", node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ", node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction, - ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional) + ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ", + node.num_agents_ready_to_depart) def print_subtree(self, node, label, indent): if node == -np.inf or not node: @@ -472,7 +508,7 @@ class TreeObsForRailEnv(ObservationBuilder): if not node.childs: return - for direction in self.tree_explorted_actions_char: + for direction in self.tree_explored_actions_char: self.print_subtree(node.childs[direction], direction, indent + "\t") def set_env(self, env: Environment): @@ -497,6 +533,7 @@ class GlobalObsForRailEnv(ObservationBuilder): - second channel containing the other agents positions and diretion - third channel containing agent/other agent malfunctions - fourth channel containing agent/other agent fractional speeds + - fifth channel containing number of other agents ready to depart - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets. @@ -518,18 +555,33 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + agent = self.env.agents[handle] + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + return None + obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1 + obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 - agent = self.env.agents[handle] - obs_agents_state[agent.position][0] = agent.direction + obs_agents_state[_agent_initial_position][0] = agent.direction obs_targets[agent.target][0] = 1 for i in range(len(self.env.agents)): - other_agent = self.env.agents[i] + other_agent: EnvAgent = self.env.agents[i] + # ignore other_agent if it is not in the grid + if other_agent.position is None: + continue if i != handle: obs_agents_state[other_agent.position][1] = other_agent.direction obs_targets[other_agent.target][1] = 1 + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + obs_agents_state[other_agent.initial_position] += 1 + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] @@ -621,18 +673,14 @@ class LocalObsForRailEnv(ObservationBuilder): direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + def get_many(self, handles: Optional[List[int]] = None) -> Dict[ + int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. """ - observations = {} - if handles is None: - handles = [] - for h in handles: - observations[h] = self.get(h) - return observations + return super().get_many(handles) def field_of_view(self, position, direction, state=None): # Compute the local field of view for an agent in the environment diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 76095a2a2e1d9532951600118c6a777612641101..29b6947c28fa054cd1f4c44a204c740e4f536181 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_env_shortest_paths import get_shortest_paths @@ -47,6 +48,9 @@ class DummyPredictorForRailEnv(PredictionBuilder): prediction_dict = {} for agent in agents: + if agent.status != RailAgentStatus.ACTIVE: + # TODO make this generic + continue action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] _agent_initial_position = agent.position _agent_initial_direction = agent.direction @@ -122,7 +126,17 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): prediction_dict = {} for agent in agents: - _agent_initial_position = agent.position + + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + prediction_dict[agent.handle] = None + continue + _agent_initial_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4fb1c55e5a179e99e6c4985fff3c5596ba1f0a66..1b5ca23f36e3571aa75d9fd69431c119054ab05c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -15,7 +15,7 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent +from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator @@ -224,12 +224,18 @@ class RailEnv(Environment): self.agents_static.append(agent_static) return len(self.agents_static) - 1 + def set_agent_active(self, handle: int): + agent = self.agents[handle] + if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + agent.position = agent.initial_position + def restart_agents(self): """ Reset the agents to their starting positions defined in agents_static """ self.agents = EnvAgent.list_from_static(self.agents_static) - def reset(self, regen_rail=True, replace_agents=True): + def reset(self, regen_rail=True, replace_agents=True, activate_agents=False): """ if regen_rail then regenerate the rails. if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) @@ -265,8 +271,13 @@ class RailEnv(Environment): *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) self.restart_agents() - for i_agent in range(self.get_num_agents()): - agent = self.agents[i_agent] + if activate_agents: + for i_agent in range(self.get_num_agents()): + self.set_agent_active(i_agent) + + for i_agent, agent in enumerate(self.agents): + if agent.status != RailAgentStatus.ACTIVE: + continue # A proportion of agent in the environment will receive a positive malfunction rate if np.random.random() < self.proportion_malfunctioning_trains: @@ -354,7 +365,8 @@ class RailEnv(Environment): info_dict = { 'action_required': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, - 'speed': {i: 0 for i in range(self.get_num_agents())} + 'speed': {i: 0 for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -369,21 +381,19 @@ class RailEnv(Environment): if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True - for k in self.dones.keys(): - self.dones[k] = True - - action_required_agents = { - i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) - } - malfunction_agents = { - i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) - } - speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} + for i in range(self.get_num_agents()): + self.agents[i].status = RailAgentStatus.DONE + self.dones[i] = True info_dict = { - 'action_required': action_required_agents, - 'malfunction': malfunction_agents, - 'speed': speed_agents + 'action_required': { + i: (agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0) + for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + }, + 'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} } return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -401,10 +411,19 @@ class RailEnv(Environment): action_dict_ : Dict[int,RailEnvActions] """ - if self.dones[i_agent]: # this agent has already completed... + agent = self.agents[i_agent] + if agent.status == RailAgentStatus.DONE: # this agent has already completed... return - agent = self.agents[i_agent] + # agent gets active by a MOVE_* action and if c + if agent.status == RailAgentStatus.READY_TO_DEPART: + if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, + RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + agent.position = agent.initial_position + else: + return + agent.old_direction = agent.direction agent.old_position = agent.position @@ -497,6 +516,7 @@ class RailEnv(Environment): # has the agent reached its target? if np.equal(agent.position, agent.target).all(): + agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False else: @@ -543,9 +563,15 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_free = self.cell_free(new_position) return cell_free, new_cell_valid, new_direction, new_position, transition_valid + def cell_free(self, position): + + agent_positions = [agent.position for agent in self.agents if agent.position is not None] + ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1)) + return ret + def check_action(self, agent: EnvAgent, action: RailEnvActions): """ @@ -591,7 +617,7 @@ class RailEnv(Environment): return self.obs_dict def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: - return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col)) + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) def get_full_state_msg(self): grid_data = self.rail.grid.tolist() diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index 793601d4d18ac38b729d15883089d5acbfc41ed3..7944a49daf388403c8a226dbf866ac810d1286a3 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions from flatland.utils.ordered_set import OrderedSet @@ -92,7 +93,15 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non shortest_paths = dict() def _shortest_path_for_agent(agent): - position = agent.position + if agent.status == RailAgentStatus.READY_TO_DEPART: + position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + position = agent.position + elif agent.status == RailAgentStatus.DONE: + position = agent.target + else: + shortest_paths[agent.handle] = None + return direction = agent.direction shortest_paths[agent.handle] = [] distance = math.inf diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 7f42feeacd0ef50b56846540a9b2af9d147eafb0..5442a0af191ce31e415964224557cb18d05407f1 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -235,7 +235,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]] # setup with loaded data - agents_position = [a.position for a in agents_static] + agents_position = [a.initial_position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] if len(data['agents_static'][0]) > 5: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 99958bf38449ef8eb58c519990f1975106409c4e..e7b1e72679937bc5b3093cebfa58bd2e9894ba94 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -146,6 +146,9 @@ class RenderTool(object): Plot a simple agent. Assumes a working graphics layer context (cf a MPL figure). """ + if position_row_col is None: + return + rt = self.__class__ direction_row_col = rt.transitions_row_col[direction] # agent direction in RC @@ -535,7 +538,7 @@ class RenderTool(object): for agent_idx, agent in enumerate(self.env.agents): - if agent is None: + if agent is None or agent.position is None: continue if self.agent_render_variant == AgentRenderVariant.BOX_ONLY: diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py new file mode 100644 index 0000000000000000000000000000000000000000..b4396c101d1879bf8f5f74f9a79f9a8541d45e31 --- /dev/null +++ b/tests/test_flaltland_rail_agent_status.py @@ -0,0 +1,124 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay + +np.random.seed(1) + + +def test_initial_status(): + """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + set_penalties_for_replay(env) + test_config = ReplayConfig( + replay=[ + Replay( + position=None, # not entered grid yet + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.READY_TO_DEPART, + action=RailEnvActions.DO_NOTHING, + reward=0, + + ), + Replay( + position=None, # not entered grid yet before step + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.READY_TO_DEPART, + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty! + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.ACTIVE, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # done + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE + ) + + ], + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, + target=(3, 5), + speed=0.5 + ) + + run_replay_config(env, [test_config], activate_agents=False) diff --git a/tests/test_flatland_envs_city_generator.py b/tests/test_flatland_envs_city_generator.py index fe39d785712c88f03af09c7a6d1dac715a585db3..1d386df225e7d025116752e26d5c55cf2a292214 100644 --- a/tests/test_flatland_envs_city_generator.py +++ b/tests/test_flatland_envs_city_generator.py @@ -28,274 +28,274 @@ def test_city_generator(): expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) - expected_grid_map[8][16]=4 - expected_grid_map[8][17]=5633 - expected_grid_map[8][18]=1025 - expected_grid_map[8][19]=1025 - expected_grid_map[8][20]=17411 - expected_grid_map[8][21]=1025 - expected_grid_map[8][22]=1025 - expected_grid_map[8][23]=1025 - expected_grid_map[8][24]=1025 - expected_grid_map[8][25]=1025 - expected_grid_map[8][26]=4608 - expected_grid_map[9][16]=16386 - expected_grid_map[9][17]=50211 - expected_grid_map[9][18]=1025 - expected_grid_map[9][19]=1025 - expected_grid_map[9][20]=3089 - expected_grid_map[9][21]=1025 - expected_grid_map[9][22]=256 - expected_grid_map[9][26]=32800 - expected_grid_map[10][6]=16386 - expected_grid_map[10][7]=1025 - expected_grid_map[10][8]=1025 - expected_grid_map[10][9]=1025 - expected_grid_map[10][10]=1025 - expected_grid_map[10][11]=1025 - expected_grid_map[10][12]=1025 - expected_grid_map[10][13]=1025 - expected_grid_map[10][14]=1025 - expected_grid_map[10][15]=1025 - expected_grid_map[10][16]=33825 - expected_grid_map[10][17]=34864 - expected_grid_map[10][26]=32800 - expected_grid_map[11][6]=32800 - expected_grid_map[11][16]=32800 - expected_grid_map[11][17]=32800 - expected_grid_map[11][26]=32800 - expected_grid_map[12][6]=32800 - expected_grid_map[12][16]=32800 - expected_grid_map[12][17]=32800 - expected_grid_map[12][26]=32800 - expected_grid_map[13][6]=32800 - expected_grid_map[13][16]=32800 - expected_grid_map[13][17]=32800 - expected_grid_map[13][26]=32800 - expected_grid_map[14][6]=32800 - expected_grid_map[14][16]=32800 - expected_grid_map[14][17]=32800 - expected_grid_map[14][26]=32800 - expected_grid_map[15][6]=32800 - expected_grid_map[15][16]=32800 - expected_grid_map[15][17]=32800 - expected_grid_map[15][26]=32800 - expected_grid_map[16][6]=32800 - expected_grid_map[16][16]=32800 - expected_grid_map[16][17]=32800 - expected_grid_map[16][26]=32800 - expected_grid_map[17][6]=32800 - expected_grid_map[17][16]=72 - expected_grid_map[17][17]=1097 - expected_grid_map[17][18]=1025 - expected_grid_map[17][19]=1025 - expected_grid_map[17][20]=1025 - expected_grid_map[17][21]=1025 - expected_grid_map[17][22]=1025 - expected_grid_map[17][23]=1025 - expected_grid_map[17][24]=1025 - expected_grid_map[17][25]=1025 - expected_grid_map[17][26]=33825 - expected_grid_map[17][27]=4608 - expected_grid_map[18][6]=32800 - expected_grid_map[18][26]=72 - expected_grid_map[18][27]=52275 - expected_grid_map[18][28]=5633 - expected_grid_map[18][29]=17411 - expected_grid_map[18][30]=1025 - expected_grid_map[18][31]=1025 - expected_grid_map[18][32]=256 - expected_grid_map[19][6]=32800 - expected_grid_map[19][25]=16386 - expected_grid_map[19][26]=1025 - expected_grid_map[19][27]=2136 - expected_grid_map[19][28]=1097 - expected_grid_map[19][29]=1097 - expected_grid_map[19][30]=5633 - expected_grid_map[19][31]=1025 - expected_grid_map[19][32]=256 - expected_grid_map[20][6]=32800 - expected_grid_map[20][25]=32800 - expected_grid_map[20][26]=16386 - expected_grid_map[20][27]=17411 - expected_grid_map[20][28]=1025 - expected_grid_map[20][29]=1025 - expected_grid_map[20][30]=3089 - expected_grid_map[20][31]=1025 - expected_grid_map[20][32]=256 - expected_grid_map[21][6]=32800 - expected_grid_map[21][16]=16386 - expected_grid_map[21][17]=1025 - expected_grid_map[21][18]=1025 - expected_grid_map[21][19]=1025 - expected_grid_map[21][20]=1025 - expected_grid_map[21][21]=1025 - expected_grid_map[21][22]=1025 - expected_grid_map[21][23]=1025 - expected_grid_map[21][24]=1025 - expected_grid_map[21][25]=33825 - expected_grid_map[21][26]=33825 - expected_grid_map[21][27]=2064 - expected_grid_map[22][6]=32800 - expected_grid_map[22][16]=32800 - expected_grid_map[22][25]=32800 - expected_grid_map[22][26]=32800 - expected_grid_map[23][6]=32800 - expected_grid_map[23][16]=32800 - expected_grid_map[23][25]=32800 - expected_grid_map[23][26]=32800 - expected_grid_map[24][6]=32800 - expected_grid_map[24][16]=32800 - expected_grid_map[24][25]=32800 - expected_grid_map[24][26]=32800 - expected_grid_map[25][6]=32800 - expected_grid_map[25][16]=32800 - expected_grid_map[25][25]=32800 - expected_grid_map[25][26]=32800 - expected_grid_map[26][6]=32800 - expected_grid_map[26][16]=32800 - expected_grid_map[26][25]=32800 - expected_grid_map[26][26]=32800 - expected_grid_map[27][6]=72 - expected_grid_map[27][7]=1025 - expected_grid_map[27][8]=1025 - expected_grid_map[27][9]=17411 - expected_grid_map[27][10]=1025 - expected_grid_map[27][11]=1025 - expected_grid_map[27][12]=1025 - expected_grid_map[27][13]=1025 - expected_grid_map[27][14]=1025 - expected_grid_map[27][15]=4608 - expected_grid_map[27][16]=72 - expected_grid_map[27][17]=17411 - expected_grid_map[27][18]=5633 - expected_grid_map[27][19]=1025 - expected_grid_map[27][20]=1025 - expected_grid_map[27][21]=1025 - expected_grid_map[27][22]=1025 - expected_grid_map[27][23]=1025 - expected_grid_map[27][24]=1025 - expected_grid_map[27][25]=33825 - expected_grid_map[27][26]=2064 - expected_grid_map[28][6]=4 - expected_grid_map[28][7]=1025 - expected_grid_map[28][8]=1025 - expected_grid_map[28][9]=3089 - expected_grid_map[28][10]=1025 - expected_grid_map[28][11]=1025 - expected_grid_map[28][12]=1025 - expected_grid_map[28][13]=1025 - expected_grid_map[28][14]=4608 - expected_grid_map[28][15]=72 - expected_grid_map[28][16]=1025 - expected_grid_map[28][17]=2136 - expected_grid_map[28][18]=1097 - expected_grid_map[28][19]=5633 - expected_grid_map[28][20]=5633 - expected_grid_map[28][21]=1025 - expected_grid_map[28][22]=256 - expected_grid_map[28][25]=32800 - expected_grid_map[29][6]=4 - expected_grid_map[29][7]=5633 - expected_grid_map[29][8]=20994 - expected_grid_map[29][9]=5633 - expected_grid_map[29][10]=1025 - expected_grid_map[29][11]=1025 - expected_grid_map[29][12]=1025 - expected_grid_map[29][13]=1025 - expected_grid_map[29][14]=1097 - expected_grid_map[29][15]=5633 - expected_grid_map[29][16]=1025 - expected_grid_map[29][17]=17411 - expected_grid_map[29][18]=5633 - expected_grid_map[29][19]=1097 - expected_grid_map[29][20]=3089 - expected_grid_map[29][21]=20994 - expected_grid_map[29][22]=1025 - expected_grid_map[29][23]=1025 - expected_grid_map[29][24]=1025 - expected_grid_map[29][25]=2064 - expected_grid_map[30][6]=16386 - expected_grid_map[30][7]=38505 - expected_grid_map[30][8]=3089 - expected_grid_map[30][9]=1097 - expected_grid_map[30][10]=1025 - expected_grid_map[30][11]=1025 - expected_grid_map[30][12]=256 - expected_grid_map[30][15]=32800 - expected_grid_map[30][16]=16386 - expected_grid_map[30][17]=52275 - expected_grid_map[30][18]=1097 - expected_grid_map[30][19]=1025 - expected_grid_map[30][20]=1025 - expected_grid_map[30][21]=3089 - expected_grid_map[30][22]=256 - expected_grid_map[31][6]=32800 - expected_grid_map[31][7]=32800 - expected_grid_map[31][15]=72 - expected_grid_map[31][16]=37408 - expected_grid_map[31][17]=32800 - expected_grid_map[32][6]=32800 - expected_grid_map[32][7]=32800 - expected_grid_map[32][16]=32800 - expected_grid_map[32][17]=32800 - expected_grid_map[33][6]=32800 - expected_grid_map[33][7]=32800 - expected_grid_map[33][16]=32800 - expected_grid_map[33][17]=32800 - expected_grid_map[34][6]=32800 - expected_grid_map[34][7]=32800 - expected_grid_map[34][16]=32800 - expected_grid_map[34][17]=32800 - expected_grid_map[35][6]=32800 - expected_grid_map[35][7]=32800 - expected_grid_map[35][16]=32800 - expected_grid_map[35][17]=32800 - expected_grid_map[36][6]=32800 - expected_grid_map[36][7]=32800 - expected_grid_map[36][16]=32800 - expected_grid_map[36][17]=32800 - expected_grid_map[37][6]=72 - expected_grid_map[37][7]=1097 - expected_grid_map[37][8]=1025 - expected_grid_map[37][9]=1025 - expected_grid_map[37][10]=1025 - expected_grid_map[37][11]=1025 - expected_grid_map[37][12]=1025 - expected_grid_map[37][13]=1025 - expected_grid_map[37][14]=1025 - expected_grid_map[37][15]=1025 - expected_grid_map[37][16]=33897 - expected_grid_map[37][17]=37408 - expected_grid_map[38][16]=72 - expected_grid_map[38][17]=52275 - expected_grid_map[38][18]=5633 - expected_grid_map[38][19]=17411 - expected_grid_map[38][20]=1025 - expected_grid_map[38][21]=1025 - expected_grid_map[38][22]=256 - expected_grid_map[39][16]=4 - expected_grid_map[39][17]=52275 - expected_grid_map[39][18]=3089 - expected_grid_map[39][19]=1097 - expected_grid_map[39][20]=5633 - expected_grid_map[39][21]=1025 - expected_grid_map[39][22]=256 - expected_grid_map[40][16]=4 - expected_grid_map[40][17]=1097 - expected_grid_map[40][18]=1025 - expected_grid_map[40][19]=1025 - expected_grid_map[40][20]=3089 - expected_grid_map[40][21]=1025 - expected_grid_map[40][22]=256 + expected_grid_map[8][16] = 4 + expected_grid_map[8][17] = 5633 + expected_grid_map[8][18] = 1025 + expected_grid_map[8][19] = 1025 + expected_grid_map[8][20] = 17411 + expected_grid_map[8][21] = 1025 + expected_grid_map[8][22] = 1025 + expected_grid_map[8][23] = 1025 + expected_grid_map[8][24] = 1025 + expected_grid_map[8][25] = 1025 + expected_grid_map[8][26] = 4608 + expected_grid_map[9][16] = 16386 + expected_grid_map[9][17] = 50211 + expected_grid_map[9][18] = 1025 + expected_grid_map[9][19] = 1025 + expected_grid_map[9][20] = 3089 + expected_grid_map[9][21] = 1025 + expected_grid_map[9][22] = 256 + expected_grid_map[9][26] = 32800 + expected_grid_map[10][6] = 16386 + expected_grid_map[10][7] = 1025 + expected_grid_map[10][8] = 1025 + expected_grid_map[10][9] = 1025 + expected_grid_map[10][10] = 1025 + expected_grid_map[10][11] = 1025 + expected_grid_map[10][12] = 1025 + expected_grid_map[10][13] = 1025 + expected_grid_map[10][14] = 1025 + expected_grid_map[10][15] = 1025 + expected_grid_map[10][16] = 33825 + expected_grid_map[10][17] = 34864 + expected_grid_map[10][26] = 32800 + expected_grid_map[11][6] = 32800 + expected_grid_map[11][16] = 32800 + expected_grid_map[11][17] = 32800 + expected_grid_map[11][26] = 32800 + expected_grid_map[12][6] = 32800 + expected_grid_map[12][16] = 32800 + expected_grid_map[12][17] = 32800 + expected_grid_map[12][26] = 32800 + expected_grid_map[13][6] = 32800 + expected_grid_map[13][16] = 32800 + expected_grid_map[13][17] = 32800 + expected_grid_map[13][26] = 32800 + expected_grid_map[14][6] = 32800 + expected_grid_map[14][16] = 32800 + expected_grid_map[14][17] = 32800 + expected_grid_map[14][26] = 32800 + expected_grid_map[15][6] = 32800 + expected_grid_map[15][16] = 32800 + expected_grid_map[15][17] = 32800 + expected_grid_map[15][26] = 32800 + expected_grid_map[16][6] = 32800 + expected_grid_map[16][16] = 32800 + expected_grid_map[16][17] = 32800 + expected_grid_map[16][26] = 32800 + expected_grid_map[17][6] = 32800 + expected_grid_map[17][16] = 72 + expected_grid_map[17][17] = 1097 + expected_grid_map[17][18] = 1025 + expected_grid_map[17][19] = 1025 + expected_grid_map[17][20] = 1025 + expected_grid_map[17][21] = 1025 + expected_grid_map[17][22] = 1025 + expected_grid_map[17][23] = 1025 + expected_grid_map[17][24] = 1025 + expected_grid_map[17][25] = 1025 + expected_grid_map[17][26] = 33825 + expected_grid_map[17][27] = 4608 + expected_grid_map[18][6] = 32800 + expected_grid_map[18][26] = 72 + expected_grid_map[18][27] = 52275 + expected_grid_map[18][28] = 5633 + expected_grid_map[18][29] = 17411 + expected_grid_map[18][30] = 1025 + expected_grid_map[18][31] = 1025 + expected_grid_map[18][32] = 256 + expected_grid_map[19][6] = 32800 + expected_grid_map[19][25] = 16386 + expected_grid_map[19][26] = 1025 + expected_grid_map[19][27] = 2136 + expected_grid_map[19][28] = 1097 + expected_grid_map[19][29] = 1097 + expected_grid_map[19][30] = 5633 + expected_grid_map[19][31] = 1025 + expected_grid_map[19][32] = 256 + expected_grid_map[20][6] = 32800 + expected_grid_map[20][25] = 32800 + expected_grid_map[20][26] = 16386 + expected_grid_map[20][27] = 17411 + expected_grid_map[20][28] = 1025 + expected_grid_map[20][29] = 1025 + expected_grid_map[20][30] = 3089 + expected_grid_map[20][31] = 1025 + expected_grid_map[20][32] = 256 + expected_grid_map[21][6] = 32800 + expected_grid_map[21][16] = 16386 + expected_grid_map[21][17] = 1025 + expected_grid_map[21][18] = 1025 + expected_grid_map[21][19] = 1025 + expected_grid_map[21][20] = 1025 + expected_grid_map[21][21] = 1025 + expected_grid_map[21][22] = 1025 + expected_grid_map[21][23] = 1025 + expected_grid_map[21][24] = 1025 + expected_grid_map[21][25] = 33825 + expected_grid_map[21][26] = 33825 + expected_grid_map[21][27] = 2064 + expected_grid_map[22][6] = 32800 + expected_grid_map[22][16] = 32800 + expected_grid_map[22][25] = 32800 + expected_grid_map[22][26] = 32800 + expected_grid_map[23][6] = 32800 + expected_grid_map[23][16] = 32800 + expected_grid_map[23][25] = 32800 + expected_grid_map[23][26] = 32800 + expected_grid_map[24][6] = 32800 + expected_grid_map[24][16] = 32800 + expected_grid_map[24][25] = 32800 + expected_grid_map[24][26] = 32800 + expected_grid_map[25][6] = 32800 + expected_grid_map[25][16] = 32800 + expected_grid_map[25][25] = 32800 + expected_grid_map[25][26] = 32800 + expected_grid_map[26][6] = 32800 + expected_grid_map[26][16] = 32800 + expected_grid_map[26][25] = 32800 + expected_grid_map[26][26] = 32800 + expected_grid_map[27][6] = 72 + expected_grid_map[27][7] = 1025 + expected_grid_map[27][8] = 1025 + expected_grid_map[27][9] = 17411 + expected_grid_map[27][10] = 1025 + expected_grid_map[27][11] = 1025 + expected_grid_map[27][12] = 1025 + expected_grid_map[27][13] = 1025 + expected_grid_map[27][14] = 1025 + expected_grid_map[27][15] = 4608 + expected_grid_map[27][16] = 72 + expected_grid_map[27][17] = 17411 + expected_grid_map[27][18] = 5633 + expected_grid_map[27][19] = 1025 + expected_grid_map[27][20] = 1025 + expected_grid_map[27][21] = 1025 + expected_grid_map[27][22] = 1025 + expected_grid_map[27][23] = 1025 + expected_grid_map[27][24] = 1025 + expected_grid_map[27][25] = 33825 + expected_grid_map[27][26] = 2064 + expected_grid_map[28][6] = 4 + expected_grid_map[28][7] = 1025 + expected_grid_map[28][8] = 1025 + expected_grid_map[28][9] = 3089 + expected_grid_map[28][10] = 1025 + expected_grid_map[28][11] = 1025 + expected_grid_map[28][12] = 1025 + expected_grid_map[28][13] = 1025 + expected_grid_map[28][14] = 4608 + expected_grid_map[28][15] = 72 + expected_grid_map[28][16] = 1025 + expected_grid_map[28][17] = 2136 + expected_grid_map[28][18] = 1097 + expected_grid_map[28][19] = 5633 + expected_grid_map[28][20] = 5633 + expected_grid_map[28][21] = 1025 + expected_grid_map[28][22] = 256 + expected_grid_map[28][25] = 32800 + expected_grid_map[29][6] = 4 + expected_grid_map[29][7] = 5633 + expected_grid_map[29][8] = 20994 + expected_grid_map[29][9] = 5633 + expected_grid_map[29][10] = 1025 + expected_grid_map[29][11] = 1025 + expected_grid_map[29][12] = 1025 + expected_grid_map[29][13] = 1025 + expected_grid_map[29][14] = 1097 + expected_grid_map[29][15] = 5633 + expected_grid_map[29][16] = 1025 + expected_grid_map[29][17] = 17411 + expected_grid_map[29][18] = 5633 + expected_grid_map[29][19] = 1097 + expected_grid_map[29][20] = 3089 + expected_grid_map[29][21] = 20994 + expected_grid_map[29][22] = 1025 + expected_grid_map[29][23] = 1025 + expected_grid_map[29][24] = 1025 + expected_grid_map[29][25] = 2064 + expected_grid_map[30][6] = 16386 + expected_grid_map[30][7] = 38505 + expected_grid_map[30][8] = 3089 + expected_grid_map[30][9] = 1097 + expected_grid_map[30][10] = 1025 + expected_grid_map[30][11] = 1025 + expected_grid_map[30][12] = 256 + expected_grid_map[30][15] = 32800 + expected_grid_map[30][16] = 16386 + expected_grid_map[30][17] = 52275 + expected_grid_map[30][18] = 1097 + expected_grid_map[30][19] = 1025 + expected_grid_map[30][20] = 1025 + expected_grid_map[30][21] = 3089 + expected_grid_map[30][22] = 256 + expected_grid_map[31][6] = 32800 + expected_grid_map[31][7] = 32800 + expected_grid_map[31][15] = 72 + expected_grid_map[31][16] = 37408 + expected_grid_map[31][17] = 32800 + expected_grid_map[32][6] = 32800 + expected_grid_map[32][7] = 32800 + expected_grid_map[32][16] = 32800 + expected_grid_map[32][17] = 32800 + expected_grid_map[33][6] = 32800 + expected_grid_map[33][7] = 32800 + expected_grid_map[33][16] = 32800 + expected_grid_map[33][17] = 32800 + expected_grid_map[34][6] = 32800 + expected_grid_map[34][7] = 32800 + expected_grid_map[34][16] = 32800 + expected_grid_map[34][17] = 32800 + expected_grid_map[35][6] = 32800 + expected_grid_map[35][7] = 32800 + expected_grid_map[35][16] = 32800 + expected_grid_map[35][17] = 32800 + expected_grid_map[36][6] = 32800 + expected_grid_map[36][7] = 32800 + expected_grid_map[36][16] = 32800 + expected_grid_map[36][17] = 32800 + expected_grid_map[37][6] = 72 + expected_grid_map[37][7] = 1097 + expected_grid_map[37][8] = 1025 + expected_grid_map[37][9] = 1025 + expected_grid_map[37][10] = 1025 + expected_grid_map[37][11] = 1025 + expected_grid_map[37][12] = 1025 + expected_grid_map[37][13] = 1025 + expected_grid_map[37][14] = 1025 + expected_grid_map[37][15] = 1025 + expected_grid_map[37][16] = 33897 + expected_grid_map[37][17] = 37408 + expected_grid_map[38][16] = 72 + expected_grid_map[38][17] = 52275 + expected_grid_map[38][18] = 5633 + expected_grid_map[38][19] = 17411 + expected_grid_map[38][20] = 1025 + expected_grid_map[38][21] = 1025 + expected_grid_map[38][22] = 256 + expected_grid_map[39][16] = 4 + expected_grid_map[39][17] = 52275 + expected_grid_map[39][18] = 3089 + expected_grid_map[39][19] = 1097 + expected_grid_map[39][20] = 5633 + expected_grid_map[39][21] = 1025 + expected_grid_map[39][22] = 256 + expected_grid_map[40][16] = 4 + expected_grid_map[40][17] = 1097 + expected_grid_map[40][18] = 1025 + expected_grid_map[40][19] = 1025 + expected_grid_map[40][20] = 3089 + expected_grid_map[40][21] = 1025 + expected_grid_map[40][22] = 256 assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid, expected_grid_map) - + s0 = 0 s1 = 0 for a in range(env.get_num_agents()): - s0 = Vec2d.get_manhattan_distance(env.agents[a].position, (0, 0)) - s1 = Vec2d.get_chebyshev_distance(env.agents[a].position, (0, 0)) + s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) + s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) assert s0 == 58, "actual={}".format(s0) assert s1 == 38, "actual={}".format(s1) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 0d6d309765690b1f95c681d7d109a13071d7f86b..52b047244850a5b1c6b39dadd74568fc5f92deec 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -29,6 +29,9 @@ def test_global_obs(): global_obs = env.reset() + # we have to take step for the agent to enter the grid. + global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD}) + assert (global_obs[0][0].shape == rail_map.shape + (16,)) rail_map_recons = np.zeros_like(rail_map) @@ -109,12 +112,14 @@ def test_reward_function_conflict(rendering=False): agent.direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE agent = env.agents_static[1] agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static env.reset(False, False) @@ -184,16 +189,20 @@ def test_reward_function_waiting(rendering=False): # set the initial position agent = env.agents_static[0] + agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.target = (3, 1) # west dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE agent = env.agents_static[1] + agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 8) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static env.reset(False, False) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index f4ab68bc45a82b8f196fcfbebb26fd68a36c37a4..2517fb84de00c1346e5c601f903a9a1a677cfae4 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,6 +5,7 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction @@ -31,12 +32,13 @@ def test_dummy_predictor(rendering=False): env.reset() # set initial position and direction for testing... - env.agents_static[0].position = (5, 6) + env.agents_static[0].initial_position = (5, 6) env.agents_static[0].direction = 0 env.agents_static[0].target = (3, 0) # reset to set agents from agents_static env.reset(False, False) + env.set_agent_active(0) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -124,10 +126,12 @@ def test_shortest_path_predictor(rendering=False): # set the initial position agent = env.agents_static[0] + agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static env.reset(False, False) @@ -139,9 +143,9 @@ def test_shortest_path_predictor(rendering=False): # compute the observations and predictions distance_map = env.distance_map.get() - assert distance_map[0, agent.position[0], agent.position[ - 1], agent.direction] == 5.0, "found {} instead of {}".format( - distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) + assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \ + "found {} instead of {}".format( + distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0) paths = get_shortest_paths(env.distance_map)[0] assert paths == [ @@ -259,19 +263,23 @@ def test_shortest_path_predictor_conflicts(rendering=False): # set the initial position agent = env.agents_static[0] + agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE agent = env.agents_static[1] + agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True + agent.status = RailAgentStatus.ACTIVE # reset to set agents from agents_static - observations = env.reset(False, False) + observations = env.reset(False, False, True) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -295,14 +303,14 @@ def test_shortest_path_predictor_conflicts(rendering=False): def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt) - for a_1 in obs_builder.tree_explorted_actions_char: + for a_1 in obs_builder.tree_explored_actions_char: if tree.childs[a_1] == -np.inf: assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) continue else: conflict = tree.childs[a_1].num_agents_opposite_direction assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) - for a_2 in obs_builder.tree_explorted_actions_char: + for a_2 in obs_builder.tree_explored_actions_char: if tree.childs[a_1].childs[a_2] == -np.inf: assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) else: diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 0fefd3e212ddb5f084c1e219f4063079e03dabdf..e0281bb0d3f21a8e25fdff86ade5bed05f5dea13 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -85,7 +85,7 @@ def test_rail_environment_single_agent(): obs_builder_object=GlobalObsForRailEnv()) for _ in range(200): - _ = rail_env.reset() + _ = rail_env.reset(False, False, True) # We do not care about target for the moment agent = rail_env.agents[0] @@ -130,9 +130,6 @@ def test_rail_environment_single_agent(): done = dones['__all__'] -test_rail_environment_single_agent() - - def test_dead_end(): transitions = RailEnvTransitions() @@ -164,32 +161,12 @@ def test_dead_end(): number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) - def check_consistency(rail_env): - # We run step to check that trains do not move anymore - # after being done. - # TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving - # until they are manually stopped. - for i in range(7): - prev_pos = rail_env.agents[0].position - - # The train cannot turn, so we check that when it tries, - # it stays where it is. - _ = rail_env.step({0: 1}) - _ = rail_env.step({0: 3}) - assert (rail_env.agents[0].position == prev_pos) - _, _, dones, _ = rail_env.step({0: 2}) - - if i < 5: - assert (not dones[0] and not dones['__all__']) - else: - assert (dones[0] and dones['__all__']) - # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=1, target=(0, 0), moving=False)] rail_env.reset() - rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=3, target=(0, 4), moving=False)] # In the vertical configuration: rail_map = np.array( @@ -210,10 +187,12 @@ def test_dead_end(): obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=2, target=(0, 0), moving=False)] rail_env.reset() - rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=0, target=(4, 0), moving=False)] + + # TODO make assertions def test_get_entry_directions(): diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index e164752483e2b4ad5896d754d378a5519c960237..8b2cdbea9431d970778acb8d973cc0002d5a90f5 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -25,6 +25,7 @@ def test_sparse_rail_generator(): schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) + env.reset(False, False, True) expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) expected_grid_map[1][33] = 8192 expected_grid_map[2][33] = 32800 @@ -1549,6 +1550,9 @@ def test_rail_env_action_required_info(): obs_builder_object=GlobalObsForRailEnv()) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) + env_always_action.reset(False, False, True) + env_only_if_action_required.reset(False, False, True) + for step in range(100): print("step {}".format(step)) @@ -1610,6 +1614,7 @@ def test_rail_env_malfunction_speed_info(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv(), stochastic_data=stochastic_data) + env.reset(False, False, True) env_renderer = RenderTool(env, gl="PILSVG", ) for step in range(100): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index fa2920a9fa7c78331cc7c32ae5308633b4d3f8da..8b4fc8bb64b753532b96f0158d3fa754bf855e2b 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -2,14 +2,15 @@ import random from typing import Dict, List import numpy as np -from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay class SingleAgentNavigationObs(ObservationBuilder): @@ -29,7 +30,16 @@ class SingleAgentNavigationObs(ObservationBuilder): def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + if agent.status == RailAgentStatus.READY_TO_DEPART: + _agent_initial_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + _agent_initial_position = agent.position + elif agent.status == RailAgentStatus.DONE: + _agent_initial_position = agent.target + else: + return None + + possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions are available; @@ -41,7 +51,7 @@ class SingleAgentNavigationObs(ObservationBuilder): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = get_new_position(agent.position, direction) + new_position = get_new_position(_agent_initial_position, direction) min_distances.append( self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: @@ -70,16 +80,19 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) - obs = env.reset() + obs = env.reset(False, False, True) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 + for agent in env.agents: + agent.status = RailAgentStatus.ACTIVE agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position for step in range(100): actions = {} + for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 @@ -104,7 +117,8 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 21 + assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format( + env.agents[0].malfunction_data['nr_malfunctions']) # Check that 20 stops where performed assert agent_halts == 20 @@ -120,8 +134,6 @@ def test_malfunction_process_statistically(): 'malfunction_rate': 2, 'min_duration': 3, 'max_duration': 3} - np.random.seed(5) - random.seed(0) env = RailEnv(width=20, height=20, @@ -131,8 +143,9 @@ def test_malfunction_process_statistically(): number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) - - env.reset() + np.random.seed(5) + random.seed(0) + env.reset(False, False, True) nb_malfunction = 0 for step in range(100): action_dict: Dict[int, RailEnvActions] = {} @@ -149,9 +162,6 @@ def test_malfunction_process_statistically(): def test_initial_malfunction(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -162,7 +172,8 @@ def test_initial_malfunction(): 1. / 2.: 0., # Fast freight train 1. / 3.: 0., # Slow commuter train 1. / 4.: 0.} # Slow freight train - + np.random.seed(5) + random.seed(0) env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(num_cities=5, @@ -226,15 +237,15 @@ def test_initial_malfunction(): ) ], speed=env.agents[0].speed_data['speed'], - target=env.agents[0].target + target=env.agents[0].target, + initial_position=(28, 5), + initial_direction=Grid4TransitionsEnum.EAST, ) + run_replay_config(env, [replay_config]) def test_initial_malfunction_stop_moving(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -269,19 +280,21 @@ def test_initial_malfunction_stop_moving(): replay_config = ReplayConfig( replay=[ Replay( - position=(28, 5), + position=None, direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=RailEnvActions.MOVE_FORWARD, set_malfunction=3, malfunction=3, - reward=env.step_penalty # full step penalty when stopped + reward=env.step_penalty, # full step penalty when stopped + status=RailAgentStatus.READY_TO_DEPART ), Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, - reward=env.step_penalty # full step penalty when stopped + reward=env.step_penalty, # full step penalty when stopped + status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving @@ -291,7 +304,8 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, malfunction=1, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we have stopped and do nothing --> should stand still Replay( @@ -299,7 +313,8 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( @@ -307,21 +322,24 @@ def test_initial_malfunction_stop_moving(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped + reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), Replay( position=(28, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # full step penalty while stopped + reward=env.step_penalty * 1.0, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], - target=env.agents[0].target + target=env.agents[0].target, + initial_position=(28, 5), + initial_direction=Grid4TransitionsEnum.EAST, ) - - run_replay_config(env, [replay_config]) + run_replay_config(env, [replay_config], activate_agents=False) def test_initial_malfunction_do_nothing(): @@ -360,20 +378,23 @@ def test_initial_malfunction_do_nothing(): ) set_penalties_for_replay(env) replay_config = ReplayConfig( - replay=[Replay( - position=(28, 5), - direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, - set_malfunction=3, - malfunction=3, - reward=env.step_penalty # full step penalty while malfunctioning - ), + replay=[ + Replay( + position=None, + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + set_malfunction=3, + malfunction=3, + reward=env.step_penalty, # full step penalty while malfunctioning + status=RailAgentStatus.READY_TO_DEPART + ), Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, - reward=env.step_penalty # full step penalty while malfunctioning + reward=env.step_penalty, # full step penalty while malfunctioning + status=RailAgentStatus.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving @@ -383,7 +404,8 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=1, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we haven't started moving yet --> stay here Replay( @@ -391,7 +413,8 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, - reward=env.step_penalty # full step penalty while stopped + reward=env.step_penalty, # full step penalty while stopped + status=RailAgentStatus.ACTIVE ), # we start to move forward --> should go to next cell now Replay( @@ -399,21 +422,25 @@ def test_initial_malfunction_do_nothing(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0 + reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 + status=RailAgentStatus.ACTIVE ), Replay( position=(28, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # step penalty for speed 1.0 + reward=env.step_penalty * 1.0, # step penalty for speed 1.0 + status=RailAgentStatus.ACTIVE ) ], speed=env.agents[0].speed_data['speed'], - target=env.agents[0].target + target=env.agents[0].target, + initial_position=(28, 5), + initial_direction=Grid4TransitionsEnum.EAST, ) - run_replay_config(env, [replay_config]) + run_replay_config(env, [replay_config], activate_agents=False) def test_initial_nextmalfunction_not_below_zero(): diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 7213560f9e9873ea4488b96d30223bab8128b37b..f29629ab7c7aabb9ca2989b02a3604239d9e6143 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -1,7 +1,7 @@ import numpy as np from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator @@ -40,7 +40,7 @@ def test_get_global_observation(): number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) - obs, all_rewards, done, _ = env.step({0: 0}) + obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) for i in range(len(env.agents)): obs_agents_state = obs[i][1] diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index b0f274ba4c4b5453140fcc50bc6137e39e8e4f04..3cd0a4c1d9812c728089255eb1150111b783a463 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -63,7 +63,8 @@ def test_multi_speed_init(): # Set all the different speeds # Reset environment and get initial observations for all agents - env.reset() + env.reset(False, False, True) + # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository old_pos = [] @@ -188,7 +189,9 @@ def test_multispeed_actions_no_malfunction_no_blocking(): ), ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config]) @@ -285,7 +288,10 @@ def test_multispeed_actions_no_malfunction_blocking(): ) ], target=(3, 0), # west dead-end - speed=1 / 3), + speed=1 / 3, + initial_position=(3, 8), + initial_direction=Grid4TransitionsEnum.WEST, + ), ReplayConfig( replay=[ Replay( @@ -369,7 +375,9 @@ def test_multispeed_actions_no_malfunction_blocking(): ), ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) ] @@ -505,7 +513,9 @@ def test_multispeed_actions_malfunction_no_blocking(): ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config]) @@ -587,7 +597,9 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): ], target=(3, 0), # west dead-end - speed=0.5 + speed=0.5, + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, ) run_replay_config(env, [test_config]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 903120d868aa65833e7c2393ddfcc821c26da4f6..f5d1cd5c957d1e81ca2f2465fb8265c6b2795342 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ import numpy as np from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.utils.rendertools import RenderTool @@ -18,6 +18,7 @@ class Replay(object): malfunction = attrib(default=0, type=int) set_malfunction = attrib(default=None, type=Optional[int]) reward = attrib(default=None, type=Optional[float]) + status = attrib(default=None, type=Optional[RailAgentStatus]) @attrs @@ -25,6 +26,8 @@ class ReplayConfig(object): replay = attrib(type=List[Replay]) target = attrib(type=Tuple[int, int]) speed = attrib(type=float) + initial_position = attrib(type=Tuple[int, int]) + initial_direction = attrib(type=Grid4TransitionsEnum) # ensure that env is working correctly with start/stop/invalidaction penalty different from 0 @@ -35,7 +38,7 @@ def set_penalties_for_replay(env: RailEnv): env.invalid_action_penalty = -29 -def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): +def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True): """ Runs the replay configs and checks assertions. @@ -47,10 +50,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: - position, direction before step are verified - optionally, set_malfunction is applied - malfunction is verified + - status is verified (optionally) *After each step* - reward is verified after step + Parameters ---------- env @@ -67,18 +72,20 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for step in range(len(test_configs[0].replay)): if step == 0: for a, test_config in enumerate(test_configs): - agent: EnvAgent = env.agents[a] - replay = test_config.replay[0] + agent: EnvAgent = env.agents_static[a] # set the initial position - agent.position = replay.position - agent.direction = replay.direction + agent.initial_position = test_config.initial_position + agent.direction = test_config.initial_direction agent.target = test_config.target agent.speed_data['speed'] = test_config.speed + env.reset(False, False, activate_agents) def _assert(a, actual, expected, msg): - assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, - actual, - expected) + print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected)) + assert (actual == expected) or ( + np.allclose(actual, expected)), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg, + actual, + expected) action_dict = {} @@ -88,26 +95,29 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') + if replay.status is not None: + _assert(a, agent.status, replay.status, 'status') if replay.action is not None: - assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( + assert info_dict['action_required'][ + a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format( step, a, True) action_dict[a] = replay.action else: - assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( - step, a, False) + assert info_dict['action_required'][ + a] == False, "[{}] agent {} expecting action_required={}, but found {}".format( + step, a, False, info_dict['action_required'][a]) if replay.set_malfunction is not None: agent.malfunction_data['malfunction'] = replay.set_malfunction agent.malfunction_data['moving_before_malfunction'] = agent.moving _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - + print(step) _, rewards_dict, _, info_dict = env.step(action_dict) if rendering: renderer.render_env(show=True, show_observations=True) for a, test_config in enumerate(test_configs): replay = test_config.replay[step] - _assert(a, rewards_dict[a], replay.reward, 'reward') - + _assert(a, rewards_dict[a], replay.reward, 'reward') diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 4c925789e6560077d637e2a594c736df8850d00a..9bfe3a47d0871cbfbbabb4e35eac0bd6ecaaedf0 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -102,6 +102,7 @@ def test_rail_from_grid_transition_map(): schedule_generator=random_schedule_generator(), number_of_agents=n_agents ) + env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) # Check if the number of non-empty rail cells is ok