From 993e6a9d5e748985d13aebd43cda047e6c083035 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Mon, 29 Apr 2019 15:42:48 +0200
Subject: [PATCH] Bug fix in handling curve. Now you can pass a curve by going
 straigt!

---
 env-data/tests/test1.npy        | Bin 328 -> 328 bytes
 examples/training_navigation.py |   4 ++--
 flatland/envs/rail_env.py       |  21 ++++++++++++---------
 notebooks/CanvasEditor.ipynb    |  26 +++++++++++++-------------
 4 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/env-data/tests/test1.npy b/env-data/tests/test1.npy
index 77e0288589171b8b03d828423ca456f2ac8395e3..f0cff3c9a1260facf073b88702da3f0557ab32f0 100644
GIT binary patch
literal 328
zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZQ);BHqoAIaUsO_*m=~X4l#&V(cT3DEP6dh=
zXCxM+0{I$-2099c2AVnwwF+bcE(WZC$$^mt0vUv03KSY}u|TQ?IAF3F4In)T)gW<&
g1`uEA0Kznw*aEN&OfN(ogoKGTFaWgzF^I$s0H)(9ivR!s

literal 328
zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZQ);BHqoAIaUsO_*m=~X4l#&V(cT3DEP6dh=
zXCxM+0{I$-2099c2AVnwwF+bcE(Qh$1t5lyAQlKSC^UdbF9)U|K^~?c<s%G?FtG*(
zkd&7LNQ{Lcf`i!wNC9O*Vv5sXEJhX(1nCy!frx=kQ~;X-BpDchnpgy+Aa)2yLF@yG
tDNF*1DNIsm$Y@{+0@=arqI`sr#rTZEq>Ki|mdqBAc@S5CNC;$L008-cH7o!C

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 41dcf77..2704e84 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -34,7 +34,7 @@ env = RailEnv(width=20,
 env = RailEnv(width=20,
               height=20,
               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
-                  ['../env-data/tests/train_simple.npy']),
+                  ['../env-data/tests/test1.npy']),
               number_of_agents=1)
 
 
@@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1):
         for a in range(env.number_of_agents):
             if demo:
                 eps = 0
-            action = agent.act(np.array(obs[a]), eps=eps)
+            action = 2# agent.act(np.array(obs[a]), eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})
         # Environment step
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index bb754ea..9b6ee96 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -1024,7 +1024,7 @@ class RailEnv(Environment):
 
                 is_deadend = False
                 if action == 2:
-                    # TODO: Check if cell is curve --> Compute correct transition and change in orientation
+
                     # compute number of possible transitions in the current
                     # cell
                     nbits = 0
@@ -1054,20 +1054,23 @@ class RailEnv(Environment):
                             movement = reverse_direction
                             is_deadend = True
                     if nbits == 2:
-                        # straigt or curve
+                        # Checking for curves
+
                         valid_transition = self.rail.get_transition(
                             (pos[0], pos[1], direction),
                             movement)
                         reverse_direction = (direction + 2) % 4
                         curv_dir = (movement + 1) % 4
                         while not valid_transition:
-                                if curv_dir != reverse_direction:
-                                    valid_transition = self.rail.get_transition(
-                                        (pos[0], pos[1], direction),
-                                        curv_dir)
-                                curv_dir = (curv_dir+1) % 4
-                                if valid_transition:
-                                    movement = curv_dir
+                            if curv_dir != reverse_direction:
+                                valid_transition = self.rail.get_transition(
+                                    (pos[0], pos[1], direction),
+                                    curv_dir)
+                            if valid_transition:
+                                print("Curve")
+                                movement = curv_dir
+                            curv_dir = (curv_dir + 1) % 4
+
                 new_position = self._new_position(pos, movement)
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
diff --git a/notebooks/CanvasEditor.ipynb b/notebooks/CanvasEditor.ipynb
index 159b85a..9856716 100644
--- a/notebooks/CanvasEditor.ipynb
+++ b/notebooks/CanvasEditor.ipynb
@@ -57,7 +57,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [
     {
@@ -79,7 +79,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [
     {
@@ -102,7 +102,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -116,7 +116,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -129,7 +129,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -138,7 +138,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [
     {
@@ -147,7 +147,7 @@
        "10"
       ]
      },
-     "execution_count": 10,
+     "execution_count": 9,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -165,7 +165,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -182,7 +182,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
@@ -222,7 +222,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
     {
@@ -249,7 +249,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -302,7 +302,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 14,
    "metadata": {
     "scrolled": false
    },
@@ -310,7 +310,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "15e10f8984c44a529cd86cc57c697c7f",
+       "model_id": "0fe2a192b21c4eebb7aac3af8f3cb02e",
        "version_major": 2,
        "version_minor": 0
       },
-- 
GitLab