From 1e2d8f2647b318488de0bcdf3b1366bd6bd955d4 Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Tue, 23 Apr 2019 13:04:36 +0200 Subject: [PATCH] fixed bad bugs in distance_map calculation + added distance from agent to branch node --- flatland/core/env_observation_builder.py | 75 ++++++------------------ flatland/envs/rail_env.py | 2 + 2 files changed, 19 insertions(+), 58 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index fd71bb9..7271f68 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -127,53 +127,6 @@ class TreeObsForRailEnv(ObservationBuilder): """ neighbors = [] - for direction in range(4): - new_cell = self._new_position(position, (direction+2) % 4) - - if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: - # Check if the two cells are connected by a valid transition - transitionValid = False - for orientation in range(4): - moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) - if moves[direction]: - transitionValid = True - break - - if not transitionValid: - continue - - # Check if a transition in direction node[2] is possible if an agent - # lands in the current cell with orientation `direction'; this only - # applies to cells that are not dead-ends! - directionMatch = True - if enforce_target_direction >= 0: - directionMatch = self.env.rail.get_transition( - (new_cell[0], new_cell[1], direction), enforce_target_direction) - - # If transition is found to invalid, check if perhaps it - # is a dead-end, in which case the direction of movement is rotated - # 180 degrees (moving forward turns the agents and makes it step in the previous cell) - if not directionMatch: - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - # Check if transition is possible in new_cell - # with orientation (direction+2)%4 in direction `direction' - directionMatch = directionMatch or self.env.rail.get_transition( - (new_cell[0], new_cell[1], (direction+2) % 4), direction) - - if transitionValid and directionMatch: - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction], - current_distance+1) - neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance - possible_directions = [0, 1, 2, 3] if enforce_target_direction >= 0: # The agent must land into the current cell with orientation `enforce_target_direction'. @@ -263,7 +216,7 @@ class TreeObsForRailEnv(ObservationBuilder): #3: 1 if another agent is detected between the previous node and the current one. - #4: + #4: distance of agent to the current branch node #5: minimum distance from node to the agent's target (when landing to the node following the corresponding branch. @@ -286,6 +239,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + root_observation = observation[:] # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation @@ -293,7 +247,7 @@ class TreeObsForRailEnv(ObservationBuilder): if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1) + branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation else: num_cells_to_fill_in = 0 @@ -305,7 +259,7 @@ class TreeObsForRailEnv(ObservationBuilder): return observation - def _explore_branch(self, handle, position, direction, depth): + def _explore_branch(self, handle, position, direction, root_observation, depth): """ Utility function to compute tree-based observations. """ @@ -323,10 +277,11 @@ class TreeObsForRailEnv(ObservationBuilder): # to land here last_isTarget = False - visited = set([position[0], position[1], direction]) + visited = set() other_agent_encountered = False other_target_encountered = False + num_steps = 1 while exploring: # ############################# # ############################# @@ -345,6 +300,8 @@ class TreeObsForRailEnv(ObservationBuilder): if (position[0], position[1], direction) in visited: last_isTerminal = True break + visited.add((position[0], position[1], direction)) + # If the target node is encountered, pick that as node. Also, no further branching is possible. if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: @@ -377,6 +334,7 @@ class TreeObsForRailEnv(ObservationBuilder): if cell_transitions[i]: position = self._new_position(position, i) direction = i + num_steps += 1 break elif num_transitions > 0: @@ -386,11 +344,10 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case + print("WRONG CELL TYPE detected in tree-search (0 transitions possible)") last_isTerminal = True break - visited.add((position[0], position[1], direction)) - # `position' is either a terminal node or a switch observation = [] @@ -403,25 +360,27 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - 0, + root_observation[3]+num_steps, 0] elif last_isTerminal: observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - 0, + np.inf, np.inf] else: observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - 0, + root_observation[3]+num_steps, self.distance_map[handle, position[0], position[1], direction]] # ############################# # ############################# + new_root_observation = observation[:] + # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: @@ -431,14 +390,14 @@ class TreeObsForRailEnv(ObservationBuilder): # it back new_cell = self._new_position(position, (branch_direction+2) % 4) - branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1) + branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, new_root_observation, depth+1) observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction): new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1) + branch_observation = self._explore_branch(handle, new_cell, branch_direction, new_root_observation, depth+1) observation = observation + branch_observation else: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 37a97ff..cf20603 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -486,6 +486,8 @@ class RailEnv(Environment): for handle in self.agents_handles: self.dones[handle] = False + # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial + # agent's orientations that allow a valid solution. re_generate = True while re_generate: valid_positions = [] -- GitLab