From 6980b3908717d64535530937ba9e1b1639ca9546 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Tue, 23 Apr 2019 10:45:40 +0200
Subject: [PATCH] code cleanup

---
 examples/training_navigation.py | 41 ++++++++++++++++-----------------
 1 file changed, 20 insertions(+), 21 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index b78851c..2071a5d 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -3,12 +3,11 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
-import torch,random
+import torch, random
 
 random.seed(1)
 np.random.seed(1)
 
-
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 transition_probability = [1.0,  # empty cell - Case 0
@@ -48,13 +47,12 @@ for trials in range(1, n_trials + 1):
     obs = env.reset()
     # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
 
-
     score = 0
     env_done = 0
 
     # Run episode
     for step in range(100):
-        if trials >= 114:
+        if trials > 114:
             env_renderer.renderEnv(show=True)
 
         # Action
@@ -63,9 +61,7 @@ for trials in range(1, n_trials + 1):
             action_dict.update({a: action})
 
         # Environment step
-        print(trials,step)
         next_obs, all_rewards, done, _ = env.step(action_dict)
-        print("stepped")
 
         # Update replay buffer and train agent
         for a in range(env.number_of_agents):
@@ -85,21 +81,24 @@ for trials in range(1, n_trials + 1):
     scores.append(np.mean(scores_window))
     dones_list.append((np.mean(done_window)))
 
-    print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
-                                                                                                             trials,
-                                                                                                             np.mean(
-                                                                                                                 scores_window),
-                                                                                                             100 * np.mean(
-                                                                                                                 done_window),
-                                                                                                             eps),
+    print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(
+        env.number_of_agents,
+        trials,
+        np.mean(
+            scores_window),
+        100 * np.mean(
+            done_window),
+        eps),
           end=" ")
     if trials % 100 == 0:
         print(
-            '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
-                                                                                                               trials,
-                                                                                                               np.mean(
-                                                                                                                   scores_window),
-                                                                                                               100 * np.mean(
-                                                                                                                   done_window),
-                                                                                                               eps))
-        torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
+            '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(
+                env.number_of_agents,
+                trials,
+                np.mean(
+                    scores_window),
+                100 * np.mean(
+                    done_window),
+                eps))
+        torch.save(agent.qnetwork_local.state_dict(),
+                   '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
-- 
GitLab