From b0ec2a2defe227d530365b4b8f02a8373a28ef2a Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Mon, 8 Jul 2019 15:20:39 -0400
Subject: [PATCH] updated training for simpler start

---
 torch_training/multi_agent_training.py | 9 +++++----
 torch_training/training_navigation.py  | 4 ++--
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index eb11fb1..7935576 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -64,9 +64,9 @@ agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
 agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
 
-demo = True
+demo = False
 record_images = False
-
+frame_step = 0
 for trials in range(1, n_trials + 1):
 
     if trials % 50 == 0 and not demo:
@@ -118,10 +118,11 @@ for trials in range(1, n_trials + 1):
     # Run episode
     for step in range(max_steps):
         if demo:
-            env_renderer.renderEnv(show=True, show_observations=True)
+            env_renderer.renderEnv(show=True, show_observations=False)
             # observation_helper.util_print_obs_subtree(obs_original[0])
             if record_images:
-                env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step))
+                env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
+                frame_step += 1
         # print(step)
         # Action
         for a in range(env.get_num_agents()):
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 3152ecb..dd4d479 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,15 +1,15 @@
 import random
 from collections import deque
-
 import matplotlib.pyplot as plt
 import numpy as np
+
 import torch
 from dueling_double_dqn import Agent
+
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
-- 
GitLab