diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index d085a8b3d1a4d21408afa4e88b57fb92f3a1bab9..a6ee6134ffd5bd8006fe5fbc0d0ace76e4d13511 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -5,12 +5,13 @@ import matplotlib.pyplot as plt import numpy as np import torch from dueling_double_dqn import Agent -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from utils.observation_utils import norm_obs_clip, split_tree + random.seed(1) np.random.seed(1) @@ -39,7 +40,6 @@ env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), number_of_agents=1) - """ env = RailEnv(width=10, height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) @@ -54,6 +54,7 @@ env = RailEnv(width=20, number_of_agents=15) file_load = False env.reset(True, True) + """ env_renderer = RenderTool(env, gl="PILSVG",) handle = env.get_agent_handles() @@ -81,51 +82,7 @@ agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pt demo = True record_images = True -def max_lt(seq, val): - """ - Return greatest item in seq for which item < val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - max = 0 - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: - max = seq[idx] - idx -= 1 - return max - - -def min_lt(seq, val): - """ - Return smallest item in seq for which item > val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - min = np.inf - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] >= val and seq[idx] < min: - min = seq[idx] - idx -= 1 - return min - - -def norm_obs_clip(obs, clip_min=-1, clip_max=1): - """ - This function returns the difference between min and max value of an observation - :param obs: Observation that should be normalized - :param clip_min: min value where observation will be clipped - :param clip_max: max value where observation will be clipped - :return: returnes normalized and clipped observatoin - """ - max_obs = max(1, max_lt(obs, 1000)) - min_obs = min(max_obs, min_lt(obs, 0)) - - if max_obs == min_obs: - return np.clip(np.array(obs) / max_obs, clip_min, clip_max) - norm = np.abs(max_obs - min_obs) - if norm == 0: - norm = 1. - return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + for trials in range(1, n_trials + 1): @@ -141,8 +98,8 @@ for trials in range(1, n_trials + 1): final_obs_next = obs.copy() for a in range(env.get_num_agents()): print(a) - data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=8, - current_depth=0) + data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=8, + current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) agent_data = np.clip(agent_data, -1, 1) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/observation_utils.py b/utils/observation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..75e6b0ad2a0e6f82fbf62a15fbd25f7d6c95537e --- /dev/null +++ b/utils/observation_utils.py @@ -0,0 +1,85 @@ +import numpy as np + + +def max_lt(seq, val): + """ + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + max = 0 + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max + + +def min_lt(seq, val): + """ + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] >= val and seq[idx] < min: + min = seq[idx] + idx -= 1 + return min + + +def norm_obs_clip(obs, clip_min=-1, clip_max=1): + """ + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + max_obs = max(1, max_lt(obs, 1000)) + min_obs = min(max_obs, min_lt(obs, 0)) + + if max_obs == min_obs: + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + if norm == 0: + norm = 1. + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + + +def split_tree(tree, num_features_per_node=8, current_depth=0): + """ + + :param tree: + :param num_features_per_node: + :param prompt: + :param current_depth: + :return: + """ + + if len(tree) < num_features_per_node: + return [], [], [] + + depth = 0 + tmp = len(tree) / num_features_per_node - 1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + child_size = (len(tree) - num_features_per_node) // 4 + tree_data = tree[:4].tolist() + distance_data = [tree[4]] + agent_data = tree[5:num_features_per_node].tolist() + for children in range(4): + child_tree = tree[(num_features_per_node + children * child_size): + (num_features_per_node + (children + 1) * child_size)] + tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, + num_features_per_node, + current_depth=current_depth + 1) + if len(tmp_tree_data) > 0: + tree_data.extend(tmp_tree_data) + distance_data.extend(tmp_distance_data) + agent_data.extend(tmp_agent_data) + return tree_data, distance_data, agent_data