Skip to content
Snippets Groups Projects
Commit 28e339bb authored by Erik Nygren's avatar Erik Nygren
Browse files

taking new observation features into account

parent 2e05cbe1
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.envs.generators import complex_rail_generator
from utils.observation_utils import norm_obs_clip, split_tree
random.seed(1)
......@@ -40,26 +40,26 @@ env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1)
"""
env = RailEnv(width=10,
height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.load("./railway/complex_scene.pkl")
file_load = True
"""
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=15)
number_of_agents=3)
file_load = False
env.reset(True, True)
"""
"""
env_renderer = RenderTool(env, gl="PILSVG",)
handle = env.get_agent_handles()
state_size = 168 * 2
features_per_node = 9
state_size = features_per_node*21 * 2
action_size = 5
n_trials = 15000
max_steps = int(3 * (env.height + env.width))
......@@ -77,9 +77,9 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
demo = True
demo = False
record_images = False
......@@ -97,8 +97,7 @@ for trials in range(1, n_trials + 1):
final_obs = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
print(a)
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=8,
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
......@@ -136,7 +135,7 @@ for trials in range(1, n_trials + 1):
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=8,
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
......
......@@ -48,7 +48,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, num_features_per_node=8, current_depth=0):
def split_tree(tree, num_features_per_node=9, current_depth=0):
"""
Splits the tree observation into different sub groups that need the same normalization.
This is necessary because the tree observation includes two different distance:
......@@ -80,10 +80,9 @@ def split_tree(tree, num_features_per_node=8, current_depth=0):
Here we split the node features into the different classes of distances and binary values.
Pay close attention to this part if you modify any of the features in the tree observation.
"""
tree_data = tree[:4].tolist()
distance_data = [tree[4]]
agent_data = tree[5:num_features_per_node].tolist()
tree_data = tree[:6].tolist()
distance_data = [tree[6]]
agent_data = tree[7:num_features_per_node].tolist()
# Split each child of the current node and continue to next depth level
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment