Commit 98eb461f authored by Erik Nygren's avatar Erik Nygren
Browse files

Update: Action forward now always moves agent along when only one transition is possible

parent b1e76401
Pipeline #469 failed with stage
in 1 minute and 54 seconds
...@@ -1021,7 +1021,8 @@ class RailEnv(Environment): ...@@ -1021,7 +1021,8 @@ class RailEnv(Environment):
nbits = 0 nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1])) tmp = self.rail.get_transitions((pos[0], pos[1]))
print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction)) possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
# print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction))
while tmp > 0: while tmp > 0:
nbits += (tmp & 1) nbits += (tmp & 1)
...@@ -1029,12 +1030,12 @@ class RailEnv(Environment): ...@@ -1029,12 +1030,12 @@ class RailEnv(Environment):
movement = direction movement = direction
if action == 1: if action == 1:
movement = direction - 1 movement = direction - 1
if nbits <= 2: if nbits <= 2 or np.sum(possible_transitions) <= 1:
transition_isValid = False transition_isValid = False
elif action == 3: elif action == 3:
movement = direction + 1 movement = direction + 1
if nbits <= 2: if nbits <= 2 or np.sum(possible_transitions) <= 1:
transition_isValid = False transition_isValid = False
if movement < 0: if movement < 0:
movement += 4 movement += 4
...@@ -1064,12 +1065,14 @@ class RailEnv(Environment): ...@@ -1064,12 +1065,14 @@ class RailEnv(Environment):
direction = reverse_direction direction = reverse_direction
movement = reverse_direction movement = reverse_direction
is_deadend = True is_deadend = True
if nbits == 2: if np.sum(possible_transitions) == 1:
# Checking for curves # Checking for curves
curv_dir = np.argmax(possible_transitions)
valid_transition = self.rail.get_transition( #valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction), # (pos[0], pos[1], direction),
movement) # movement)
movement = curv_dir
"""
reverse_direction = (direction + 2) % 4 reverse_direction = (direction + 2) % 4
curv_dir = (movement + 1) % 4 curv_dir = (movement + 1) % 4
while not valid_transition: while not valid_transition:
...@@ -1080,9 +1083,9 @@ class RailEnv(Environment): ...@@ -1080,9 +1083,9 @@ class RailEnv(Environment):
if valid_transition: if valid_transition:
movement = curv_dir movement = curv_dir
curv_dir = (curv_dir + 1) % 4 curv_dir = (curv_dir + 1) % 4
"""
new_position = self._new_position(pos, movement) new_position = self._new_position(pos, movement)
print(pos,new_position)
# Is it a legal move? 1) transition allows the movement in the # Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is # cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell # free, i.e., no agent is currently in that cell
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment