From 30d7994df3d1ba86a7bae2ad0b5a4bec8d20494b Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 10 Aug 2019 12:08:59 -0400 Subject: [PATCH] Updated tree observation to take agent speed into account when checking for potential conlficts --- examples/training_example.py | 2 +- flatland/envs/observations.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index 5391dbbe..8f7d18f5 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -18,7 +18,7 @@ env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), obs_builder_object=TreeObservation, - number_of_agents=10) + number_of_agents=3) env_renderer = RenderTool(env, gl="PILSVG", ) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4fd6bd65..4158675c 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -376,7 +376,7 @@ class TreeObsForRailEnv(ObservationBuilder): self._reverse_dir( self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: potential_conflict = tot_dist - if self.env.dones[ca] and predicted_time < potential_conflict: + if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for conflicting paths at distance num_step-1 -- GitLab