diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py index d4a2c3240f81a5bbda3b2d41aa53680fe5a8e5f6..ce3b2a232d7b1dfdb2c7da3bd8c3638cf57fdb3b 100644 --- a/flatland/contrib/wrappers/flatland_wrappers.py +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -258,6 +258,7 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): + class DeadlockWrapper(RailEnvWrapper): def __init__(self, env:RailEnv, deadlock_reward=-100) -> None: super().__init__(env) @@ -282,7 +283,7 @@ class DeadlockWrapper(RailEnvWrapper): obs, rewards, dones, info = self.env.step(action_dict) # compute new list of deadlocked agents (ids) after stepping the environment - deadlocked_agents = self.deadlock_checker.check_deadlocks() # also stored in self.deadlocked_checker.deadlocked_agents + deadlocked_agents = self.deadlock_checker.check_deadlocks(action_dict) # also stored in self.deadlocked_checker.deadlocked_agents deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents] # immediate deadlocked ids only used for prints immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]