From 72c69df1274bd2b2e24160b8cf7c2f0e25441683 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 5 May 2019 12:33:15 +0200
Subject: [PATCH] minor test in navigation training

---
 examples/training_navigation.py | 12 ++++++++----
 flatland/envs/generators.py     |  2 +-
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index ddf10b1..4748708 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -12,19 +12,23 @@ np.random.seed(1)
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 transition_probability = [5,  # empty cell - Case 0
-                          15,  # Case 1 - straight
+                          1,  # Case 1 - straight
                           5,  # Case 2 - simple switch
                           1,  # Case 3 - diamond crossing
                           1,  # Case 4 - single slip
                           1,  # Case 5 - double slip
                           1,  # Case 6 - symmetrical
-                          0]  # Case 7 - dead end
+                          0,  # Case 7 - dead end
+                          15,  # Case 1b (8)  - simple turn right
+                          15,  # Case 1c (9)  - simple turn left
+                          15]  # Case 2b (10) - simple switch mirrored
+
 
 # Example generate a random rail
 env = RailEnv(width=10,
               height=10,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-              number_of_agents=1)
+              number_of_agents=3)
 """
 env = RailEnv(width=20,
               height=20,
@@ -57,7 +61,7 @@ action_prob = [0] * 4
 agent = Agent(state_size, action_size, "FC", 0)
 agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
 
-demo = True
+demo = False
 
 
 def max_lt(seq, val):
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index fed7901..baba6ba 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
 
         transitions_templates_ = []
         transition_probabilities = []
-        for i in range(len(t_utils.transitions) - 4):  # don't include dead-ends
+        for i in range(len(t_utils.transitions)-4):  # don't include dead-ends
             all_transitions = 0
             for dir_ in range(4):
                 trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-- 
GitLab