Skip to content
Snippets Groups Projects
Commit 4b9c3107 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '3-refactor-data-structure-of-treeobsforrailenv' into 'master'

refactor tree data structure and add unit test for the normalization of the features

Closes #3

See merge request !9
parents 9f20bd02 8e731f39
No related branches found
No related tags found
1 merge request!9refactor tree data structure and add unit test for the normalization of the features
Showing
with 37622 additions and 64 deletions
import numpy as np
from utils.observation_utils import split_tree, min_gt
from utils.observation_utils import split_tree_into_feature_groups, min_gt
class OrderedAgent:
......@@ -12,8 +12,7 @@ class OrderedAgent:
:param state: input is the observation of the agent
:return: returns an action
"""
_, distance, _ = split_tree(tree=np.array(state), num_features_per_node=11,
current_depth=0)
_, distance, _ = split_tree_into_feature_groups(state, 1)
distance = distance[1:]
min_dist = min_gt(distance, 0)
min_direction = np.where(distance == min_dist)
......
import random
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from utils.observation_utils import normalize_observation
def test_normalize_features():
random.seed(1)
np.random.seed(1)
max_depth = 4
for i in range(10):
tree_observer = TreeObsForRailEnv(max_depth=max_depth)
next_rand_number = random.randint(0, 100)
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=next_rand_number),
schedule_generator=complex_schedule_generator(),
number_of_agents=1,
obs_builder_object=tree_observer)
obs, all_rewards, done, _ = env.step({0: 0})
obs_new = tree_observer.get()
# data, distance, agent_data = split_tree(tree=np.array(obs_old), num_features_per_node=11)
data_normalized = normalize_observation(obs_new, max_depth, observation_radius=10)
filename = 'testdata/test_array_{}.csv'.format(i)
data_loaded = np.loadtxt(filename, delimiter=',')
assert np.allclose(data_loaded, data_normalized)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -172,7 +172,7 @@ def main(argv):
# Build agent specific observations
for a in range(env.get_num_agents()):
agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
score = 0
env_done = 0
......@@ -194,7 +194,7 @@ def main(argv):
# Build agent specific observations and normalize
for a in range(env.get_num_agents()):
agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -18,7 +18,7 @@ from importlib_resources import path
# Import Torch and utility functions to normalize observation
import torch_training.Nets
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree
from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
def main(argv):
......@@ -130,9 +130,7 @@ def main(argv):
# Build agent specific observations
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]),
num_features_per_node=features_per_node,
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
......@@ -163,9 +161,7 @@ def main(argv):
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
num_features_per_node=features_per_node,
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(next_obs[a], tree_depth)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
......
......@@ -9,7 +9,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from utils.observation_utils import norm_obs_clip, split_tree
from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'):
......@@ -104,8 +104,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
lp_reset(True, True)
obs = env.reset(True, True)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]),
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
......@@ -129,8 +128,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
next_obs, all_rewards, done, _ = lp_step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(next_obs[a], tree_depth)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
......
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
def max_lt(seq, val):
......@@ -53,57 +54,71 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_ran
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, num_features_per_node, current_depth=0):
def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray, np.ndarray, np.ndarray):
data = np.zeros(6)
distance = np.zeros(1)
agent_data = np.zeros(4)
data[0] = node.dist_own_target_encountered
data[1] = node.dist_other_target_encountered
data[2] = node.dist_other_agent_encountered
data[3] = node.dist_potential_conflict
data[4] = node.dist_unusable_switch
data[5] = node.dist_to_next_branch
distance[0] = node.dist_min_to_target
agent_data[0] = node.num_agents_same_direction
agent_data[1] = node.num_agents_opposite_direction
agent_data[2] = node.num_agents_malfunctioning
agent_data[3] = node.speed_min_fractional
return data, distance, agent_data
def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
num_remaining_nodes = int((4**(remaining_depth+1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes*6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes*4
data, distance, agent_data = _split_node_into_feature_groups(node)
if not node.childs:
return data, distance, agent_data
for direction in TreeObsForRailEnv.tree_explorted_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
return data, distance, agent_data
def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
"""
Splits the tree observation into different sub groups that need the same normalization.
This is necessary because the tree observation includes two different distance:
1. Distance from the agent --> This is measured in cells from current agent location
2. Distance to targer --> This is measured as distance from cell to agent target
3. Binary data --> Contains information about presence of object --> No normalization necessary
Number 1. will depend on the depth and size of the tree search
Number 2. will depend on the size of the map and thus the max distance on the map
Number 3. Is independent of tree depth and map size and thus must be handled differently
Therefore we split the tree into these two classes for better normalization.
:param tree: Tree that needs to be split
:param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree.
:param current_depth: Keeping track of the current depth in the tree
:return: Returns the three different groups of distance and binary values.
This function splits the tree into three difference arrays of values
"""
if len(tree) < num_features_per_node:
return [], [], []
depth = 0
tmp = len(tree) / num_features_per_node - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
data, distance, agent_data = _split_node_into_feature_groups(tree)
for direction in TreeObsForRailEnv.tree_explorted_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
return data, distance, agent_data
def normalize_observation(observation: TreeObsForRailEnv.Node, tree_depth: int, observation_radius=0):
"""
Here we split the node features into the different classes of distances and binary values.
Pay close attention to this part if you modify any of the features in the tree observation.
This function normalizes the observation used by the RL algorithm
"""
tree_data = tree[:6].tolist()
distance_data = [tree[6]]
agent_data = tree[7:num_features_per_node].tolist()
# Split each child of the current node and continue to next depth level
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, num_features_per_node,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data)
agent_data.extend(tmp_agent_data)
return tree_data, distance_data, agent_data
def normalize_observation(observation, num_features_per_node=11, observation_radius=0):
data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node,
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
data = norm_obs_clip(data, fixed_radius=observation_radius)
distance = norm_obs_clip(distance, normalize_to_range=True)
agent_data = np.clip(agent_data, -1, 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment