From 993e6a9d5e748985d13aebd43cda047e6c083035 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Mon, 29 Apr 2019 15:42:48 +0200 Subject: [PATCH] Bug fix in handling curve. Now you can pass a curve by going straigt! --- env-data/tests/test1.npy | Bin 328 -> 328 bytes examples/training_navigation.py | 4 ++-- flatland/envs/rail_env.py | 21 ++++++++++++--------- notebooks/CanvasEditor.ipynb | 26 +++++++++++++------------- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/env-data/tests/test1.npy b/env-data/tests/test1.npy index 77e0288589171b8b03d828423ca456f2ac8395e3..f0cff3c9a1260facf073b88702da3f0557ab32f0 100644 GIT binary patch literal 328 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZQ);BHqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-2099c2AVnwwF+bcE(WZC$$^mt0vUv03KSY}u|TQ?IAF3F4In)T)gW<& g1`uEA0Kznw*aEN&OfN(ogoKGTFaWgzF^I$s0H)(9ivR!s literal 328 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZQ);BHqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-2099c2AVnwwF+bcE(Qh$1t5lyAQlKSC^UdbF9)U|K^~?c<s%G?FtG*( zkd&7LNQ{Lcf`i!wNC9O*Vv5sXEJhX(1nCy!frx=kQ~;X-BpDchnpgy+Aa)2yLF@yG tDNF*1DNIsm$Y@{+0@=arqI`sr#rTZEq>Ki|mdqBAc@S5CNC;$L008-cH7o!C diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 41dcf77..2704e84 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -34,7 +34,7 @@ env = RailEnv(width=20, env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - ['../env-data/tests/train_simple.npy']), + ['../env-data/tests/test1.npy']), number_of_agents=1) @@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1): for a in range(env.number_of_agents): if demo: eps = 0 - action = agent.act(np.array(obs[a]), eps=eps) + action = 2# agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index bb754ea..9b6ee96 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -1024,7 +1024,7 @@ class RailEnv(Environment): is_deadend = False if action == 2: - # TODO: Check if cell is curve --> Compute correct transition and change in orientation + # compute number of possible transitions in the current # cell nbits = 0 @@ -1054,20 +1054,23 @@ class RailEnv(Environment): movement = reverse_direction is_deadend = True if nbits == 2: - # straigt or curve + # Checking for curves + valid_transition = self.rail.get_transition( (pos[0], pos[1], direction), movement) reverse_direction = (direction + 2) % 4 curv_dir = (movement + 1) % 4 while not valid_transition: - if curv_dir != reverse_direction: - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - curv_dir) - curv_dir = (curv_dir+1) % 4 - if valid_transition: - movement = curv_dir + if curv_dir != reverse_direction: + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), + curv_dir) + if valid_transition: + print("Curve") + movement = curv_dir + curv_dir = (curv_dir + 1) % 4 + new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is diff --git a/notebooks/CanvasEditor.ipynb b/notebooks/CanvasEditor.ipynb index 159b85a..9856716 100644 --- a/notebooks/CanvasEditor.ipynb +++ b/notebooks/CanvasEditor.ipynb @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -138,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -147,7 +147,7 @@ "10" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -165,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -182,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -302,7 +302,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": { "scrolled": false }, @@ -310,7 +310,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "15e10f8984c44a529cd86cc57c697c7f", + "model_id": "0fe2a192b21c4eebb7aac3af8f3cb02e", "version_major": 2, "version_minor": 0 }, -- GitLab