diff --git a/docs/tutorials/01_gettingstarted.rst b/docs/tutorials/01_gettingstarted.rst index 9ca370a0e1ed6a698d3e37111c70b97ac0ad2aa8..a7a8e5514e03d2f9feb31408d433badc7f3767c2 100644 --- a/docs/tutorials/01_gettingstarted.rst +++ b/docs/tutorials/01_gettingstarted.rst @@ -160,7 +160,7 @@ Once we are set with the environment we can load our preferred agent from either .. code-block:: python - agent = RandomAgent(env.action_space, env.observation_space) + agent = RandomAgent(state_size, action_size) We start every trial by resetting the environment diff --git a/docs/tutorials/02_observationbuilder.rst b/docs/tutorials/02_observationbuilder.rst index fd5decae31139a883289155131d71fd6870e0a74..d1c287fedf880fd091a1b24921292010eb01359e 100644 --- a/docs/tutorials/02_observationbuilder.rst +++ b/docs/tutorials/02_observationbuilder.rst @@ -18,7 +18,7 @@ base class and must implement two methods, :code:`reset(self)` and :code:`get(se .. _`flatland.core.env_observation_builder.ObservationBuilder` : https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/core/env_observation_builder.py#L13 -Below is a simple example that returns observation vectors of size :code:`observation_space = 5` featuring only the ID (handle) of the agent whose +Below is a simple example that returns observation vectors of size 5 featuring only the ID (handle) of the agent whose observation vector is being computed: .. code-block:: python @@ -28,14 +28,12 @@ observation vector is being computed: Simplest observation builder. The object returns observation vectors with 5 identical components, all equal to the ID of the respective agent. """ - def __init__(self): - self.observation_space = [5] def reset(self): return def get(self, handle): - observation = handle * np.ones((self.observation_space[0],)) + observation = handle * np.ones(5) return observation We can pass an instance of our custom observation builder :code:`SimpleObs` to the :code:`RailEnv` creator as follows: @@ -85,7 +83,6 @@ Note that this simple strategy fails when multiple agents are present, as each a super().__init__(max_depth=0) # We set max_depth=0 in because we only need to look at the current # position of the agent to decide what direction is shortest. - self.observation_space = [3] def reset(self): # Recompute the distance map, if the environment has changed. @@ -189,7 +186,6 @@ In contrast to the previous examples we also implement the :code:`def get_many(s def __init__(self, predictor): super().__init__(max_depth=0) - self.observation_space = [10] self.predictor = predictor def reset(self): diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 705169e95c71137e93f92e8026f82a34d29d2182..8c12886777c0d47952db4aeea8314488d13c19c3 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -18,7 +18,6 @@ class SimpleObs(ObservationBuilder): def __init__(self): super().__init__() - self.observation_space = [5] def reset(self): return diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index e9c2a84eea5e375c024b96a35934e136bb5d40b5..7ddfcd899f747f038471cbe3921e6df76fff37ee 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -28,7 +28,6 @@ class SingleAgentNavigationObs(ObservationBuilder): def __init__(self): super().__init__() - self.observation_space = [3] def reset(self): pass diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 2ed47a5f8c18894c8e0a2108b7bc0a8d73b783f9..855d1f5dffbef29da26ca1dead933af846b863bd 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -28,7 +28,6 @@ class ObservePredictions(ObservationBuilder): def __init__(self, predictor): super().__init__() - self.observation_space = [10] self.predictor = predictor def reset(self): diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 7cb7d9623c79d154e549114e779d39c138cf788d..a52eeed47c5cb1fe75c87d430d93f30f50336fbf 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -25,7 +25,6 @@ class SingleAgentNavigationObs(ObservationBuilder): def __init__(self): super().__init__() - self.observation_space = [3] def reset(self): pass diff --git a/flatland/core/env.py b/flatland/core/env.py index 2dc983c5edd49e4ebb5033877b69f36f293141b2..32b688ca78e35b1e36aac85c0da4a4ee22246d1b 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -11,7 +11,6 @@ class Environment: Derived environments should implement the following attributes: action_space: tuple with the dimensions of the actions to be passed to the step method - observation_space: tuple with the dimensions of the observations returned by reset and step Agents are identified by agent ids (handles). Examples: @@ -46,7 +45,6 @@ class Environment: def __init__(self): self.action_space = () - self.observation_space = () pass def reset(self): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 2d4df089eed08ee17f3d5f89147735b1b8570a7d..3cc21966162dd28d183493a97cd6072a34abb738 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -18,13 +18,9 @@ from flatland.core.env import Environment class ObservationBuilder: """ ObservationBuilder base class. - - Derived objects must implement and `observation_space` attribute as a tuple with the dimensions of the returned - observations. """ def __init__(self): - self.observation_space = () self.env = None def set_env(self, env: Environment): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index de9ee2a45ebdeabec5202be4f12593a82b4e20e4..c23d4345a03c761ad4c4ac1d936db817f8acc529 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -1,8 +1,8 @@ """ Collection of environment-specific ObservationBuilder. """ -import pprint -from typing import Optional, List, Dict, T, Tuple +import collections +from typing import Optional, List, Dict, Tuple import numpy as np @@ -15,6 +15,22 @@ from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): + + Node = collections.namedtuple('Node', 'dist_own_target_encountered ' + 'dist_other_target_encountered ' + 'dist_other_agent_encountered ' + 'dist_potential_conflict ' + 'dist_unusable_switch ' + 'dist_to_next_branch ' + 'dist_min_to_target ' + 'num_agents_same_direction ' + 'num_agents_opposite_direction ' + 'num_agents_malfunctioning ' + 'speed_min_fractional ' + 'childs') + + tree_explorted_actions_char = ['L', 'F', 'R', 'B'] + """ TreeObsForRailEnv object. @@ -29,24 +45,15 @@ class TreeObsForRailEnv(ObservationBuilder): super().__init__() self.max_depth = max_depth self.observation_dim = 11 - # Compute the size of the returned observation vector - size = 0 - pow4 = 1 - for i in range(self.max_depth + 1): - size += pow4 - pow4 *= 4 - self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor self.location_has_target = None - self.tree_explored_actions = [1, 2, 3, 0] - self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] def reset(self): self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]: + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]: """ Called whenever an observation has to be computed for the `env` environment, for each agent with handle in the `handles` list. @@ -75,7 +82,7 @@ class TreeObsForRailEnv(ObservationBuilder): observations[h] = self.get(h) return observations - def get(self, handle: int = 0) -> List[int]: + def get(self, handle: int = 0) -> Node: """ Computes the current observation for agent `handle` in env @@ -165,11 +172,18 @@ class TreeObsForRailEnv(ObservationBuilder): possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) - # Root node - current position # Here information about the agent itself is stored distance_map = self.env.distance_map.get() - observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0, - agent.malfunction_data['malfunction'], agent.speed_data['speed']] + + root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, + dist_other_agent_encountered=0, dist_potential_conflict=0, + dist_unusable_switch=0, dist_to_next_branch=0, + dist_min_to_target=distance_map[(handle, *agent.position, + agent.direction)], + num_agents_same_direction=0, num_agents_opposite_direction=0, + num_agents_malfunctioning=agent.malfunction_data['malfunction'], + speed_min_fractional=agent.speed_data['speed'], + childs={}) visited = OrderedSet() @@ -181,28 +195,22 @@ class TreeObsForRailEnv(ObservationBuilder): if num_transitions == 1: orientation = np.argmax(possible_transitions) - for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: + for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): + if possible_transitions[branch_direction]: new_cell = get_new_position(agent.position, branch_direction) + branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) - observation = observation + branch_observation + root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation + visited |= branch_visited else: # add cells filled with infinity if no transition is possible - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) + root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited - return observation - - def _num_cells_to_fill_in(self, remaining_depth): - """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" - num_observations = 0 - pow4 = 1 - for i in range(remaining_depth): - num_observations += pow4 - pow4 *= 4 - return num_observations * self.observation_dim + return root_node_observation def _explore_branch(self, handle, position, direction, tot_dist, depth): """ @@ -378,53 +386,35 @@ class TreeObsForRailEnv(ObservationBuilder): # Modify here to append new / different features for each visited cell! if last_is_target: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - 0, - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - + dist_to_next_branch = tot_dist + dist_min_to_target = 0 elif last_is_terminal: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - np.inf, - self.env.distance_map.get()[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - + dist_to_next_branch = np.inf + dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] else: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - self.env.distance_map.get()[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] + dist_to_next_branch = tot_dist + dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] + + node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, + dist_other_target_encountered=other_target_encountered, + dist_other_agent_encountered=other_agent_encountered, + dist_potential_conflict=potential_conflict, + dist_unusable_switch=unusable_switch, + dist_to_next_branch=dist_to_next_branch, + dist_min_to_target=dist_min_to_target, + num_agents_same_direction=other_agent_same_direction, + num_agents_opposite_direction=other_agent_opposite_direction, + num_agents_malfunctioning=malfunctioning_agent, + speed_min_fractional=min_fractional_speed, + childs={}) + # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions(*position, direction) - for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: + for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes @@ -435,7 +425,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4, tot_dist + 1, depth + 1) - observation = observation + branch_observation + node.childs[self.tree_explorted_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: @@ -445,51 +435,45 @@ class TreeObsForRailEnv(ObservationBuilder): branch_direction, tot_dist + 1, depth + 1) - observation = observation + branch_observation + node.childs[self.tree_explorted_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) + node.childs[self.tree_explorted_actions_char[i]] = -np.inf - return observation, visited + if depth == self.max_depth: + node.childs.clear() + return node, visited - def util_print_obs_subtree(self, tree): + def util_print_obs_subtree(self, tree: Node): """ - Utility function to pretty-print tree observations returned by this object. + Utility function to print tree observations returned by this object. """ - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(self.unfold_observation_tree(tree)) + self.print_node_features(tree, "root", "") + for direction in self.tree_explorted_actions_char: + self.print_subtree(tree.childs[direction], direction, "\t") + + @staticmethod + def print_node_features(node: Node, label, indent): + print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ", + node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ", + node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ", + node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction, + ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional) + + def print_subtree(self, node, label, indent): + if node == -np.inf or not node: + print(indent, "Direction ", label, ": -np.inf") + return - def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): - """ - Utility function to pretty-print tree observations returned by this object. - """ - if len(tree) < self.observation_dim: + self.print_node_features(node, label, indent) + + if not node.childs: return - depth = 0 - tmp = len(tree) / self.observation_dim - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - - unfolded = {} - unfolded[''] = tree[0:self.observation_dim] - child_size = (len(tree) - self.observation_dim) // 4 - for child in range(4): - child_tree = tree[(self.observation_dim + child * child_size): - (self.observation_dim + (child + 1) * child_size)] - observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) - if observation_tree is not None: - if actions_for_display: - label = self.tree_explorted_actions_char[child] - else: - label = self.tree_explored_actions[child] - unfolded[label] = observation_tree - return unfolded + for direction in self.tree_explorted_actions_char: + self.print_subtree(node.childs[direction], direction, indent + "\t") def set_env(self, env: Environment): super().set_env(env) @@ -508,23 +492,21 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16),\ assuming 16 bits encoding of transitions. + - A 3D array (map_height, map_width, 4) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and diretion + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets. - - - A 3D array (map_height, map_width, 4) wtih - - first channel containing the agents position and direction - - second channel containing the other agents positions and diretions - - third channel containing agent malfunctions - - fourth channel containing agent fractional speeds """ def __init__(self): - self.observation_space = () super(GlobalObsForRailEnv, self).__init__() def set_env(self, env: Environment): super().set_env(env) - self.observation_space = [4, self.env.height, self.env.width] def reset(self): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) @@ -535,22 +517,21 @@ class GlobalObsForRailEnv(ObservationBuilder): self.rail_obs[i, j] = np.array(bitlist) def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - agents = self.env.agents - agent = agents[handle] + obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1 - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = agents[handle].direction + agent = self.env.agents[handle] + obs_agents_state[agent.position][0] = agent.direction obs_targets[agent.target][0] = 1 - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1] = agent2.direction - obs_targets[agent2.target][1] = 1 - obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction'] - obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed'] + for i in range(len(self.env.agents)): + other_agent = self.env.agents[i] + if i != handle: + obs_agents_state[other_agent.position][1] = other_agent.direction + obs_targets[other_agent.target][1] = 1 + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] return self.rail_obs, obs_agents_state, obs_targets diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6fefc3f4d76b27744356f4c378be678303cf234c..98fa4d2e4b557edced94c6eb3692f4949abe9221 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -189,7 +189,6 @@ class RailEnv(Environment): self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] - self.observation_space = self.obs_builder.observation_space # updated on resets? # Stochastic train malfunctioning parameters if stochastic_data is not None: @@ -300,7 +299,6 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() - self.observation_space = self.obs_builder.observation_space # <-- change on reset? self.distance_map.reset(self.agents, self.rail) # Return the new observation vectors for each agent diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index d2663916a17a70597d10e489da7aead4f8932dc4..0d6d309765690b1f95c681d7d109a13071d7f86b 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -41,7 +41,9 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell - assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) + obs_agents_state = global_obs[0][1] + obs_agents_state = obs_agents_state + 1 + assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0) def _step_along_shortest_path(env, obs_builder, rail): diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 7ee0fd4aadf72b12b591259a71af8b408145418f..f4ab68bc45a82b8f196fcfbebb26fd68a36c37a4 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -281,9 +281,10 @@ def test_shortest_path_predictor_conflicts(rendering=False): # get the trees to test obs_builder: TreeObsForRailEnv = env.obs_builder pp = pprint.PrettyPrinter(indent=4) - tree_0 = obs_builder.unfold_observation_tree(observations[0]) - tree_1 = obs_builder.unfold_observation_tree(observations[1]) - pp.pprint(tree_0) + tree_0 = observations[0] + tree_1 = observations[1] + env.obs_builder.util_print_obs_subtree(tree_0) + env.obs_builder.util_print_obs_subtree(tree_1) # check the expectations expected_conflicts_0 = [('F', 'R')] @@ -292,11 +293,18 @@ def test_shortest_path_predictor_conflicts(rendering=False): _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") -def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''): - assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt) +def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): + assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt) for a_1 in obs_builder.tree_explorted_actions_char: - conflict = tree_0[a_1][''][8] - assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) + if tree.childs[a_1] == -np.inf: + assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) + continue + else: + conflict = tree.childs[a_1].num_agents_opposite_direction + assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) for a_2 in obs_builder.tree_explorted_actions_char: - conflict = tree_0[a_1][a_2][''][8] - assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) + if tree.childs[a_1].childs[a_2] == -np.inf: + assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) + else: + conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction + assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 86c206097648078b33531fb1bad3b4f091b25ba4..cc61150325ee86710ea3ee3820d8386c7b926da6 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -22,7 +22,6 @@ class SingleAgentNavigationObs(ObservationBuilder): def __init__(self): super().__init__() - self.observation_space = [3] def reset(self): pass diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py new file mode 100644 index 0000000000000000000000000000000000000000..7213560f9e9873ea4488b96d30223bab8128b37b --- /dev/null +++ b/tests/test_global_observation.py @@ -0,0 +1,64 @@ +import numpy as np + +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator + + +def test_get_global_observation(): + np.random.seed(1) + number_of_agents = 20 + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=25, + # Number of cities in map (where train stations are) + num_intersections=10, + # Number of intersections (no start / target) + num_trainstations=50, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=4, # Proximity of stations to city center + num_neighb=4, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=GlobalObsForRailEnv()) + + obs, all_rewards, done, _ = env.step({0: 0}) + + for i in range(len(env.agents)): + obs_agents_state = obs[i][1] + obs_targets = obs[i][2] + + nr_agents = np.count_nonzero(obs_targets[:, :, 0]) + nr_agents_other = np.count_nonzero(obs_targets[:, :, 1]) + assert nr_agents == 1 + assert nr_agents_other == (number_of_agents - 1) + + # since the array is initialized with -1 add one in order to used np.count_nonzero + obs_agents_state += 1 + obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0]) + obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1]) + obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2]) + obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3]) + assert obs_agents_state_0 == 1 + assert obs_agents_state_1 == (number_of_agents - 1) + assert obs_agents_state_2 == number_of_agents + assert obs_agents_state_3 == number_of_agents +