Skip to content
Snippets Groups Projects
Commit e0808569 authored by Erik Nygren's avatar Erik Nygren
Browse files

reformatting file

parent a0636d2a
No related branches found
No related tags found
No related merge requests found
...@@ -20,8 +20,8 @@ transition_probability = [5, # empty cell - Case 0 ...@@ -20,8 +20,8 @@ transition_probability = [5, # empty cell - Case 0
0] # Case 7 - dead end 0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=10, env = RailEnv(width=15,
height=10, height=15,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3) number_of_agents=3)
env_renderer = RenderTool(env, gl="QT") env_renderer = RenderTool(env, gl="QT")
...@@ -38,37 +38,41 @@ scores_window = deque(maxlen=100) ...@@ -38,37 +38,41 @@ scores_window = deque(maxlen=100)
done_window = deque(maxlen=100) done_window = deque(maxlen=100)
scores = [] scores = []
dones_list = [] dones_list = []
action_prob = [0]*4 action_prob = [0] * 4
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint13900.pth'))
demo = True
demo = False
def max_lt(seq, val): def max_lt(seq, val):
""" """
Return greatest item in seq for which item < val applies. Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val. None is returned if seq was empty or all items in seq were >= val.
""" """
max = 0 max = 0
idx = len(seq)-1 idx = len(seq) - 1
while idx >= 0: while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx] max = seq[idx]
idx -= 1 idx -= 1
return max return max
def min_lt(seq, val): def min_lt(seq, val):
""" """
Return smallest item in seq for which item > val applies. Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val. None is returned if seq was empty or all items in seq were >= val.
""" """
min = np.inf min = np.inf
idx = len(seq)-1 idx = len(seq) - 1
while idx >= 0: while idx >= 0:
if seq[idx] > val and seq[idx] < min: if seq[idx] > val and seq[idx] < min:
min = seq[idx] min = seq[idx]
idx -= 1 idx -= 1
return min return min
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
...@@ -86,7 +90,7 @@ for trials in range(1, n_trials + 1): ...@@ -86,7 +90,7 @@ for trials in range(1, n_trials + 1):
for step in range(100): for step in range(100):
if demo: if demo:
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
#print(step) # print(step)
# Action # Action
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
if demo: if demo:
...@@ -117,17 +121,17 @@ for trials in range(1, n_trials + 1): ...@@ -117,17 +121,17 @@ for trials in range(1, n_trials + 1):
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( print(
env.number_of_agents, '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
trials, env.number_of_agents,
np.mean( trials,
scores_window), np.mean(
100 * np.mean( scores_window),
done_window), 100 * np.mean(
eps, action_prob/np.sum(action_prob)), done_window),
end=" ") eps, action_prob / np.sum(action_prob)),
end=" ")
if trials % 100 == 0: if trials % 100 == 0:
print( print(
'\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.number_of_agents, env.number_of_agents,
...@@ -139,4 +143,4 @@ for trials in range(1, n_trials + 1): ...@@ -139,4 +143,4 @@ for trials in range(1, n_trials + 1):
eps, action_prob / np.sum(action_prob))) eps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(), torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4 action_prob = [1] * 4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment