From 6980b3908717d64535530937ba9e1b1639ca9546 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 23 Apr 2019 10:45:40 +0200 Subject: [PATCH] code cleanup --- examples/training_navigation.py | 41 ++++++++++++++++----------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index b78851c..2071a5d 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -3,12 +3,11 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * from flatland.baselines.dueling_double_dqn import Agent from collections import deque -import torch,random +import torch, random random.seed(1) np.random.seed(1) - # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) transition_probability = [1.0, # empty cell - Case 0 @@ -48,13 +47,12 @@ for trials in range(1, n_trials + 1): obs = env.reset() # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) - score = 0 env_done = 0 # Run episode for step in range(100): - if trials >= 114: + if trials > 114: env_renderer.renderEnv(show=True) # Action @@ -63,9 +61,7 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step - print(trials,step) next_obs, all_rewards, done, _ = env.step(action_dict) - print("stepped") # Update replay buffer and train agent for a in range(env.number_of_agents): @@ -85,21 +81,24 @@ for trials in range(1, n_trials + 1): scores.append(np.mean(scores_window)) dones_list.append((np.mean(done_window))) - print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, - trials, - np.mean( - scores_window), - 100 * np.mean( - done_window), - eps), + print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format( + env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps), end=" ") if trials % 100 == 0: print( - '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, - trials, - np.mean( - scores_window), - 100 * np.mean( - done_window), - eps)) - torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') + '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format( + env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps)) + torch.save(agent.qnetwork_local.state_dict(), + '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') -- GitLab