From 98eb461f3dd0db44e9adb164581bb0368e3a74cb Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 1 May 2019 16:53:45 +0200
Subject: [PATCH] Update: Action forward now always moves agent along when only
 one transition is possible

---
 flatland/envs/rail_env.py | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 7424860..713eb40 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -1021,7 +1021,8 @@ class RailEnv(Environment):
 
                 nbits = 0
                 tmp = self.rail.get_transitions((pos[0], pos[1]))
-                print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction))
+                possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
+                # print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction))
 
                 while tmp > 0:
                     nbits += (tmp & 1)
@@ -1029,12 +1030,12 @@ class RailEnv(Environment):
                 movement = direction
                 if action == 1:
                     movement = direction - 1
-                    if nbits <= 2:
+                    if nbits <= 2 or np.sum(possible_transitions) <= 1:
                         transition_isValid = False
 
                 elif action == 3:
                     movement = direction + 1
-                    if nbits <= 2:
+                    if nbits <= 2 or np.sum(possible_transitions) <= 1:
                         transition_isValid = False
                 if movement < 0:
                     movement += 4
@@ -1064,12 +1065,14 @@ class RailEnv(Environment):
                             direction = reverse_direction
                             movement = reverse_direction
                             is_deadend = True
-                    if nbits == 2:
+                    if np.sum(possible_transitions) == 1:
                         # Checking for curves
-
-                        valid_transition = self.rail.get_transition(
-                            (pos[0], pos[1], direction),
-                            movement)
+                        curv_dir = np.argmax(possible_transitions)
+                        #valid_transition = self.rail.get_transition(
+                        #    (pos[0], pos[1], direction),
+                        #    movement)
+                        movement = curv_dir
+                        """
                         reverse_direction = (direction + 2) % 4
                         curv_dir = (movement + 1) % 4
                         while not valid_transition:
@@ -1080,9 +1083,9 @@ class RailEnv(Environment):
                             if valid_transition:
                                 movement = curv_dir
                             curv_dir = (curv_dir + 1) % 4
-
-
+                        """
                 new_position = self._new_position(pos, movement)
+                print(pos,new_position)
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
                 # free, i.e., no agent is currently in that cell
-- 
GitLab