From c794eb4eced3fb3ea42364169d45977fe5a25ee9 Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Thu, 18 Apr 2019 17:35:03 +0200 Subject: [PATCH] fixed build failing --- flatland/core/env_observation_builder.py | 50 ++++++++++++++++-------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index c51dff31..f8341e27 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -1,7 +1,7 @@ import numpy as np from collections import deque -## TODO: add docstrings, pylint, etc... +# TODO: add docstrings, pylint, etc... class ObservationBuilder: @@ -36,12 +36,16 @@ class TreeObsForRailEnv(ObservationBuilder): self.distance_map[target_nr, position[0], position[1]] = 0 # Fill in the (up to) 4 neighboring nodes - # nodes_queue = [] # list of tuples (row, col, direction, distance); direction is the direction of movement, meaning that at least a possible orientation of an agent in cell (row,col) allows a movement in direction `direction' - nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) + # nodes_queue = [] # list of tuples (row, col, direction, distance); + # direction is the direction of movement, meaning that at least a possible orientation + # of an agent in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(position, + target_nr, 0, enforce_target_direction=-1)) # BFS from target `position' to all the reachable nodes in the grid # Stop the search if the target position is re-visited, in any direction - visited = set([(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), (position[0], position[1], 3)]) + visited = set([(position[0], position[1], 0), (position[0], position[1], 1), + (position[0], position[1], 2), (position[0], position[1], 3)]) max_distance = 0 @@ -56,8 +60,11 @@ class TreeObsForRailEnv(ObservationBuilder): if node_id not in visited: visited.add(node_id) - # From the list of possible neighbors that have at least a path to the current node, only keep those whose new orientation in the current cell would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) + # From the list of possible neighbors that have at least a path to the + # current node, only keep those whose new orientation in the current cell + # would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors( + (node[0], node[1]), target_nr, node[3], node[2]) for n in valid_neighbors: nodes_queue.append(n) @@ -74,7 +81,8 @@ class TreeObsForRailEnv(ObservationBuilder): for direction in range(4): new_cell = self._new_position(position, (direction+2)%4) - if new_cell[0]>=0 and new_cell[0]<self.env.height and new_cell[1]>=0 and new_cell[1]<self.env.width: + if new_cell[0]>=0 and new_cell[0]<self.env.height and\ + new_cell[1]>=0 and new_cell[1]<self.env.width: # Check if the two cells are connected by a valid transition transitionValid = False for orientation in range(4): @@ -86,12 +94,17 @@ class TreeObsForRailEnv(ObservationBuilder): if not transitionValid: continue - # Check if a transition in direction node[2] is possible if an agent lands in the current cell with orientation `direction'; this only applies to cells that are not dead-ends! + # Check if a transition in direction node[2] is possible if an agent + # lands in the current cell with orientation `direction'; this only + # applies to cells that are not dead-ends! directionMatch = True if enforce_target_direction>=0: - directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), enforce_target_direction) + directionMatch = self.env.rail.get_transition( + (new_cell[0], new_cell[1], direction), enforce_target_direction) - # If transition is found to invalid, check if perhaps it is a dead-end, in which case the direction of movement is rotated 180 degrees (moving forward turns the agents and makes it step in the previous cell) + # If transition is found to invalid, check if perhaps it + # is a dead-end, in which case the direction of movement is rotated + # 180 degrees (moving forward turns the agents and makes it step in the previous cell) if not directionMatch: # If cell is a dead-end, append previous node with reversed # orientation! @@ -102,11 +115,14 @@ class TreeObsForRailEnv(ObservationBuilder): tmp = tmp >> 1 if nbits == 1: # Dead-end! - # Check if transition is possible in new_cell with orientation (direction+2)%4 in direction `direction' - directionMatch = directionMatch or self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2)%4), direction) + # Check if transition is possible in new_cell + # with orientation (direction+2)%4 in direction `direction' + directionMatch = directionMatch or self.env.rail.get_transition( + (new_cell[0], new_cell[1], (direction+2)%4), direction) if transitionValid and directionMatch: - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1]], current_distance+1) + new_distance = min(self.distance_map[target_nr, + new_cell[0], new_cell[1]], current_distance+1) neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance @@ -125,7 +141,7 @@ class TreeObsForRailEnv(ObservationBuilder): def get(self, handle): # TODO: compute the observation for agent `handle' - + return [] @@ -179,7 +195,8 @@ class TreeObsForRailEnv(ObservationBuilder): if child_idx != forbidden_action or in_tree_state.direction == -1: tree_state = copy.deepcopy(in_tree_state) tree_state.direction = child_idx - current_position, invalid_move = self._detect_path(tree_state.position, tree_state.direction) + current_position, invalid_move = self._detect_path( + tree_state.position, tree_state.direction) if tree_state.initial_direction == None: tree_state.initial_direction = child_idx if not invalid_move: @@ -211,7 +228,8 @@ class TreeObsForRailEnv(ObservationBuilder): if self._switch_detection(tree_state.position): tree_state.depth += 1 new_tree_state = copy.deepcopy(tree_state) - new_node = parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction) + new_node = parent_node.insert(tree_state.position, + tree_state.data, tree_state.initial_direction) new_tree_state.initial_direction = None new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf] self._tree_search(new_tree_state, new_node, agent) -- GitLab