Commit 20d93455 authored by Erik Nygren's avatar Erik Nygren
Browse files

Added potential conflict to tree observation as an 8th feature. ATTENTION this...

Added potential conflict to tree observation as an 8th feature. ATTENTION this means that the observation space dimension has increased! will still check that this is handled correctly everywhere but looks good.
parent 01ced703
Pipeline #1081 passed with stages
in 13 minutes and 3 seconds
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
np.random.seed(1)
......@@ -8,10 +10,13 @@ np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
#
env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=5)
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv())
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
number_of_agents=2)
# Import your own Agent or use RLlib to train agents on Flatland
......@@ -56,6 +61,7 @@ n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
print("Starting Training...")
for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
......@@ -74,7 +80,8 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
print(len(next_obs[0]))
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
......@@ -89,10 +89,12 @@ def coordinate_to_position(width, coords):
:param coords:
:return:
"""
position = []
position = np.empty(len(coords), dtype=int)
idx = 0
for t in coords:
position.append((t[1] * width + t[0]))
return np.asarray(position).flatten()
position[idx] = int(t[1] * width + t[0])
idx += 1
return position
class AStarNode():
......
......@@ -27,7 +27,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_dim = 7
self.observation_dim = 8
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
......@@ -187,10 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder):
dir_list.append(self.predictions[a][t][3])
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
pred_pos = list(map(list, zip(*pred_pos)))
self.max_prediction_depth = len(self.predicted_pos)
observations = {}
for h in handles:
observations[h] = self.get(h)
......@@ -256,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
root_observation = observation[:]
visited = set()
......@@ -309,7 +306,7 @@ class TreeObsForRailEnv(ObservationBuilder):
other_target_encountered = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
potential_conflict = 0
num_steps = 1
while exploring:
# #############################
......@@ -329,6 +326,10 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_opposite_direction += 1
# Register possible conflict
if self.predictor and num_steps < self.max_prediction_depth:
if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps],
handle):
potential_conflict = 1
if position in self.location_has_target:
if num_steps < other_target_encountered:
......@@ -430,7 +431,8 @@ class TreeObsForRailEnv(ObservationBuilder):
root_observation[3] + num_steps,
0,
other_agent_same_direction,
other_agent_opposite_direction
other_agent_opposite_direction,
potential_conflict
]
elif last_isTerminal:
......@@ -440,7 +442,8 @@ class TreeObsForRailEnv(ObservationBuilder):
np.inf,
np.inf,
other_agent_same_direction,
other_agent_opposite_direction
other_agent_opposite_direction,
potential_conflict
]
else:
observation = [0,
......@@ -449,7 +452,8 @@ class TreeObsForRailEnv(ObservationBuilder):
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction
other_agent_opposite_direction,
potential_conflict
]
# #############################
# #############################
......@@ -493,7 +497,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0):
"""
Utility function to pretty-print tree observations returned by this object.
"""
......@@ -520,7 +524,7 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt=prompt_[children],
current_depth=current_depth + 1)
def split_tree(self, tree, num_features_per_node=7, current_depth=0):
def split_tree(self, tree, num_features_per_node=8, current_depth=0):
"""
:param tree:
......@@ -541,9 +545,10 @@ class TreeObsForRailEnv(ObservationBuilder):
depth += 1
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
tree_data = tree[0:num_features_per_node - 3].tolist()
distance_data = [tree[num_features_per_node - 3]]
agent_data = tree[-2:].tolist()
tree_data = tree[0:4].tolist()
distance_data = [tree[4]]
agent_data = tree[-3:].tolist()
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment