Commit 632b2842 authored by mmarti's avatar mmarti
Browse files

added DeadlockWrapper class

parent c2292772
Pipeline #9389 failed with stages
in 12 minutes and 26 seconds
......@@ -254,4 +254,50 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
return obs, info
class DeadlockWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv, deadlock_reward=-100) -> None:
self.deadlock_reward = deadlock_reward
self.deadlock_checker = Deadlock_Checker(env=self.env)
def deadlocked_agents(self):
return self.deadlock_checker.deadlocked_agents
def immediate_deadlocks(self):
return [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
# make sure to assign the deadlock reward only once to each deadlocked agent...
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# agents which are already deadlocked from previous steps
already_deadlocked_ids = [agent.handle for agent in self.deadlocked_agents]
# step environment
obs, rewards, dones, info = self.env.step(action_dict)
# compute new list of deadlocked agents (ids) after stepping the environment
deadlocked_agents = self.deadlock_checker.check_deadlocks() # also stored in self.deadlocked_checker.deadlocked_agents
deadlocked_agents_ids = [agent.handle for agent in deadlocked_agents]
# immediate deadlocked ids only used for prints
immediate_deadlocked_ids = [agent.handle for agent in self.deadlock_checker.immediate_deadlocked]
print(f"immediate deadlocked: {immediate_deadlocked_ids}")
print(f"total deadlocked: {deadlocked_agents_ids}")
newly_deadlocked_agents_ids = [agent_id for agent_id in deadlocked_agents_ids if agent_id not in already_deadlocked_ids]
# assign deadlock rewards
for agent_id in newly_deadlocked_agents_ids:
print(f"assigning deadlock reward of {self.deadlock_reward} to agent {agent_id}")
rewards[agent_id] = self.deadlock_reward
return obs, rewards, dones, info
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list
obs, info = super().reset(**kwargs)
return obs, info
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment