Skip to content
Snippets Groups Projects
Commit 1e2d8f26 authored by spiglerg's avatar spiglerg
Browse files

fixed bad bugs in distance_map calculation + added distance from agent to branch node

parent 530db333
No related branches found
No related tags found
No related merge requests found
...@@ -127,53 +127,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -127,53 +127,6 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
neighbors = [] neighbors = []
for direction in range(4):
new_cell = self._new_position(position, (direction+2) % 4)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition
transitionValid = False
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
transitionValid = True
break
if not transitionValid:
continue
# Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
# If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if not directionMatch:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2) % 4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
possible_directions = [0, 1, 2, 3] possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0: if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'. # The agent must land into the current cell with orientation `enforce_target_direction'.
...@@ -263,7 +216,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -263,7 +216,7 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: 1 if another agent is detected between the previous node and the current one. #3: 1 if another agent is detected between the previous node and the current one.
#4: #4: distance of agent to the current branch node
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch. branch.
...@@ -286,6 +239,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -286,6 +239,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position # Root node - current position
observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
root_observation = observation[:]
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation # organize them as [left, forward, right, back], relative to the current orientation
...@@ -293,7 +247,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -293,7 +247,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1) branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
observation = observation + branch_observation observation = observation + branch_observation
else: else:
num_cells_to_fill_in = 0 num_cells_to_fill_in = 0
...@@ -305,7 +259,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -305,7 +259,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation return observation
def _explore_branch(self, handle, position, direction, depth): def _explore_branch(self, handle, position, direction, root_observation, depth):
""" """
Utility function to compute tree-based observations. Utility function to compute tree-based observations.
""" """
...@@ -323,10 +277,11 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -323,10 +277,11 @@ class TreeObsForRailEnv(ObservationBuilder):
# to land here # to land here
last_isTarget = False last_isTarget = False
visited = set([position[0], position[1], direction]) visited = set()
other_agent_encountered = False other_agent_encountered = False
other_target_encountered = False other_target_encountered = False
num_steps = 1
while exploring: while exploring:
# ############################# # #############################
# ############################# # #############################
...@@ -345,6 +300,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -345,6 +300,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if (position[0], position[1], direction) in visited: if (position[0], position[1], direction) in visited:
last_isTerminal = True last_isTerminal = True
break break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible. # If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
...@@ -377,6 +334,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -377,6 +334,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if cell_transitions[i]: if cell_transitions[i]:
position = self._new_position(position, i) position = self._new_position(position, i)
direction = i direction = i
num_steps += 1
break break
elif num_transitions > 0: elif num_transitions > 0:
...@@ -386,11 +344,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -386,11 +344,10 @@ class TreeObsForRailEnv(ObservationBuilder):
elif num_transitions == 0: elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
last_isTerminal = True last_isTerminal = True
break break
visited.add((position[0], position[1], direction))
# `position' is either a terminal node or a switch # `position' is either a terminal node or a switch
observation = [] observation = []
...@@ -403,25 +360,27 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -403,25 +360,27 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [0, observation = [0,
1 if other_target_encountered else 0, 1 if other_target_encountered else 0,
1 if other_agent_encountered else 0, 1 if other_agent_encountered else 0,
0, root_observation[3]+num_steps,
0] 0]
elif last_isTerminal: elif last_isTerminal:
observation = [0, observation = [0,
1 if other_target_encountered else 0, 1 if other_target_encountered else 0,
1 if other_agent_encountered else 0, 1 if other_agent_encountered else 0,
0, np.inf,
np.inf] np.inf]
else: else:
observation = [0, observation = [0,
1 if other_target_encountered else 0, 1 if other_target_encountered else 0,
1 if other_agent_encountered else 0, 1 if other_agent_encountered else 0,
0, root_observation[3]+num_steps,
self.distance_map[handle, position[0], position[1], direction]] self.distance_map[handle, position[0], position[1], direction]]
# ############################# # #############################
# ############################# # #############################
new_root_observation = observation[:]
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation # organize them as [left, forward, right, back], relative to the current orientation
for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
...@@ -431,14 +390,14 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -431,14 +390,14 @@ class TreeObsForRailEnv(ObservationBuilder):
# it back # it back
new_cell = self._new_position(position, (branch_direction+2) % 4) new_cell = self._new_position(position, (branch_direction+2) % 4)
branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1) branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, new_root_observation, depth+1)
observation = observation + branch_observation observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
branch_direction): branch_direction):
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1) branch_observation = self._explore_branch(handle, new_cell, branch_direction, new_root_observation, depth+1)
observation = observation + branch_observation observation = observation + branch_observation
else: else:
......
...@@ -486,6 +486,8 @@ class RailEnv(Environment): ...@@ -486,6 +486,8 @@ class RailEnv(Environment):
for handle in self.agents_handles: for handle in self.agents_handles:
self.dones[handle] = False self.dones[handle] = False
# Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
# agent's orientations that allow a valid solution.
re_generate = True re_generate = True
while re_generate: while re_generate:
valid_positions = [] valid_positions = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment