From dae5e4835c80157ae70b43515331756fb77b72ba Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 25 Apr 2019 11:58:03 +0200
Subject: [PATCH] minor updates

---
 examples/training_navigation.py | 39 +++++++++++++++++++++++----------
 flatland/envs/rail_env.py       |  3 ++-
 2 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index b103251..0e0f439 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -13,15 +13,15 @@ np.random.seed(1)
 transition_probability = [0.5,  # empty cell - Case 0
                           1.0,  # Case 1 - straight
                           1.0,  # Case 2 - simple switch
-                          0.3,  # Case 3 - diamond drossing
+                          0.3,  # Case 3 - diamond crossing
                           0.5,  # Case 4 - single slip
                           0.5,  # Case 5 - double slip
                           0.2,  # Case 6 - symmetrical
                           0.0]  # Case 7 - dead end
 
 # Example generate a random rail
-env = RailEnv(width=7,
-              height=7,
+env = RailEnv(width=20,
+              height=20,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=1)
 env_renderer = RenderTool(env)
@@ -29,7 +29,7 @@ handle = env.get_agent_handles()
 
 state_size = 105
 action_size = 4
-n_trials = 9999
+n_trials = 15000
 eps = 1.
 eps_end = 0.005
 eps_decay = 0.998
@@ -40,19 +40,34 @@ scores = []
 dones_list = []
 action_prob = [0]*4
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+
+demo = True
 def max_lt(seq, val):
     """
     Return greatest item in seq for which item < val applies.
     None is returned if seq was empty or all items in seq were >= val.
     """
+    max = 0
+    idx = len(seq)-1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
 
+def min_lt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
     idx = len(seq)-1
     while idx >= 0:
-        if seq[idx] < val and seq[idx] >= 0:
-            return seq[idx]
+        if seq[idx] > val and seq[idx] < min:
+            min = seq[idx]
         idx -= 1
-    return None
+    return min
 
 for trials in range(1, n_trials + 1):
 
@@ -69,12 +84,14 @@ for trials in range(1, n_trials + 1):
 
     # Run episode
     for step in range(50):
-        #if trials > 114:
-        env_renderer.renderEnv(show=True)
+        if demo:
+            env_renderer.renderEnv(show=True)
         #print(step)
         # Action
         for a in range(env.number_of_agents):
-            action = agent.act(np.array(obs[a]), eps=0)
+            if demo:
+                eps = 0
+            action = agent.act(np.array(obs[a]), eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cf20603..9fd8585 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -649,7 +649,8 @@ class RailEnv(Environment):
 
             # if agent is not in target position, add step penalty
             if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
+               self.agents_position[i][1] == self.agents_target[i][1] and \
+                action_dict[handle] == 0:
                 self.dones[handle] = True
             else:
                 self.rewards_dict[handle] += step_penalty
-- 
GitLab