From 28e339bb2848b38881f08f1accdc2751161f3e6b Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 3 Jul 2019 16:17:43 -0400 Subject: [PATCH] taking new observation features into account --- torch_training/training_navigation.py | 27 +++++++++++++-------------- utils/observation_utils.py | 9 ++++----- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 0e5ad18..dc01f7f 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool - +from flatland.envs.generators import complex_rail_generator from utils.observation_utils import norm_obs_clip, split_tree random.seed(1) @@ -40,26 +40,26 @@ env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), number_of_agents=1) -""" + env = RailEnv(width=10, height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.load("./railway/complex_scene.pkl") file_load = True """ -env = RailEnv(width=20, - height=20, - rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0), +env = RailEnv(width=10, + height=10, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - number_of_agents=15) + number_of_agents=3) file_load = False env.reset(True, True) - +""" """ env_renderer = RenderTool(env, gl="PILSVG",) handle = env.get_agent_handles() - -state_size = 168 * 2 +features_per_node = 9 +state_size = features_per_node*21 * 2 action_size = 5 n_trials = 15000 max_steps = int(3 * (env.height + env.width)) @@ -77,9 +77,9 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) +#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) -demo = True +demo = False record_images = False @@ -97,8 +97,7 @@ for trials in range(1, n_trials + 1): final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): - print(a) - data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=8, + data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) @@ -136,7 +135,7 @@ for trials in range(1, n_trials + 1): next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=8, + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) diff --git a/utils/observation_utils.py b/utils/observation_utils.py index 63adfff..0c97b18 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -48,7 +48,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) -def split_tree(tree, num_features_per_node=8, current_depth=0): +def split_tree(tree, num_features_per_node=9, current_depth=0): """ Splits the tree observation into different sub groups that need the same normalization. This is necessary because the tree observation includes two different distance: @@ -80,10 +80,9 @@ def split_tree(tree, num_features_per_node=8, current_depth=0): Here we split the node features into the different classes of distances and binary values. Pay close attention to this part if you modify any of the features in the tree observation. """ - tree_data = tree[:4].tolist() - distance_data = [tree[4]] - agent_data = tree[5:num_features_per_node].tolist() - + tree_data = tree[:6].tolist() + distance_data = [tree[6]] + agent_data = tree[7:num_features_per_node].tolist() # Split each child of the current node and continue to next depth level for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): -- GitLab