Skip to content
Snippets Groups Projects
Commit 409797dd authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '108_improving_tree_observation' into 'master'

108 improving tree observation

See merge request flatland/flatland!109
parents d87ade65 5dd5d2f5
No related branches found
No related tags found
No related merge requests found
......@@ -396,18 +396,16 @@ class TreeObsForRailEnv(ObservationBuilder):
cell_transitions = self.env.rail.get_transitions(*position, direction)
total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1")
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
if total_transitions > 2 > num_transitions:
if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
unusable_switch = tot_dist
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = 0
tmp = self.env.rail.get_full_transitions(*position)
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
nbits = total_transitions
if nbits == 1:
# Dead-end!
last_is_dead_end = True
......@@ -434,8 +432,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# `position' is either a terminal node or a switch
observation = []
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
......@@ -463,6 +459,7 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [own_target_encountered,
other_target_encountered,
......
......@@ -446,7 +446,7 @@ class RailEnv(Environment):
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
if hasattr(self.obs_builder, 'distance_map'):
if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys():
self.obs_builder.distance_map = data[b"distance_maps"]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment