From 98eb461f3dd0db44e9adb164581bb0368e3a74cb Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 1 May 2019 16:53:45 +0200 Subject: [PATCH] Update: Action forward now always moves agent along when only one transition is possible --- flatland/envs/rail_env.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 7424860..713eb40 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -1021,7 +1021,8 @@ class RailEnv(Environment): nbits = 0 tmp = self.rail.get_transitions((pos[0], pos[1])) - print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction)) + possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction)) + # print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction)) while tmp > 0: nbits += (tmp & 1) @@ -1029,12 +1030,12 @@ class RailEnv(Environment): movement = direction if action == 1: movement = direction - 1 - if nbits <= 2: + if nbits <= 2 or np.sum(possible_transitions) <= 1: transition_isValid = False elif action == 3: movement = direction + 1 - if nbits <= 2: + if nbits <= 2 or np.sum(possible_transitions) <= 1: transition_isValid = False if movement < 0: movement += 4 @@ -1064,12 +1065,14 @@ class RailEnv(Environment): direction = reverse_direction movement = reverse_direction is_deadend = True - if nbits == 2: + if np.sum(possible_transitions) == 1: # Checking for curves - - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) + curv_dir = np.argmax(possible_transitions) + #valid_transition = self.rail.get_transition( + # (pos[0], pos[1], direction), + # movement) + movement = curv_dir + """ reverse_direction = (direction + 2) % 4 curv_dir = (movement + 1) % 4 while not valid_transition: @@ -1080,9 +1083,9 @@ class RailEnv(Environment): if valid_transition: movement = curv_dir curv_dir = (curv_dir + 1) % 4 - - + """ new_position = self._new_position(pos, movement) + print(pos,new_position) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell -- GitLab