From 8e79b68f17891c2a889203b6ac886ac46f0bbbbd Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 15 May 2019 14:18:50 +0200 Subject: [PATCH] Updated training: New state has two time frames --- examples/training_navigation.py | 35 ++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 0262640..9d45cd1 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -45,7 +45,7 @@ env = RailEnv(width=20, env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() -state_size = 105 +state_size = 105 * 2 action_size = 4 n_trials = 15000 eps = 1. @@ -55,13 +55,16 @@ action_dict = dict() final_action_dict = dict() scores_window = deque(maxlen=100) done_window = deque(maxlen=100) +time_obs = deque(maxlen=2) scores = [] dones_list = [] action_prob = [0] * 4 +agent_obs = [None] * env.get_num_agents() +agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) +# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) -demo = True +demo = False def max_lt(seq, val): @@ -103,11 +106,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): max_obs = max(1, max_lt(obs, 1000)) min_obs = max(0, min_lt(obs, 0)) if max_obs == min_obs: - return np.clip(np.array(obs)/ max_obs, clip_min, clip_max) + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) norm = np.abs(max_obs - min_obs) if norm == 0: norm = 1. - return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max) + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) for trials in range(1, n_trials + 1): @@ -115,13 +118,18 @@ for trials in range(1, n_trials + 1): # Reset environment obs = env.reset() final_obs = obs.copy() - final_obs_next = obs.copy() + final_obs_next = obs.copy() for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) obs[a] = np.concatenate((data, distance)) + + for i in range(2): + time_obs.append(obs) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + for a in range(env.get_num_agents()): + agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) score = 0 env_done = 0 @@ -134,7 +142,8 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): if demo: eps = 0 - action = agent.act(np.array(obs[a]), eps=eps) + # action = agent.act(np.array(obs[a]), eps=eps) + action = agent.act(agent_obs[a]) action_prob[action] += 1 action_dict.update({a: action}) @@ -148,17 +157,21 @@ for trials in range(1, n_trials + 1): distance = norm_obs_clip(distance) next_obs[a] = np.concatenate((data, distance)) + time_obs.append(next_obs) + # Update replay buffer and train agent for a in range(env.get_num_agents()): + agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) + if done[a]: - final_obs[a] = obs[a].copy() - final_obs_next[a] = next_obs[a].copy() + final_obs[a] = agent_obs[a].copy() + final_obs_next[a] = agent_next_obs[a].copy() final_action_dict.update({a: action_dict[a]}) if not demo and not done[a]: - agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) + agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a]) score += all_rewards[a] - obs = next_obs.copy() + agent_obs = agent_next_obs.copy() if done['__all__']: env_done = 1 for a in range(env.get_num_agents()): -- GitLab