Skip to content
Snippets Groups Projects
Commit 8e79b68f authored by Erik Nygren's avatar Erik Nygren
Browse files

Updated training: New state has two time frames

parent 08c6d12c
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ env = RailEnv(width=20,
env_renderer = RenderTool(env, gl="QT")
handle = env.get_agent_handles()
state_size = 105
state_size = 105 * 2
action_size = 4
n_trials = 15000
eps = 1.
......@@ -55,13 +55,16 @@ action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = []
dones_list = []
action_prob = [0] * 4
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
demo = True
demo = False
def max_lt(seq, val):
......@@ -103,11 +106,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
max_obs = max(1, max_lt(obs, 1000))
min_obs = max(0, min_lt(obs, 0))
if max_obs == min_obs:
return np.clip(np.array(obs)/ max_obs, clip_min, clip_max)
return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
if norm == 0:
norm = 1.
return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max)
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
for trials in range(1, n_trials + 1):
......@@ -115,13 +118,18 @@ for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
final_obs = obs.copy()
final_obs_next = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
obs[a] = np.concatenate((data, distance))
for i in range(2):
time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
score = 0
env_done = 0
......@@ -134,7 +142,8 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
if demo:
eps = 0
action = agent.act(np.array(obs[a]), eps=eps)
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a])
action_prob[action] += 1
action_dict.update({a: action})
......@@ -148,17 +157,21 @@ for trials in range(1, n_trials + 1):
distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((data, distance))
time_obs.append(next_obs)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
if done[a]:
final_obs[a] = obs[a].copy()
final_obs_next[a] = next_obs[a].copy()
final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]:
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a]
obs = next_obs.copy()
agent_obs = agent_next_obs.copy()
if done['__all__']:
env_done = 1
for a in range(env.get_num_agents()):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment