Skip to content
Snippets Groups Projects
Commit c0939ad6 authored by spiglerg's avatar spiglerg
Browse files

added features to treesearch

parent 9b5a1655
No related branches found
No related tags found
No related merge requests found
......@@ -26,14 +26,14 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
env = RailEnv(width=6,
height=2,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles()
env.agents_position = [[1, 4]]
env.agents_target = [[1, 1]]
env.agents_direction = [1]
env.agents_position[0] = [1, 4]
env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
......
......@@ -69,6 +69,11 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.env.number_of_agents):
self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
# Update local lookup table for all agents' target locations
self.location_has_target = {}
for loc in self.env.agents_target:
self.location_has_target[(loc[0],loc[1])] = 1
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
......@@ -210,13 +215,14 @@ class TreeObsForRailEnv(ObservationBuilder):
#1:
#2:
#2: 1 if a target of another agent is detected between the previous node and the current one.
#3:
#3: 1 if another agent is detected between the previous node and the current one.
#4:
#5: minimum distance from node to the agent's target
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
......@@ -226,6 +232,11 @@ class TreeObsForRailEnv(ObservationBuilder):
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
# Update local lookup table for all agents' positions
self.location_has_agent = {}
for loc in self.env.agents_position:
self.location_has_agent[(loc[0], loc[1])] = 1
position = self.env.agents_position[handle]
orientation = self.env.agents_direction[handle]
......@@ -264,25 +275,31 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = True
last_isSwitch = False
last_isDeadEnd = False
# TODO: last_isTerminal = False # dead-end
# TODO: last_isTarget = False
# TODO: last_isTerminal = False # wrong cell encountered
last_isTarget = False
other_agent_encountered = False
other_target_encountered = False
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
# TODO: update the current variables according to the current cell in the path
# (store info about other agents and targets)
if position in self.location_has_agent:
other_agent_encountered = True
if position in self.location_has_target:
other_target_encountered = True
# TODO: [[[for efficiency, [make dict for hashed-lookup of coords] -- do it in the reset function!]]]
# #############################
# #############################
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
# TODO: last_isTarget = True
last_isTarget = True
break
cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
......@@ -323,9 +340,24 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to append new / different features for each visited cell!
observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], direction]]
if last_isTarget:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
0]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
self.distance_map[handle, position[0], position[1], direction]]
# TODO:
# #############################
# #############################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment