diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index fd71bb989e37980b5d28294eb7a65550b6935435..86485ec2c068ea410d5d27997f6f037d3aab6c23 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -12,8 +12,6 @@ import numpy as np from collections import deque -# TODO: add docstrings, pylint, etc... - class ObservationBuilder: """ @@ -127,53 +125,6 @@ class TreeObsForRailEnv(ObservationBuilder): """ neighbors = [] - for direction in range(4): - new_cell = self._new_position(position, (direction+2) % 4) - - if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: - # Check if the two cells are connected by a valid transition - transitionValid = False - for orientation in range(4): - moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) - if moves[direction]: - transitionValid = True - break - - if not transitionValid: - continue - - # Check if a transition in direction node[2] is possible if an agent - # lands in the current cell with orientation `direction'; this only - # applies to cells that are not dead-ends! - directionMatch = True - if enforce_target_direction >= 0: - directionMatch = self.env.rail.get_transition( - (new_cell[0], new_cell[1], direction), enforce_target_direction) - - # If transition is found to invalid, check if perhaps it - # is a dead-end, in which case the direction of movement is rotated - # 180 degrees (moving forward turns the agents and makes it step in the previous cell) - if not directionMatch: - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - # Check if transition is possible in new_cell - # with orientation (direction+2)%4 in direction `direction' - directionMatch = directionMatch or self.env.rail.get_transition( - (new_cell[0], new_cell[1], (direction+2) % 4), direction) - - if transitionValid and directionMatch: - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction], - current_distance+1) - neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance - possible_directions = [0, 1, 2, 3] if enforce_target_direction >= 0: # The agent must land into the current cell with orientation `enforce_target_direction'. @@ -263,7 +214,7 @@ class TreeObsForRailEnv(ObservationBuilder): #3: 1 if another agent is detected between the previous node and the current one. - #4: + #4: distance of agent to the current branch node #5: minimum distance from node to the agent's target (when landing to the node following the corresponding branch. @@ -286,6 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + root_observation = observation[:] # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation @@ -293,7 +245,7 @@ class TreeObsForRailEnv(ObservationBuilder): if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1) + branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation else: num_cells_to_fill_in = 0 @@ -305,7 +257,7 @@ class TreeObsForRailEnv(ObservationBuilder): return observation - def _explore_branch(self, handle, position, direction, depth): + def _explore_branch(self, handle, position, direction, root_observation, depth): """ Utility function to compute tree-based observations. """ @@ -319,14 +271,14 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True last_isSwitch = False last_isDeadEnd = False - last_isTerminal = False # wrong cell encountered OR cycle encountered; either way, we don't want the agent - # to land here + last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_isTarget = False - visited = set([position[0], position[1], direction]) + visited = set() other_agent_encountered = False other_target_encountered = False + num_steps = 1 while exploring: # ############################# # ############################# @@ -345,6 +297,7 @@ class TreeObsForRailEnv(ObservationBuilder): if (position[0], position[1], direction) in visited: last_isTerminal = True break + visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: @@ -377,6 +330,7 @@ class TreeObsForRailEnv(ObservationBuilder): if cell_transitions[i]: position = self._new_position(position, i) direction = i + num_steps += 1 break elif num_transitions > 0: @@ -386,11 +340,10 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case + print("WRONG CELL TYPE detected in tree-search (0 transitions possible)") last_isTerminal = True break - visited.add((position[0], position[1], direction)) - # `position' is either a terminal node or a switch observation = [] @@ -403,25 +356,27 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - 0, + root_observation[3]+num_steps, 0] elif last_isTerminal: observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - 0, + np.inf, np.inf] else: observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - 0, + root_observation[3]+num_steps, self.distance_map[handle, position[0], position[1], direction]] # ############################# # ############################# + new_root_observation = observation[:] + # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: @@ -431,14 +386,22 @@ class TreeObsForRailEnv(ObservationBuilder): # it back new_cell = self._new_position(position, (branch_direction+2) % 4) - branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1) + branch_observation = self._explore_branch(handle, + new_cell, + (branch_direction+2) % 4, + new_root_observation, + depth+1) observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction): new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1) + branch_observation = self._explore_branch(handle, + new_cell, + branch_direction, + new_root_observation, + depth+1) observation = observation + branch_observation else: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 37a97ff80beff3a2d2abc40f297d456021b071f3..cf206033d7aab65571b18de0c95d729f5d09c65c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -486,6 +486,8 @@ class RailEnv(Environment): for handle in self.agents_handles: self.dones[handle] = False + # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial + # agent's orientations that allow a valid solution. re_generate = True while re_generate: valid_positions = [] diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index 44dbfc6d8f14a7293e148f125b47eed66c9ca08d..db264c2975b75f32c2612aa19c0511076460ec6b 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import numpy as np + from flatland.core.env_observation_builder import GlobalObsForRailEnv from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator -from flatland.utils.rendertools import * """Tests for `flatland` package.""" @@ -45,14 +46,14 @@ def test_global_obs(): double_switch_south_horizontal_straight, 180) rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) @@ -81,17 +82,3 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell assert(np.sum(rail_map * global_obs[0][1][0]) > 0) - - - -test_global_obs() - - - - - - - - - -