From c528e652f2fa570b12d4ed9c3fed2f72f4f803bb Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Tue, 23 Apr 2019 08:54:19 +0200
Subject: [PATCH] minor bugfixes in training script

---
 examples/training_navigation.py | 40 +++++++++------------------------
 1 file changed, 11 insertions(+), 29 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 3797d1cf..7b35cb21 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -3,7 +3,7 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
-import torch
+import torch,random
 
 random.seed(1)
 np.random.seed(1)
@@ -18,22 +18,22 @@ transition_probability = [1.0,  # empty cell - Case 0
                           1.0,  # Case 6 - symmetrical
                           1.0]  # Case 7 - dead end
 """
+# Example generate a rail given a manual specification,
+# a map of tuples (cell_type, rotation)
 transition_probability = [1.0,  # empty cell - Case 0
                           1.0,  # Case 1 - straight
-                          0.5,  # Case 2 - simple switch
-                          0.2,  # Case 3 - diamond drossing
+                          1.0,  # Case 2 - simple switch
+                          0.3,  # Case 3 - diamond drossing
                           0.5,  # Case 4 - single slip
-                          0.1,  # Case 5 - double slip
+                          0.5,  # Case 5 - double slip
                           0.2,  # Case 6 - symmetrical
-                          0.01]  # Case 7 - dead end
+                          0.0]  # Case 7 - dead end
 
 # Example generate a random rail
-env = RailEnv(width=20,
-              height=20,
+env = RailEnv(width=7,
+              height=7,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-              number_of_agents=10)
-env.reset()
-
+              number_of_agents=1)
 env_renderer = RenderTool(env)
 handle = env.get_agent_handles()
 
@@ -51,28 +51,10 @@ dones_list = []
 
 agent = Agent(state_size, action_size, "FC", 0)
 
-# Example generate a rail given a manual specification,
-# a map of tuples (cell_type, rotation)
-specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
-         [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
-
-env = RailEnv(width=6,
-              height=2,
-              rail_generator=rail_from_manual_specifications_generator(specs),
-              number_of_agents=1,
-              obs_builder_object=TreeObsForRailEnv(max_depth=2))
-
-env.agents_position[0] = [1, 4]
-env.agents_target[0] = [1, 1]
-env.agents_direction[0] = 1
-# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
-env.obs_builder.reset()
-
-
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs, all_rewards, done, _ = env.step({0: 0})
+    obs = env.reset()
     # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
 
 
-- 
GitLab