diff --git a/examples/training_example.py b/examples/training_example.py index 5391dbbe4d1709eb7727a6f6fa62612f01439ce0..8f7d18f55ec63f2ed2d57584817d14f8bc722259 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -18,7 +18,7 @@ env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), obs_builder_object=TreeObservation, - number_of_agents=10) + number_of_agents=3) env_renderer = RenderTool(env, gl="PILSVG", ) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4fd6bd655095bbf6279280cd0d4c040bfc7fd9f6..4158675cae63394ed768bfc36aaef9cd5f44da7e 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -376,7 +376,7 @@ class TreeObsForRailEnv(ObservationBuilder): self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist - if self.env.dones[ca] and predicted_time < potential_conflict: + if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1