From c0939ad62d7c7e93ddc615c6a0ffe24153abbad0 Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Fri, 19 Apr 2019 14:28:11 +0200 Subject: [PATCH] added features to treesearch --- examples/temporary_example.py | 8 ++-- flatland/core/env_observation_builder.py | 52 +++++++++++++++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 9d6046ec..03b5ebd0 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -26,14 +26,14 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], env = RailEnv(width=6, height=2, rail_generator=rail_from_manual_specifications_generator(specs), - number_of_agents=1, + number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) handle = env.get_agent_handles() -env.agents_position = [[1, 4]] -env.agents_target = [[1, 1]] -env.agents_direction = [1] +env.agents_position[0] = [1, 4] +env.agents_target[0] = [1, 1] +env.agents_direction[0] = 1 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! env.obs_builder.reset() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index e2e74b54..24f213ab 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -69,6 +69,11 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.env.number_of_agents): self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + # Update local lookup table for all agents' target locations + self.location_has_target = {} + for loc in self.env.agents_target: + self.location_has_target[(loc[0],loc[1])] = 1 + def _distance_map_walker(self, position, target_nr): """ Utility function to compute distance maps from each cell in the rail network (and each possible @@ -210,13 +215,14 @@ class TreeObsForRailEnv(ObservationBuilder): #1: - #2: + #2: 1 if a target of another agent is detected between the previous node and the current one. - #3: + #3: 1 if another agent is detected between the previous node and the current one. #4: - #5: minimum distance from node to the agent's target + #5: minimum distance from node to the agent's target (when landing to the node following the corresponding + branch. Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -226,6 +232,11 @@ class TreeObsForRailEnv(ObservationBuilder): In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ + # Update local lookup table for all agents' positions + self.location_has_agent = {} + for loc in self.env.agents_position: + self.location_has_agent[(loc[0], loc[1])] = 1 + position = self.env.agents_position[handle] orientation = self.env.agents_direction[handle] @@ -264,25 +275,31 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True last_isSwitch = False last_isDeadEnd = False - # TODO: last_isTerminal = False # dead-end - # TODO: last_isTarget = False + # TODO: last_isTerminal = False # wrong cell encountered + last_isTarget = False + + other_agent_encountered = False + other_target_encountered = False while exploring: # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. - # TODO: update the current variables according to the current cell in the path - # (store info about other agents and targets) + if position in self.location_has_agent: + other_agent_encountered = True + + if position in self.location_has_target: + other_target_encountered = True + - # TODO: [[[for efficiency, [make dict for hashed-lookup of coords] -- do it in the reset function!]]] # ############################# # ############################# # If the target node is encountered, pick that as node. Also, no further branching is possible. if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: - # TODO: last_isTarget = True + last_isTarget = True break cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) @@ -323,9 +340,24 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # Modify here to append new / different features for each visited cell! - observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], direction]] + if last_isTarget: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + 0, + 0] + + else: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + 0, + self.distance_map[handle, position[0], position[1], direction]] + + # TODO: + # ############################# # ############################# -- GitLab