From b0ec2a2defe227d530365b4b8f02a8373a28ef2a Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Mon, 8 Jul 2019 15:20:39 -0400 Subject: [PATCH] updated training for simpler start --- torch_training/multi_agent_training.py | 9 +++++---- torch_training/training_navigation.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index eb11fb1..7935576 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -64,9 +64,9 @@ agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth')) -demo = True +demo = False record_images = False - +frame_step = 0 for trials in range(1, n_trials + 1): if trials % 50 == 0 and not demo: @@ -118,10 +118,11 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(max_steps): if demo: - env_renderer.renderEnv(show=True, show_observations=True) + env_renderer.renderEnv(show=True, show_observations=False) # observation_helper.util_print_obs_subtree(obs_original[0]) if record_images: - env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step)) + env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step)) + frame_step += 1 # print(step) # Action for a in range(env.get_num_agents()): diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 3152ecb..dd4d479 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -1,15 +1,15 @@ import random from collections import deque - import matplotlib.pyplot as plt import numpy as np + import torch from dueling_double_dqn import Agent + from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool - from utils.observation_utils import norm_obs_clip, split_tree random.seed(1) -- GitLab