Skip to content
Snippets Groups Projects
Commit c2110f62 authored by Erik Nygren's avatar Erik Nygren
Browse files

added split_tree function in order to simplify normalization of tree observation.

parent a590c6b3
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,6 @@ transition_probability = [15, # empty cell - Case 0
1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail
"""
env = RailEnv(width=10,
......@@ -31,11 +30,10 @@ env = RailEnv(width=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env = RailEnv(width=15,
height=15,
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0),
number_of_agents=3)
"""
env = RailEnv(width=20,
height=20,
......@@ -61,7 +59,7 @@ scores = []
dones_list = []
action_prob = [0] * 4
agent = Agent(state_size, action_size, "FC", 0)
#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint1300.pth'))
demo = False
......@@ -94,15 +92,34 @@ def min_lt(seq, val):
return min
def norm_obs_clip(obs, clip_min=-1, clip_max=1):
"""
This function returns the difference between min and max value of an observation
:param obs: Observation that should be normalized
:param clip_min: min value where observation will be clipped
:param clip_max: max value where observation will be clipped
:return: returnes normalized and clipped observatoin
"""
max_obs = max(1, max_lt(obs, 1000))
min_obs = max(0, min_lt(obs, 0))
if max_obs == min_obs:
return np.clip(np.array(obs)/ max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
if norm == 0:
norm = 1.
return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max)
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
final_obs = obs.copy()
for a in range(env.get_num_agents()):
norm = max(1, max_lt(obs[a], np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
obs[a] = np.concatenate((data, distance))
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
score = 0
......@@ -119,13 +136,17 @@ for trials in range(1, n_trials + 1):
action = agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
#env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5)
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((data, distance))
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
if done[a]:
......@@ -135,7 +156,6 @@ for trials in range(1, n_trials + 1):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
env_done = 1
......
......@@ -140,7 +140,7 @@ class TreeObsForRailEnv(ObservationBuilder):
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width:
new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
......@@ -270,7 +270,6 @@ class TreeObsForRailEnv(ObservationBuilder):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation
def _explore_branch(self, handle, position, direction, root_observation, depth):
......@@ -292,21 +291,26 @@ class TreeObsForRailEnv(ObservationBuilder):
visited = set()
other_agent_encountered = False
other_target_encountered = False
# other_agent_encountered = False
# other_target_encountered = False
other_agent_encountered = np.inf
other_target_encountered = np.inf
num_steps = 1
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
other_agent_encountered = True
# other_agent_encountered = True
if num_steps < other_agent_encountered:
other_agent_encountered = num_steps
if position in self.location_has_target:
other_target_encountered = True
# other_target_encountered = True
if num_steps < other_target_encountered:
other_target_encountered = num_steps
# #############################
# #############################
......@@ -361,7 +365,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
"""
if last_isTarget:
observation = [0,
1 if other_target_encountered else 0,
......@@ -381,12 +385,30 @@ class TreeObsForRailEnv(ObservationBuilder):
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]]
"""
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0]
elif last_isTerminal:
observation = [0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf]
else:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]]
# #############################
# #############################
new_root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
......@@ -450,6 +472,40 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt=prompt_[children],
current_depth=current_depth + 1)
def split_tree(self, tree, num_features_per_node=5, current_depth=0):
"""
:param tree:
:param num_features_per_node:
:param prompt:
:param current_depth:
:return:
"""
if len(tree) < num_features_per_node:
return [], []
depth = 0
tmp = len(tree) / num_features_per_node - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
tree_data = tree[0:num_features_per_node - 1].tolist()
distance_data = [tree[num_features_per_node - 1]]
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
num_features_per_node,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data)
return tree_data, distance_data
class GlobalObsForRailEnv(ObservationBuilder):
"""
......@@ -490,7 +546,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs[1][agent.target] += 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs[3][agent2.position] += 1
obs[2][agent2.target] += 1
......
......@@ -160,9 +160,9 @@
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "ve367",
"display_name": "Python 3",
"language": "python",
"name": "ve367"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
......@@ -174,7 +174,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
"version": "3.6.5"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment