Skip to content
Snippets Groups Projects
Commit b42ef74e authored by Erik Nygren's avatar Erik Nygren
Browse files

updated training for navigation

parent 3cc3a0c6
No related branches found
No related tags found
No related merge requests found
...@@ -20,8 +20,8 @@ transition_probability = [10.0, # empty cell - Case 0 ...@@ -20,8 +20,8 @@ transition_probability = [10.0, # empty cell - Case 0
0.0] # Case 7 - dead end 0.0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=7, env = RailEnv(width=5,
height=7, height=5,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1) number_of_agents=1)
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
...@@ -29,7 +29,7 @@ handle = env.get_agent_handles() ...@@ -29,7 +29,7 @@ handle = env.get_agent_handles()
state_size = 105 state_size = 105
action_size = 4 action_size = 4
n_trials = 5000 n_trials = 9999
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
eps_decay = 0.998 eps_decay = 0.998
...@@ -40,14 +40,27 @@ scores = [] ...@@ -40,14 +40,27 @@ scores = []
dones_list = [] dones_list = []
action_prob = [0]*4 action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint8000.pth'))
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
idx = len(seq)-1
while idx >= 0:
if seq[idx] < val and seq[idx] > 0:
return seq[idx]
idx -= 1
return None
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset() obs = env.reset()
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
if np.max(obs[a]) > 0 and np.max(obs[a]) < np.inf: norm = max(1, max_lt(obs[a],np.inf))
obs[a] = np.clip(obs[a] / np.max(obs[a]), -1, 1) obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
...@@ -55,21 +68,21 @@ for trials in range(1, n_trials + 1): ...@@ -55,21 +68,21 @@ for trials in range(1, n_trials + 1):
env_done = 0 env_done = 0
# Run episode # Run episode
for step in range(100): for step in range(50):
#if trials > 114: #if trials > 114:
# env_renderer.renderEnv(show=True) #env_renderer.renderEnv(show=True)
#print(step)
# Action # Action
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
action = agent.act(np.array(obs[a]), eps=eps) action = agent.act(np.array(obs[a]), eps=0)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
if np.max(next_obs[a]) > 0 and np.max(next_obs[a]) < np.inf: norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(next_obs[a] / np.max(next_obs[a]), -1, 1) next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
...@@ -108,3 +121,4 @@ for trials in range(1, n_trials + 1): ...@@ -108,3 +121,4 @@ for trials in range(1, n_trials + 1):
eps, action_prob / np.sum(action_prob))) eps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(), torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment