From a712961050c185022444d676eb8c38bd4131d9f6 Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Tue, 23 Apr 2019 12:24:29 +0200 Subject: [PATCH] important fixes to treesearch --- examples/temporary_example.py | 29 +++++++++++++++++++++--- flatland/core/env_observation_builder.py | 26 ++++++++++++++++----- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 02c282c..67d0fe6 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -17,7 +17,7 @@ transition_probability = [1.0, # empty cell - Case 0 0.5, # Case 5 - double slip 0.2, # Case 6 - symmetrical 0.0] # Case 7 - dead end - +""" # Example generate a random rail env = RailEnv(width=20, height=20, @@ -33,7 +33,7 @@ env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - +""" """ # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) @@ -52,12 +52,35 @@ env.agents_target[0] = [1, 1] env.agents_direction[0] = 1 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! env.obs_builder.reset() +""" + + +specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], + [(7, 270), (1, 90), (1, 90), (2, 270), (2, 0), (0, 0)], + [(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)], + [(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]] + +env = RailEnv(width=6, + height=4, + rail_generator=rail_from_manual_specifications_generator(specs), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +handle = env.get_agent_handles() +env.agents_position[0] = [1, 3] +env.agents_target[0] = [1, 1] +env.agents_direction[0] = 1 +# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! +env.obs_builder.reset() + + + """ env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=2) - +""" # Print the distance map of each cell to the target of the first agent # for i in range(4): # print(env.obs_builder.distance_map[0, :, :, i]) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index b0730d8..fd71bb9 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -169,10 +169,10 @@ class TreeObsForRailEnv(ObservationBuilder): (new_cell[0], new_cell[1], (direction+2) % 4), direction) if transitionValid and directionMatch: - new_distance = min(self.distance_map[target_nr, - new_cell[0], new_cell[1]], current_distance+1) + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction], + current_distance+1) neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance + self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance possible_directions = [0, 1, 2, 3] if enforce_target_direction >= 0: @@ -319,13 +319,15 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True last_isSwitch = False last_isDeadEnd = False - # TODO: last_isTerminal = False # wrong cell encountered + last_isTerminal = False # wrong cell encountered OR cycle encountered; either way, we don't want the agent + # to land here last_isTarget = False + visited = set([position[0], position[1], direction]) + other_agent_encountered = False other_target_encountered = False while exploring: - # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called @@ -340,6 +342,10 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # ############################# + if (position[0], position[1], direction) in visited: + last_isTerminal = True + break + # If the target node is encountered, pick that as node. Also, no further branching is possible. if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: last_isTarget = True @@ -380,9 +386,11 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - # TODO: last_isTerminal = True + last_isTerminal = True break + visited.add((position[0], position[1], direction)) + # `position' is either a terminal node or a switch observation = [] @@ -398,6 +406,12 @@ class TreeObsForRailEnv(ObservationBuilder): 0, 0] + elif last_isTerminal: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + 0, + np.inf] else: observation = [0, 1 if other_target_encountered else 0, -- GitLab