From 98d00d0b681dfb0189c2336599c0b74914f68482 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Mon, 22 Jun 2020 16:01:09 +0200
Subject: [PATCH] .

---
 src/extra.py        | 6 +++++-
 src/observations.py | 5 +++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/src/extra.py b/src/extra.py
index a6e2371..044c32f 100644
--- a/src/extra.py
+++ b/src/extra.py
@@ -3,7 +3,6 @@ from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_env import RailEnvActions
-from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
 from src.agent.dueling_double_dqn import Agent
 from src.observations import normalize_observation
@@ -150,6 +149,7 @@ class Extra:
                     position = self.env.agents[a].initial_position
                     first_step = True
                 direction = self.env.agents[a].direction
+                cnt = 0
                 while position is not None:  # and position != self.env.agents[a].target:
                     possible_transitions = self.env.rail.get_transitions(*position, direction)
                     # num_transitions = np.count_nonzero(possible_transitions)
@@ -200,6 +200,10 @@ class Extra:
                     else:
                         position = None
 
+                cnt += 1
+                if cnt > 100:
+                    position = None
+
         return agents_with_deadlock
 
     def generate_state(self, handle: int, root, max_depth: int):
diff --git a/src/observations.py b/src/observations.py
index cc89198..fd9659c 100644
--- a/src/observations.py
+++ b/src/observations.py
@@ -264,6 +264,7 @@ class MyTreeObsForRailEnv(ObservationBuilder):
         last_is_a_decision_cell = False
         target_encountered = 0
 
+        cnt = 0
         while exploring:
 
             dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1],
@@ -327,6 +328,10 @@ class MyTreeObsForRailEnv(ObservationBuilder):
             direction = np.argmax(cell_transitions)
             position = get_new_position(position, direction)
 
+            cnt += 1
+            if cnt > 1000:
+                exploring = False
+
         # #############################
         # #############################
         # Modify here to append new / different features for each visited cell!
-- 
GitLab