Skip to content
Snippets Groups Projects
Commit 1e2d8f26 authored by spiglerg's avatar spiglerg
Browse files

fixed bad bugs in distance_map calculation + added distance from agent to branch node

parent 530db333
No related branches found
No related tags found
No related merge requests found
......@@ -127,53 +127,6 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
neighbors = []
for direction in range(4):
new_cell = self._new_position(position, (direction+2) % 4)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition
transitionValid = False
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
transitionValid = True
break
if not transitionValid:
continue
# Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
# If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if not directionMatch:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2) % 4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
......@@ -263,7 +216,7 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: 1 if another agent is detected between the previous node and the current one.
#4:
#4: distance of agent to the current branch node
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
......@@ -286,6 +239,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
......@@ -293,7 +247,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
observation = observation + branch_observation
else:
num_cells_to_fill_in = 0
......@@ -305,7 +259,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation
def _explore_branch(self, handle, position, direction, depth):
def _explore_branch(self, handle, position, direction, root_observation, depth):
"""
Utility function to compute tree-based observations.
"""
......@@ -323,10 +277,11 @@ class TreeObsForRailEnv(ObservationBuilder):
# to land here
last_isTarget = False
visited = set([position[0], position[1], direction])
visited = set()
other_agent_encountered = False
other_target_encountered = False
num_steps = 1
while exploring:
# #############################
# #############################
......@@ -345,6 +300,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if (position[0], position[1], direction) in visited:
last_isTerminal = True
break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
......@@ -377,6 +334,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if cell_transitions[i]:
position = self._new_position(position, i)
direction = i
num_steps += 1
break
elif num_transitions > 0:
......@@ -386,11 +344,10 @@ class TreeObsForRailEnv(ObservationBuilder):
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
last_isTerminal = True
break
visited.add((position[0], position[1], direction))
# `position' is either a terminal node or a switch
observation = []
......@@ -403,25 +360,27 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
root_observation[3]+num_steps,
0]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
np.inf,
np.inf]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
root_observation[3]+num_steps,
self.distance_map[handle, position[0], position[1], direction]]
# #############################
# #############################
new_root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
......@@ -431,14 +390,14 @@ class TreeObsForRailEnv(ObservationBuilder):
# it back
new_cell = self._new_position(position, (branch_direction+2) % 4)
branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1)
branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, new_root_observation, depth+1)
observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
branch_direction):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, new_root_observation, depth+1)
observation = observation + branch_observation
else:
......
......@@ -486,6 +486,8 @@ class RailEnv(Environment):
for handle in self.agents_handles:
self.dones[handle] = False
# Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
# agent's orientations that allow a valid solution.
re_generate = True
while re_generate:
valid_positions = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment