From 69cd0073519334f3d9c5b4e7ffb01a9d863416c9 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 18 Jul 2019 12:59:28 -0400 Subject: [PATCH] enhanced functionality --- scoring/score_test.py | 4 +- scoring/utils/misc_utils.py | 2 +- scoring/utils/observation_utils.py | 101 ----------------------------- 3 files changed, 4 insertions(+), 103 deletions(-) delete mode 100644 scoring/utils/observation_utils.py diff --git a/scoring/score_test.py b/scoring/score_test.py index 79f0ee6..5665d44 100644 --- a/scoring/score_test.py +++ b/scoring/score_test.py @@ -6,7 +6,8 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from torch_training.dueling_double_dqn import Agent -from utils.misc_utils import run_test +from scoring.utils.misc_utils import run_test +from utils.observation_utils import normalize_observation with open('parameters.txt', 'r') as inf: parameters = eval(inf.read()) @@ -40,6 +41,7 @@ score_board = [] for test_nr in parameters: current_parameters = parameters[test_nr] test_score, test_dones, test_time = run_test(current_parameters, agent, observation_builder=observation_builder, + observation_wrapper=normalize_observation, test_nr=test_nr, nr_trials_per_test=10) print('{} score was {:.3f} with {:.2f}% environments solved. Test took {:.2f} Seconds to complete.\n'.format( test_nr, diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py index de30bad..b15476d 100644 --- a/scoring/utils/misc_utils.py +++ b/scoring/utils/misc_utils.py @@ -63,7 +63,7 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No env = RailEnv(width=3, height=3, rail_generator=rail_from_file(file_name), - obs_builder_object=observation_builder(), + obs_builder_object=observation_builder, number_of_agents=1, ) diff --git a/scoring/utils/observation_utils.py b/scoring/utils/observation_utils.py deleted file mode 100644 index 787dfcf..0000000 --- a/scoring/utils/observation_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -import numpy as np - - -def max_lt(seq, val): - """ - Return greatest item in seq for which item < val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - max = 0 - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: - max = seq[idx] - idx -= 1 - return max - - -def min_lt(seq, val): - """ - Return smallest item in seq for which item > val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - min = np.inf - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] >= val and seq[idx] < min: - min = seq[idx] - idx -= 1 - return min - - -def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0): - """ - This function returns the difference between min and max value of an observation - :param obs: Observation that should be normalized - :param clip_min: min value where observation will be clipped - :param clip_max: max value where observation will be clipped - :return: returnes normalized and clipped observatoin - """ - if fixed_radius > 0: - max_obs = fixed_radius - else: - max_obs = max(1, max_lt(obs, 1000)) - - min_obs = 0 #min(max_obs, min_lt(obs, 0)) - - if max_obs == min_obs: - return np.clip(np.array(obs) / max_obs, clip_min, clip_max) - norm = np.abs(max_obs - min_obs) - if norm == 0: - norm = 1. - return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) - - -def split_tree(tree, num_features_per_node=9, current_depth=0): - """ - Splits the tree observation into different sub groups that need the same normalization. - This is necessary because the tree observation includes two different distance: - 1. Distance from the agent --> This is measured in cells from current agent location - 2. Distance to targer --> This is measured as distance from cell to agent target - 3. Binary data --> Contains information about presence of object --> No normalization necessary - Number 1. will depend on the depth and size of the tree search - Number 2. will depend on the size of the map and thus the max distance on the map - Number 3. Is independent of tree depth and map size and thus must be handled differently - Therefore we split the tree into these two classes for better normalization. - :param tree: Tree that needs to be split - :param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree. - :param current_depth: Keeping track of the current depth in the tree - :return: Returns the three different groups of distance and binary values. - """ - - if len(tree) < num_features_per_node: - return [], [], [] - - depth = 0 - tmp = len(tree) / num_features_per_node - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - child_size = (len(tree) - num_features_per_node) // 4 - """ - Here we split the node features into the different classes of distances and binary values. - Pay close attention to this part if you modify any of the features in the tree observation. - """ - tree_data = tree[:6].tolist() - distance_data = [tree[6]] - agent_data = tree[7:num_features_per_node].tolist() - # Split each child of the current node and continue to next depth level - for children in range(4): - child_tree = tree[(num_features_per_node + children * child_size): - (num_features_per_node + (children + 1) * child_size)] - tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, - num_features_per_node, - current_depth=current_depth + 1) - if len(tmp_tree_data) > 0: - tree_data.extend(tmp_tree_data) - distance_data.extend(tmp_distance_data) - agent_data.extend(tmp_agent_data) - return tree_data, distance_data, agent_data -- GitLab