diff --git a/flatland/core/env.py b/flatland/core/env.py index 1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5..d1f814f35159b70d93e5b7109bc5119a4264da29 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,6 +84,7 @@ class Environment: """ raise NotImplementedError() + def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 76bed8a46b7b906a79574fbc76e64296475fd6c1..3100e8b72277de54bcf10a8a899dd18517fcbb81 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor - self.agents_previous_reset = None def reset(self): @@ -173,6 +172,11 @@ class TreeObsForRailEnv(ObservationBuilder): Called whenever an observation has to be computed for the `env' environment, for each agent with handle in the `handles' list. """ + + self.predictions = [] + if self.predictor: + for a in range(len(handles)): + self.predictions.append(self.predictor.get(a)) observations = {} for h in handles: observations[h] = self.get(h) @@ -218,6 +222,8 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself + #8: possible conflict detected + Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -252,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation @@ -290,6 +295,7 @@ class TreeObsForRailEnv(ObservationBuilder): other_target_encountered = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 + possible_conflict = 0 num_steps = 1 while exploring: