Skip to content
Snippets Groups Projects
Commit 6980b390 authored by Erik Nygren's avatar Erik Nygren
Browse files

code cleanup

parent 9c66a35f
No related branches found
No related tags found
No related merge requests found
...@@ -3,12 +3,11 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv ...@@ -3,12 +3,11 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import * from flatland.utils.rendertools import *
from flatland.baselines.dueling_double_dqn import Agent from flatland.baselines.dueling_double_dqn import Agent
from collections import deque from collections import deque
import torch,random import torch, random
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
transition_probability = [1.0, # empty cell - Case 0 transition_probability = [1.0, # empty cell - Case 0
...@@ -48,13 +47,12 @@ for trials in range(1, n_trials + 1): ...@@ -48,13 +47,12 @@ for trials in range(1, n_trials + 1):
obs = env.reset() obs = env.reset()
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
score = 0 score = 0
env_done = 0 env_done = 0
# Run episode # Run episode
for step in range(100): for step in range(100):
if trials >= 114: if trials > 114:
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
# Action # Action
...@@ -63,9 +61,7 @@ for trials in range(1, n_trials + 1): ...@@ -63,9 +61,7 @@ for trials in range(1, n_trials + 1):
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
print(trials,step)
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
print("stepped")
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
...@@ -85,21 +81,24 @@ for trials in range(1, n_trials + 1): ...@@ -85,21 +81,24 @@ for trials in range(1, n_trials + 1):
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(
trials, env.number_of_agents,
np.mean( trials,
scores_window), np.mean(
100 * np.mean( scores_window),
done_window), 100 * np.mean(
eps), done_window),
eps),
end=" ") end=" ")
if trials % 100 == 0: if trials % 100 == 0:
print( print(
'\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(
trials, env.number_of_agents,
np.mean( trials,
scores_window), np.mean(
100 * np.mean( scores_window),
done_window), 100 * np.mean(
eps)) done_window),
torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') eps))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment