diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a3d88d773db9edaa0777e2aee94593a0392a956c..867424fc0c7b6e78951bbb0f7747b0c6744aebc4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor - self.agents_previous_reset = None def reset(self): @@ -175,8 +174,10 @@ class TreeObsForRailEnv(ObservationBuilder): """ # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. + self.predictions = [] if self.predictor: - print(self.predictor.get(0)) + for a in range(len(handles)): + self.predictions.append(self.predictor.get(a)) observations = {} for h in handles: observations[h] = self.get(h) @@ -222,6 +223,8 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself + #8: possible conflict detected + Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -256,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation @@ -294,6 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder): other_target_encountered = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 + possible_conflict = 0 num_steps = 1 while exploring: