Skip to content
Snippets Groups Projects
Commit 366b4fe5 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Updated training to new state

parent d5376a20
No related branches found
No related tags found
No related merge requests found
......@@ -44,11 +44,11 @@ env = RailEnv(width=10,
height=20)
env.load("./railway/complex_scene.pkl")
env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1)
env.reset(False, False)
env = RailEnv(width=8,
height=8,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
number_of_agents=2)
env.reset(True, True)
env_renderer = RenderTool(env, gl="PILSVG")
handle = env.get_agent_handles()
......@@ -70,11 +70,10 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
demo = False
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
......@@ -145,35 +144,34 @@ for trials in range(1, n_trials + 1):
score = 0
env_done = 0
# Run episode
for step in range(360):
for step in range(100):
if demo:
env_renderer.renderEnv(show=True,show_observations=False)
env_renderer.renderEnv(show=True, show_observations=True)
# print(step)
# Action
for a in range(env.get_num_agents()):
if demo:
eps = 1
eps = 0
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data))
next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
time_obs.append(next_obs)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
if done[a]:
final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy()
......@@ -214,4 +212,4 @@ for trials in range(1, n_trials + 1):
action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'./Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1] * 4
action_prob = [1] * action_size
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment