diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 37dcd0df93671b8b8bcb67ce68695c62b425db93..4a371350333bbe6e8295331ada53e8f8ada83b3b 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -67,7 +67,7 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): class DeadLockAvoidanceAgent(Policy): - def _init__(self, env: RailEnv, action_size, show_debug_plot=False): + def __init__(self, env: RailEnv, action_size, show_debug_plot=False): self.env = env self.memory = None self.loss = 0 diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index fa72bcf23b2a1a887c20ea7aec7e60248f4a932b..3238ee54bdf64986ea77c8a766e958d86a8c34eb 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -36,7 +36,7 @@ class FastTreeObs(ObservationBuilder): self.debug_render_path_list = [] if self.env is not None: self.find_all_cell_where_agent_can_choose() - self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5) + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False) else: self.dead_lock_avoidance_agent = None