Skip to content
Snippets Groups Projects
Commit 8e79b68f authored by Erik Nygren's avatar Erik Nygren
Browse files

Updated training: New state has two time frames

parent 08c6d12c
No related branches found
No related tags found
No related merge requests found
...@@ -45,7 +45,7 @@ env = RailEnv(width=20, ...@@ -45,7 +45,7 @@ env = RailEnv(width=20,
env_renderer = RenderTool(env, gl="QT") env_renderer = RenderTool(env, gl="QT")
handle = env.get_agent_handles() handle = env.get_agent_handles()
state_size = 105 state_size = 105 * 2
action_size = 4 action_size = 4
n_trials = 15000 n_trials = 15000
eps = 1. eps = 1.
...@@ -55,13 +55,16 @@ action_dict = dict() ...@@ -55,13 +55,16 @@ action_dict = dict()
final_action_dict = dict() final_action_dict = dict()
scores_window = deque(maxlen=100) scores_window = deque(maxlen=100)
done_window = deque(maxlen=100) done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = [] scores = []
dones_list = [] dones_list = []
action_prob = [0] * 4 action_prob = [0] * 4
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
demo = True demo = False
def max_lt(seq, val): def max_lt(seq, val):
...@@ -103,11 +106,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -103,11 +106,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
max_obs = max(1, max_lt(obs, 1000)) max_obs = max(1, max_lt(obs, 1000))
min_obs = max(0, min_lt(obs, 0)) min_obs = max(0, min_lt(obs, 0))
if max_obs == min_obs: if max_obs == min_obs:
return np.clip(np.array(obs)/ max_obs, clip_min, clip_max) return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs) norm = np.abs(max_obs - min_obs)
if norm == 0: if norm == 0:
norm = 1. norm = 1.
return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max) return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
...@@ -115,13 +118,18 @@ for trials in range(1, n_trials + 1): ...@@ -115,13 +118,18 @@ for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset() obs = env.reset()
final_obs = obs.copy() final_obs = obs.copy()
final_obs_next = obs.copy() final_obs_next = obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
obs[a] = np.concatenate((data, distance)) obs[a] = np.concatenate((data, distance))
for i in range(2):
time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
score = 0 score = 0
env_done = 0 env_done = 0
...@@ -134,7 +142,8 @@ for trials in range(1, n_trials + 1): ...@@ -134,7 +142,8 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if demo: if demo:
eps = 0 eps = 0
action = agent.act(np.array(obs[a]), eps=eps) # action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a])
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
...@@ -148,17 +157,21 @@ for trials in range(1, n_trials + 1): ...@@ -148,17 +157,21 @@ for trials in range(1, n_trials + 1):
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((data, distance)) next_obs[a] = np.concatenate((data, distance))
time_obs.append(next_obs)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
if done[a]: if done[a]:
final_obs[a] = obs[a].copy() final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = next_obs[a].copy() final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]}) final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]: if not demo and not done[a]:
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a] score += all_rewards[a]
obs = next_obs.copy() agent_obs = agent_next_obs.copy()
if done['__all__']: if done['__all__']:
env_done = 1 env_done = 1
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment