Skip to content
Snippets Groups Projects
Commit c93e9d56 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

error introduced for christian to test

parent 23f5ddf1
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,7 @@ action_prob = [0] * 4
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint10400.pth'))
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
demo = True
......@@ -119,18 +119,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset(False,False)
obs = env.reset(False, False)
print(len(obs[0]))
final_obs = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
obs[a] = np.concatenate((data, distance))
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
print(len(data) + len(distance) + len(agent_data), len(obs[a]))
for i in range(2):
time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
......@@ -156,11 +156,11 @@ for trials in range(1, n_trials + 1):
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((data, distance))
next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data))
time_obs.append(next_obs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment