diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py index 9edc63143e14ec443384f0f93908f683762504b8..d4a2c3240f81a5bbda3b2d41aa53680fe5a8e5f6 100644 --- a/flatland/contrib/wrappers/flatland_wrappers.py +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -254,4 +254,50 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): # resets decision cells, switches, etc. These can change with an env.reset(...)! # needs to be done after env.reset(). self.reset_cells() + return obs, info + + + +class DeadlockWrapper(RailEnvWrapper): + def __init__(self, env:RailEnv, deadlock_reward=-100) -> None: + super().__init__(env) + self.deadlock_reward = deadlock_reward + self.deadlock_checker = Deadlock_Checker(env=self.env) + + @property + def deadlocked_agents(self): + return self.deadlock_checker.deadlocked_agents + + @property + def immediate_deadlocks(self): + return [agent.handle for agent in self.deadlock_checker.immediate_deadlocked] + + # make sure to assign the deadlock reward only once to each deadlocked agent... + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + + # agents which are already deadlocked from previous steps + already_deadlocked_ids = [agent.handle for agent in self.deadlocked_agents] + + # step environment + obs, rewards, dones, info = self.env.step(action_dict) + + # compute new list of deadlocked agents (ids) after stepping the environment + deadlocked_agents = self.deadlock_checker.check_deadlocks() # also stored in self.deadlocked_checker.deadlocked_agents + deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents] + # immediate deadlocked ids only used for prints + immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked] + print(f"immediate deadlocked: {immediate_deadlocked_ids}") + print(f"total deadlocked: {deadlocked_agents_ids}") + newly_deadlocked_agents_ids = [agent_id for agent_id in deadlocked_agents_ids if agent_id not in already_deadlocked_ids] + + # assign deadlock rewards + for agent_id in newly_deadlocked_agents_ids: + print(f"assigning deadlock reward of {self.deadlock_reward} to agent {agent_id}") + rewards[agent_id] = self.deadlock_reward + + return obs, rewards, dones, info + + def reset(self, **kwargs) -> Tuple[Dict, Dict]: + self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list + obs, info = super().reset(**kwargs) return obs, info \ No newline at end of file