From d5ea50cded849574561fa251f9519d237f01d703 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 08:20:39 -0400
Subject: [PATCH] updated observation builder as well as normalization!!! You
 need to adjust your code to work with the new features!

---
 .../observation_builders/observations.py      | 345 +++++++++++++++++-
 utils/observation_utils.py                    |   5 +-
 2 files changed, 331 insertions(+), 19 deletions(-)

diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py
index e3d52d3..10bd1f0 100644
--- a/torch_training/observation_builders/observations.py
+++ b/torch_training/observation_builders/observations.py
@@ -25,14 +25,13 @@ class TreeObsForRailEnv(ObservationBuilder):
     def __init__(self, max_depth, predictor=None):
         super().__init__()
         self.max_depth = max_depth
-        self.observation_dim = 9
+        self.observation_dim = 11
         # Compute the size of the returned observation vector
         size = 0
         pow4 = 1
         for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
-        self.observation_dim = 9
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
@@ -219,7 +218,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
 
-        #2: if another agents target is detected the distance in number of cells from the agents current locaiton
+        #2: if another agents target is detected the distance in number of cells from the agents current location
             is stored
 
         #3: if another agent is detected the distance in number of cells from current agent position is stored.
@@ -246,6 +245,15 @@ class TreeObsForRailEnv(ObservationBuilder):
                 (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
             0 = no agent present other direction than myself
 
+        #10: malfunctioning/blokcing agents
+            n = number of time steps the oberved agent remains blocked
+
+        #11: slowest observed speed of an agent in same direction
+            1 if no agent is observed
+
+            min_fractional speed otherwise
+
+
 
 
 
@@ -253,13 +261,17 @@ class TreeObsForRailEnv(ObservationBuilder):
         Missing values in present node are filled in with +inf (truncated).
 
 
-        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
         In case the target node is reached, the values are [0, 0, 0, 0, 0].
         """
 
         # Update local lookup table for all agents' positions
         self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
         self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
+        self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
+        self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
+                                               self.env.agents}
+
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
@@ -267,9 +279,12 @@ class TreeObsForRailEnv(ObservationBuilder):
         num_transitions = np.count_nonzero(possible_transitions)
 
         # Root node - current position
-        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
+        # Here information about the agent itself is stored
+        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
+                       agent.malfunction_data['malfunction'], agent.speed_data['speed']]
 
         visited = set()
+
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # If only one transition is possible, the tree is oriented with this transition as the forward branch.
@@ -289,6 +304,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 # add cells filled with infinity if no transition is possible
                 observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
         self.env.dev_obs_dict[handle] = visited
+
         return observation
 
     def _num_cells_to_fill_in(self, remaining_depth):
@@ -306,6 +322,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         We walk along the branch and collect the information documented in the get() function.
         If there is a branching point a new node is created and each possible branch is explored.
         """
+
         # [Recursive branch opened]
         if depth >= self.max_depth + 1:
             return [], []
@@ -321,6 +338,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         visited = set()
         agent = self.env.agents[handle]
+        time_per_cell = np.reciprocal(agent.speed_data["speed"])
         own_target_encountered = np.inf
         other_agent_encountered = np.inf
         other_target_encountered = np.inf
@@ -328,7 +346,8 @@ class TreeObsForRailEnv(ObservationBuilder):
         unusable_switch = np.inf
         other_agent_same_direction = 0
         other_agent_opposite_direction = 0
-
+        malfunctioning_agent = 0
+        min_fractional_speed = 1.
         num_steps = 1
         while exploring:
             # #############################
@@ -339,10 +358,19 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if tot_dist < other_agent_encountered:
                     other_agent_encountered = tot_dist
 
+                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
+                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
+                    malfunctioning_agent = self.location_has_agent_malfunction[position]
+
                 if self.location_has_agent_direction[position] == direction:
                     # Cummulate the number of agents on branch with same direction
                     other_agent_same_direction += 1
 
+                    # Check fractional speed of agents
+                    current_fractional_speed = self.location_has_agent_speed[position]
+                    if current_fractional_speed < min_fractional_speed:
+                        min_fractional_speed = current_fractional_speed
+
                 if self.location_has_agent_direction[position] != direction:
                     # Cummulate the number of agents on branch with other direction
                     other_agent_opposite_direction += 1
@@ -356,17 +384,21 @@ class TreeObsForRailEnv(ObservationBuilder):
                 crossing_found = True
 
             # Register possible future conflict
-            if self.predictor and num_steps < self.max_prediction_depth:
+            predicted_time = int(tot_dist * time_per_cell)
+            if self.predictor and predicted_time < self.max_prediction_depth:
                 int_position = coordinate_to_position(self.env.width, [position])
                 if tot_dist < self.max_prediction_depth:
-                    pre_step = max(0, tot_dist - 1)
-                    post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
 
-                    # Look for conflicting paths at distance num_step
-                    if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
+                    pre_step = max(0, predicted_time - 1)
+                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
+
+                    # Look for conflicting paths at distance tot_dist
+                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
+                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
+                                self._reverse_dir(
+                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -375,7 +407,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                     elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[pre_step][ca] \
+                                    and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
+                                    and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -384,7 +418,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                     elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
+                                    self.predicted_dir[post_step][ca])] == 1 \
+                                    and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -462,7 +498,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                            tot_dist,
                            0,
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           malfunctioning_agent,
+                           min_fractional_speed
                            ]
 
         elif last_is_terminal:
@@ -474,7 +512,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                            np.inf,
                            self.distance_map[handle, position[0], position[1], direction],
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           malfunctioning_agent,
+                           min_fractional_speed
                            ]
 
         else:
@@ -487,6 +527,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            self.distance_map[handle, position[0], position[1], direction],
                            other_agent_same_direction,
                            other_agent_opposite_direction,
+                           malfunctioning_agent,
+                           min_fractional_speed
                            ]
         # #############################
         # #############################
@@ -565,3 +607,272 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.env = env
         if self.predictor:
             self.predictor._set_env(self.env)
+
+    def _reverse_dir(self, direction):
+        return int((direction + 2) % 4)
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
+         target and the positions of the other agents targets.
+
+        - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
+          of the direction of the given agent and the 4 second channels containing the positions
+          of the other agents at their position coordinates.
+    """
+
+    def __init__(self):
+        self.observation_space = ()
+        super(GlobalObsForRailEnv, self).__init__()
+
+    def _set_env(self, env):
+        super()._set_env(env)
+
+        self.observation_space = [4, self.env.height, self.env.width]
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle):
+        obs_targets = np.zeros((self.env.height, self.env.width, 2))
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
+        agents = self.env.agents
+        agent = agents[handle]
+
+        direction = np.zeros(4)
+        direction[agent.direction] = 1
+        agent_pos = agents[handle].position
+        obs_agents_state[agent_pos][:4] = direction
+        obs_targets[agent.target][0] += 1
+
+        for i in range(len(agents)):
+            if i != handle:  # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs_agents_state[agent2.position][4 + agent2.direction] = 1
+                obs_targets[agent2.target][1] += 1
+
+        direction = self._get_one_hot_for_agent_direction(agent)
+
+        return self.rail_obs, obs_agents_state, obs_targets, direction
+
+
+class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions, flipped in the direction of the agent
+          (the agent is always heading north on the flipped view).
+
+        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
+         target and the positions of the other agents targets, also flipped depending on the agent's direction.
+
+        - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
+          agents at their position coordinates, and the last channel containing the position of the given agent.
+
+        - A 4 elements array with one hot encoding of the direction.
+    """
+
+    def __init__(self):
+        self.observation_space = ()
+        super(GlobalObsForRailEnvDirectionDependent, self).__init__()
+
+    def _set_env(self, env):
+        super()._set_env(env)
+
+        self.observation_space = [4, self.env.height, self.env.width]
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle):
+        obs_targets = np.zeros((self.env.height, self.env.width, 2))
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
+        agents = self.env.agents
+        agent = agents[handle]
+        direction = agent.direction
+
+        idx = np.tile(np.arange(16), 2)
+
+        rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
+
+        if direction == 1:
+            rail_obs = np.flip(rail_obs, axis=1)
+        elif direction == 2:
+            rail_obs = np.flip(rail_obs)
+        elif direction == 3:
+            rail_obs = np.flip(rail_obs, axis=0)
+
+        agent_pos = agents[handle].position
+        obs_agents_state[agent_pos][0] = 1
+        obs_targets[agent.target][0] += 1
+
+        idx = np.tile(np.arange(4), 2)
+        for i in range(len(agents)):
+            if i != handle:  # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
+                obs_targets[agent2.target][1] += 1
+
+        direction = self._get_one_hot_for_agent_direction(agent)
+
+        return rail_obs, obs_agents_state, obs_targets, direction
+
+
+class LocalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a local observation of the rail environment around the agent.
+    The observation is composed of the following elements:
+
+        - transition map array of the local environment around the given agent,
+          with dimensions (view_height,2*view_width+1, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
+        if they are in the agent's vision range, its target position, the positions of the other targets.
+
+        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
+          of the other agents at their position coordinates, if they are in the agent's vision range.
+
+        - A 4 elements array with one hot encoding of the direction.
+
+    Use the parameters view_width and view_height to define the rectangular view of the agent.
+    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
+    observation in front of it.
+    """
+
+    def __init__(self, view_width, view_height, center):
+
+        super(LocalObsForRailEnv, self).__init__()
+        self.view_width = view_width
+        self.view_height = view_height
+        self.center = center
+        self.max_padding = max(self.view_width, self.view_height - self.center)
+
+    def reset(self):
+        # We build the transition map with a view_radius empty cells expansion on each side.
+        # This helps to collect the local transition map view when the agent is close to a border.
+        self.max_padding = max(self.view_width, self.view_height)
+        self.rail_obs = np.zeros((self.env.height,
+                                  self.env.width, 16))
+        for i in range(self.env.height):
+            for j in range(self.env.width):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle):
+        agents = self.env.agents
+        agent = agents[handle]
+
+        # Correct agents position for padding
+        # agent_rel_pos[0] = agent.position[0] + self.max_padding
+        # agent_rel_pos[1] = agent.position[1] + self.max_padding
+
+        # Collect visible cells as set to be plotted
+        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
+        local_rail_obs = None
+
+        # Add the visible cells to the observed cells
+        self.env.dev_obs_dict[handle] = set(visited)
+
+        # Locate observed agents and their coresponding targets
+        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
+        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
+        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
+        _idx = 0
+        for pos in visited:
+            curr_rel_coord = rel_coords[_idx]
+            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
+            if pos == agent.target:
+                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
+            else:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.target:
+                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
+            if pos != agent.position:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.position:
+                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
+                            tmp_agent.direction]
+
+            _idx += 1
+
+        direction = np.identity(4)[agent.direction]
+        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
+
+    def get_many(self, handles=None):
+        """
+        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
+        in the `handles' list.
+        """
+
+        observations = {}
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
+
+    def field_of_view(self, position, direction, state=None):
+        # Compute the local field of view for an agent in the environment
+        data_collection = False
+        if state is not None:
+            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
+            data_collection = True
+        if direction == 0:
+            origin = (position[0] + self.center, position[1] - self.view_width)
+        elif direction == 1:
+            origin = (position[0] - self.view_width, position[1] - self.center)
+        elif direction == 2:
+            origin = (position[0] - self.center, position[1] + self.view_width)
+        else:
+            origin = (position[0] + self.view_width, position[1] + self.center)
+        visible = list()
+        rel_coords = list()
+        for h in range(self.view_height):
+            for w in range(2 * self.view_width + 1):
+                if direction == 0:
+                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                        visible.append((origin[0] - h, origin[1] + w))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
+                elif direction == 1:
+                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
+                        visible.append((origin[0] + w, origin[1] + h))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
+                elif direction == 2:
+                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
+                        visible.append((origin[0] + h, origin[1] - w))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
+                else:
+                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
+                        visible.append((origin[0] - w, origin[1] - h))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
+        if data_collection:
+            return temp_visible_data
+        else:
+            return visible, rel_coords
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 7891c28..b4badeb 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -97,11 +97,12 @@ def split_tree(tree, num_features_per_node, current_depth=0):
             agent_data.extend(tmp_agent_data)
     return tree_data, distance_data, agent_data
 
-def normalize_observation(observation, num_features_per_node=9, observation_radius=0):
+
+def normalize_observation(observation, num_features_per_node=11, observation_radius=0):
     data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node,
                                             current_depth=0)
     data = norm_obs_clip(data, fixed_radius=observation_radius)
     distance = norm_obs_clip(distance, normalize_to_range=True)
-    agent_data = np.clip(agent_data, -1, 1)
+    agent_data = np.clip(agent_data, -1, 20)
     normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
     return normalized_obs
-- 
GitLab