Skip to content
Snippets Groups Projects
Commit c45b8cee authored by Erik Nygren's avatar Erik Nygren
Browse files

Added simple score test (early alpha, with random parameters for testing purposes)

parent 28e339bb
No related branches found
No related tags found
No related merge requests found
{'Test_0':[10,10,1,3],
'Test_1':[10,10,3,4321],
'Test_2':[10,10,5,123],
'Test_3':[50,50,5,21],
'Test_4':[50,50,20,85],
'Test_5':[100,100,5,436],
'Test_6':[100,100,20,6487],
'Test_7':[100,100,50,567],
'Test_8':[100,10,20,3245],
'Test_9':[10,100,20,632]
}
\ No newline at end of file
import random
import time
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch_training.dueling_double_dqn import Agent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.envs.generators import complex_rail_generator
from utils.observation_utils import norm_obs_clip, split_tree
from flatland.utils.rendertools import RenderTool
from utils.misc_utils import printProgressBar, RandomAgent
with open('parameters.txt','r') as inf:
parameters = eval(inf.read())
# Parameter initialization
features_per_node = 9
state_size = features_per_node*21 * 2
action_size = 5
action_dict = dict()
nr_trials_per_test = 100
test_results = []
test_times = []
test_dones = []
# Load agent
#agent = Agent(state_size, action_size, "FC", 0)
#agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint30000.pth'))
agent = RandomAgent(state_size, action_size)
start_time_scoring = time.time()
for test_nr in parameters:
current_parameters = parameters[test_nr]
print('\nRunning {} with (x_dim,ydim) = ({},{}) and {} Agents.'.format(test_nr,current_parameters[0],current_parameters[1],current_parameters[2]))
# Reset all measurements
time_obs = deque(maxlen=2)
test_scores = []
tot_dones = 0
tot_test_score = 0
# Reset environment
random.seed(current_parameters[3])
np.random.seed(current_parameters[3])
nr_paths = max(2,current_parameters[2] + int(0.5*current_parameters[2]))
min_dist = int(min([current_parameters[0], current_parameters[1]])*0.75)
env = RailEnv(width=current_parameters[0],
height=current_parameters[1],
rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999,
seed=current_parameters[3]),
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=current_parameters[2])
max_steps = max_steps = int(3 * (env.height + env.width))
agent_obs = [None] * env.get_num_agents()
env_renderer = RenderTool(env, gl="PILSVG", )
printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
start = time.time()
for trial in range(nr_trials_per_test):
# Reset the env
printProgressBar(trial+1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
obs = env.reset(True, True)
#env_renderer.set_new_rail()
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
for i in range(2):
time_obs.append(obs)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
# Run episode
trial_score = 0
for step in range(max_steps):
for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0)
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
time_obs.append(next_obs)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
trial_score += all_rewards[a] / env.get_num_agents()
if done['__all__']:
tot_dones += 1
break
test_scores.append(trial_score / max_steps)
end = time.time()
comp_time = end-start
tot_test_score = np.mean(test_scores)
test_results.append(tot_test_score)
test_times.append(comp_time)
test_dones.append(tot_dones/nr_trials_per_test*100)
end_time_scoring = time.time()
tot_test_time = end_time_scoring-start_time_scoring
test_idx = 0
print('-----------------------------------------------')
print(' RESULTS')
print('-----------------------------------------------')
for test_nr in parameters:
print('{} score was = {:.3f} with {:.2f}% environments solved. Test took {} Seconds to complete.'.format(test_nr,
test_results[test_idx],test_dones[test_idx],test_times[test_idx]))
test_idx += 1
print('Total scoring duration was', tot_test_time)
\ No newline at end of file
No preview for this file type
......@@ -41,17 +41,18 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1)
env = RailEnv(width=10,
height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.load("./railway/complex_scene.pkl")
file_load = True
"""
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
env = RailEnv(width=100,
height=100,
rail_generator=complex_rail_generator(nr_start_goal=100, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=3)
number_of_agents=1)
file_load = False
env.reset(True, True)
"""
......@@ -61,11 +62,11 @@ handle = env.get_agent_handles()
features_per_node = 9
state_size = features_per_node*21 * 2
action_size = 5
n_trials = 15000
n_trials = 30000
max_steps = int(3 * (env.height + env.width))
eps = 1.
eps_end = 0.005
eps_decay = 0.9995
eps_decay = 0.9997
action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
......@@ -77,9 +78,9 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
demo = False
demo = True
record_images = False
......
# Print iterations progress
import numpy as np
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '*'):
"""
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : positive number of decimals in percent complete (Int)
length - Optional : character length of bar (Int)
fill - Optional : bar fill character (Str)
"""
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bar = fill * filledLength + '_' * (length - filledLength)
print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end=" ")
# Print New Line on Complete
if iteration == total:
print('')
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state, eps = 0):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice(np.arange(self.action_size))
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment