From def798afa817d2ddc60dd84b41f1927ef19d0eef Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Thu, 18 Apr 2019 14:43:00 +0200
Subject: [PATCH] distance_map code for TreeObsForRailEnv

---
 flatland/core/env_observation_builder.py | 304 ++++++++++++++++++++++-
 1 file changed, 301 insertions(+), 3 deletions(-)

diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 17061854..02deab12 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -1,3 +1,6 @@
+import numpy as np
+from collections import deque
+
 ## TODO: add docstrings, pylint, etc...
 
 
@@ -14,10 +17,112 @@ class ObservationBuilder:
 
 
 class TreeObsForRailEnv(ObservationBuilder):
+    def __init__(self, env):
+        self.env = env
+
     def reset(self):
-        # TODO: precompute distances, etc...
-        #raise NotImplementedError()
-        pass
+        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
+                                                    self.env.height,
+                                                    self.env.width))
+        self.max_dist = np.zeros(self.env.number_of_agents)
+
+        for i in range(self.env.number_of_agents):
+            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+
+
+    def _distance_map_walker(self, position, target_nr):
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+
+        for ori in range(4):
+            self.distance_map[target_nr, position[0], position[1]] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # nodes_queue = []  # list of tuples (row, col, direction, distance);  direction is the direction of movement, meaning that at least a possible orientation of an agent in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = set([(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), (position[0], position[1], 3)])
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            #print(node_id, visited, (node_id in visited))
+            #print(nodes_queue)
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the current node, only keep those whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors)>0:
+                    max_distance = max(max_distance, node[3]+1)
+
+        return max_distance
+
+
+    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
+        neighbors = []
+
+        for direction in range(4):
+            new_cell = self._new_position(position, (direction+2)%4)
+
+            if new_cell[0]>=0 and new_cell[0]<self.env.height and new_cell[1]>=0 and new_cell[1]<self.env.width:
+                # Check if the two cells are connected by a valid transition
+                transitionValid = False
+                for orientation in range(4):
+                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
+                    if moves[direction]:
+                        transitionValid = True
+                        break
+
+                if not transitionValid:
+                    continue
+
+                # Check if a transition in direction node[2] is possible if an agent lands in the current cell with orientation `direction'; this only applies to cells that are not dead-ends!
+                directionMatch = True
+                if enforce_target_direction>=0:
+                    directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), enforce_target_direction)
+
+                # If transition is found to invalid, check if perhaps it is a dead-end, in which case the direction of movement is rotated 180 degrees (moving forward turns the agents and makes it step in the previous cell)
+                if not directionMatch:
+                    # If cell is a dead-end, append previous node with reversed
+                    # orientation!
+                    nbits = 0
+                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # Dead-end!
+                        # Check if transition is possible in new_cell with orientation (direction+2)%4 in direction `direction'
+                        directionMatch = directionMatch or self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2)%4), direction)
+
+                if transitionValid and directionMatch:
+                    new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1]], current_distance+1)
+                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
+                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
+
+        return neighbors
+
+    def _new_position(self, position, movement):
+        if movement == 0:    # NORTH
+            return (position[0]-1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0]+1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+
 
     def get(self, handle):
         # TODO: compute the observation for agent `handle'
@@ -25,3 +130,196 @@ class TreeObsForRailEnv(ObservationBuilder):
         return []
 
 
+
+
+
+"""
+
+    def get_observation(self, agent):
+        # Get the current observation for an agent
+        current_position = self.internal_position[agent]
+        #target_heading = self._compass(agent).tolist()
+        coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
+        agent_distance = self.distance_map[agent][coordinate][0]
+        # Start tree search
+        if current_position == self.target[agent]:
+            agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
+        else:
+            agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])
+
+        initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
+        self._tree_search(initial_tree_state, agent_tree, agent)
+        observation = []
+        distance_data = []
+
+        self._flatten_tree(agent_tree, observation, distance_data,  self.max_depth+1)
+        # This is probably very slow!!!!
+        #max_obs = np.max([i for i in observation if i < np.inf])
+        #if max_obs != 0:
+        #    observation = np.array(observation)/ max_obs
+
+        #print([i for i in distance_data if i >= 0])
+        observation = np.concatenate((observation, distance_data))
+        #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
+        #return np.clip(observation / self.max_dist[agent], -1, 1)
+        return np.clip(observation / 15., -1, 1)
+
+
+
+
+    def _tree_search(self, in_tree_state, parent_node, agent):
+        if in_tree_state.depth >= self.max_depth:
+            return
+        target_distance = np.inf
+        other_target = np.inf
+        other_agent = np.inf
+        coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
+        curr_target_dist = self.distance_map[agent][coordinate][0]
+        forbidden_action = (in_tree_state.direction + 2) % 4
+        # Update the position
+        failed_move = 0
+        leaf_distance = in_tree_state.distance
+        for child_idx in range(4):
+            if child_idx != forbidden_action or in_tree_state.direction == -1:
+                tree_state = copy.deepcopy(in_tree_state)
+                tree_state.direction = child_idx
+                current_position, invalid_move = self._detect_path(tree_state.position, tree_state.direction)
+                if tree_state.initial_direction == None:
+                    tree_state.initial_direction = child_idx
+                if not invalid_move:
+                    coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
+                    curr_target_dist = self.distance_map[agent][coordinate][0]
+                    #if tree_state.initial_direction == None:
+                    #    tree_state.initial_direction = child_idx
+                    tree_state.position = current_position
+                    tree_state.distance += 1
+
+
+                    # Collect information at the current position
+                    detection_distance = tree_state.distance
+                    if current_position == self.target[tree_state.agent]:
+                        target_distance = detection_distance
+
+                    elif current_position in self.target:
+                        other_target = detection_distance
+
+                    if current_position in self.internal_position:
+                        other_agent = detection_distance
+
+                    tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
+                    tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
+                    tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
+                    tree_state.data[3] = tree_state.distance
+                    tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
+
+                    if self._switch_detection(tree_state.position):
+                        tree_state.depth += 1
+                        new_tree_state = copy.deepcopy(tree_state)
+                        new_node = parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
+                        new_tree_state.initial_direction = None
+                        new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
+                        self._tree_search(new_tree_state, new_node, agent)
+                    else:
+                        self._tree_search(tree_state, parent_node, agent)
+                else:
+                    failed_move += 1
+            if failed_move == 3 and in_tree_state.direction != -1:
+                tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
+                parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
+                return
+        return
+
+    def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
+        if depth <= 0:
+            return
+        if node != None:
+            observation_vector.extend(node.data[:-1])
+            distance_sensor.extend([node.data[-1]])
+        else:
+            observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
+            distance_sensor.extend([-np.inf])
+        for child_idx in range(4):
+            if node != None:
+                child = node.children[child_idx]
+            else:
+                child = None
+            self._flatten_tree(child, observation_vector, distance_sensor,  depth -1)
+
+
+
+    def _switch_detection(self, position):
+        # Hack to detect switches
+        # This can later directly be derived from the transition matrix
+        paths = 0
+        for i in range(4):
+            _, invalid_move = self._detect_path(position, i)
+            if not invalid_move:
+                paths +=1
+            if paths >= 3:
+                return True
+        return False
+
+
+
+
+    def _min_greater_zero(self, x, y):
+        if x <= 0 and y <= 0:
+            return 0
+        if x < 0:
+            return y
+        if y < 0:
+            return x
+        return min(x, y)
+
+
+
+"""
+
+
+class Tree_State:
+    """
+    Keep track of the current state while building the tree
+    """
+    def __init__(self, agent, position, direction, depth, distance):
+        self.agent = agent
+        self.position = position
+        self.direction = direction
+        self.depth = depth
+        self.initial_direction = None
+        self.distance = distance
+        self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
+
+class Node():
+    """
+    Define a tree node to get populated during search
+    """
+    def __init__(self, position, data):
+        self.n_children = 4
+        self.children = [None]*self.n_children
+        self.data = data
+        self.position = position
+
+    def insert(self, position, data, child_idx):
+        """
+        Insert new node with data
+
+        @param data node data object to insert
+        """
+        new_node = Node(position, data)
+        self.children[child_idx] = new_node
+        return new_node
+
+    def print_tree(self, i=0, depth = 0):
+        """
+        Print tree content inorder
+        """
+        current_i = i
+        curr_depth = depth+1
+        if i < self.n_children:
+            if self.children[i] != None:
+                self.children[i].print_tree(depth=curr_depth)
+            current_i += 1
+            if self.children[i] != None:
+                self.children[i].print_tree(i, depth=curr_depth)
+
+
-- 
GitLab