Skip to content
Snippets Groups Projects
Commit c794eb4e authored by gmollard's avatar gmollard
Browse files

fixed build failing

parent 0434a32e
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from collections import deque from collections import deque
## TODO: add docstrings, pylint, etc... # TODO: add docstrings, pylint, etc...
class ObservationBuilder: class ObservationBuilder:
...@@ -36,12 +36,16 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -36,12 +36,16 @@ class TreeObsForRailEnv(ObservationBuilder):
self.distance_map[target_nr, position[0], position[1]] = 0 self.distance_map[target_nr, position[0], position[1]] = 0
# Fill in the (up to) 4 neighboring nodes # Fill in the (up to) 4 neighboring nodes
# nodes_queue = [] # list of tuples (row, col, direction, distance); direction is the direction of movement, meaning that at least a possible orientation of an agent in cell (row,col) allows a movement in direction `direction' # nodes_queue = [] # list of tuples (row, col, direction, distance);
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) # direction is the direction of movement, meaning that at least a possible orientation
# of an agent in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position,
target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid # BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction # Stop the search if the target position is re-visited, in any direction
visited = set([(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), (position[0], position[1], 3)]) visited = set([(position[0], position[1], 0), (position[0], position[1], 1),
(position[0], position[1], 2), (position[0], position[1], 3)])
max_distance = 0 max_distance = 0
...@@ -56,8 +60,11 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -56,8 +60,11 @@ class TreeObsForRailEnv(ObservationBuilder):
if node_id not in visited: if node_id not in visited:
visited.add(node_id) visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those whose new orientation in the current cell would allow a transition to direction node[2] # From the list of possible neighbors that have at least a path to the
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) # current node, only keep those whose new orientation in the current cell
# would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors(
(node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors: for n in valid_neighbors:
nodes_queue.append(n) nodes_queue.append(n)
...@@ -74,7 +81,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -74,7 +81,8 @@ class TreeObsForRailEnv(ObservationBuilder):
for direction in range(4): for direction in range(4):
new_cell = self._new_position(position, (direction+2)%4) new_cell = self._new_position(position, (direction+2)%4)
if new_cell[0]>=0 and new_cell[0]<self.env.height and new_cell[1]>=0 and new_cell[1]<self.env.width: if new_cell[0]>=0 and new_cell[0]<self.env.height and\
new_cell[1]>=0 and new_cell[1]<self.env.width:
# Check if the two cells are connected by a valid transition # Check if the two cells are connected by a valid transition
transitionValid = False transitionValid = False
for orientation in range(4): for orientation in range(4):
...@@ -86,12 +94,17 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -86,12 +94,17 @@ class TreeObsForRailEnv(ObservationBuilder):
if not transitionValid: if not transitionValid:
continue continue
# Check if a transition in direction node[2] is possible if an agent lands in the current cell with orientation `direction'; this only applies to cells that are not dead-ends! # Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True directionMatch = True
if enforce_target_direction>=0: if enforce_target_direction>=0:
directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), enforce_target_direction) directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
# If transition is found to invalid, check if perhaps it is a dead-end, in which case the direction of movement is rotated 180 degrees (moving forward turns the agents and makes it step in the previous cell) # If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if not directionMatch: if not directionMatch:
# If cell is a dead-end, append previous node with reversed # If cell is a dead-end, append previous node with reversed
# orientation! # orientation!
...@@ -102,11 +115,14 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -102,11 +115,14 @@ class TreeObsForRailEnv(ObservationBuilder):
tmp = tmp >> 1 tmp = tmp >> 1
if nbits == 1: if nbits == 1:
# Dead-end! # Dead-end!
# Check if transition is possible in new_cell with orientation (direction+2)%4 in direction `direction' # Check if transition is possible in new_cell
directionMatch = directionMatch or self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2)%4), direction) # with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2)%4), direction)
if transitionValid and directionMatch: if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1]], current_distance+1) new_distance = min(self.distance_map[target_nr,
new_cell[0], new_cell[1]], current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
...@@ -125,7 +141,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -125,7 +141,7 @@ class TreeObsForRailEnv(ObservationBuilder):
def get(self, handle): def get(self, handle):
# TODO: compute the observation for agent `handle' # TODO: compute the observation for agent `handle'
return []
...@@ -179,7 +195,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -179,7 +195,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if child_idx != forbidden_action or in_tree_state.direction == -1: if child_idx != forbidden_action or in_tree_state.direction == -1:
tree_state = copy.deepcopy(in_tree_state) tree_state = copy.deepcopy(in_tree_state)
tree_state.direction = child_idx tree_state.direction = child_idx
current_position, invalid_move = self._detect_path(tree_state.position, tree_state.direction) current_position, invalid_move = self._detect_path(
tree_state.position, tree_state.direction)
if tree_state.initial_direction == None: if tree_state.initial_direction == None:
tree_state.initial_direction = child_idx tree_state.initial_direction = child_idx
if not invalid_move: if not invalid_move:
...@@ -211,7 +228,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -211,7 +228,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if self._switch_detection(tree_state.position): if self._switch_detection(tree_state.position):
tree_state.depth += 1 tree_state.depth += 1
new_tree_state = copy.deepcopy(tree_state) new_tree_state = copy.deepcopy(tree_state)
new_node = parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction) new_node = parent_node.insert(tree_state.position,
tree_state.data, tree_state.initial_direction)
new_tree_state.initial_direction = None new_tree_state.initial_direction = None
new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf] new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
self._tree_search(new_tree_state, new_node, agent) self._tree_search(new_tree_state, new_node, agent)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment