From dceceb865cf9105cba83de7305384d233d4e9fd1 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Mon, 29 Apr 2019 15:28:33 +0200
Subject: [PATCH] added curves to rail_env.py step function

---
 examples/training_navigation.py |  6 +++---
 flatland/envs/rail_env.py       | 15 +++++++++++++++
 2 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 0c35fb14..41dcf779 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -30,13 +30,13 @@ env = RailEnv(width=20,
               rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=5)
 
-"""
+
 env = RailEnv(width=20,
               height=20,
               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
-                  ['../env-data/tests/test_rail.npy']),
+                  ['../env-data/tests/train_simple.npy']),
               number_of_agents=1)
-"""
+
 
 env_renderer = RenderTool(env, gl="QT")
 handle = env.get_agent_handles()
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 30e474d6..3171838c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -1053,6 +1053,21 @@ class RailEnv(Environment):
                             direction = reverse_direction
                             movement = reverse_direction
                             is_deadend = True
+                    if nbits == 2:
+                        # straigt or curve
+                        valid_transition = self.rail.get_transition(
+                            (pos[0], pos[1], direction),
+                            movement)
+                        reverse_direction = (direction + 2) % 4
+                        curv_dir = (movement + 1) % 4
+                        while not valid_transition:
+                                if curv_dir != reverse_direction:
+                                    valid_transition = self.rail.get_transition(
+                                        (pos[0], pos[1], direction),
+                                        curv_dir)
+                                curv_dir = (curv_dir+1) % 4
+                                if valid_transition:
+                                    movement = curv_dir
                 new_position = self._new_position(pos, movement)
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
-- 
GitLab