From 8cec2335efd6f3f26434c05eda71a3d87ab2b13a Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 1 May 2019 08:34:49 +0200
Subject: [PATCH] updated action behavior at curve. deviation left and right
 not allowed at curve

---
 examples/training_navigation.py          |  2 +-
 flatland/core/env_observation_builder.py |  2 +-
 flatland/envs/rail_env.py                | 15 ++++++++++-----
 3 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index e21afd9c..9fc83242 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -112,7 +112,7 @@ for trials in range(1, n_trials + 1):
             action = agent.act(np.array(obs[a]), eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})
-            env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5)
+            #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5)
         # Environment step
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.number_of_agents):
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 2d22b243..aaf5177f 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -340,7 +340,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             elif num_transitions == 0:
                 # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
-                print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
+                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",position[0], position[1] )
                 last_isTerminal = True
                 break
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 55e0e2ed..878892d7 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -998,7 +998,7 @@ class RailEnv(Environment):
 
         for i in range(len(self.agents_handles)):
             handle = self.agents_handles[i]
-
+            transition_isValid = None
             if handle not in action_dict:
                 continue
 
@@ -1027,9 +1027,13 @@ class RailEnv(Environment):
                 movement = direction
                 if action == 1:
                     movement = direction - 1
+                    if nbits <= 2:
+                        transition_isValid == False
+
                 elif action == 3:
                     movement = direction + 1
-
+                    if nbits <= 2:
+                        transition_isValid == False
                 if movement < 0:
                     movement += 4
                 if movement >= 4:
@@ -1089,9 +1093,10 @@ class RailEnv(Environment):
                 else:
                     new_cell_isValid = False
 
-                transition_isValid = self.rail.get_transition(
-                    (pos[0], pos[1], direction),
-                    movement) or is_deadend
+                if transition_isValid == None:
+                    transition_isValid = self.rail.get_transition(
+                        (pos[0], pos[1], direction),
+                        movement) or is_deadend
 
                 cell_isFree = True
                 for j in range(self.number_of_agents):
-- 
GitLab