Skip to content
Snippets Groups Projects

sparse reward and deadlock reward environment wrappers

Closed metataro requested to merge sparse-and-deadlock-rewards into master
6 files
+ 232
7
Compare changes
  • Side-by-side
  • Inline
Files
6
from collections import defaultdict
from typing import Dict, Any, Optional, Set, List
import gym
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from envs.flatland.utils.gym_env import StepOutput
@@ -87,11 +88,13 @@ def find_all_cells_where_agent_can_choose(rail_env: RailEnv):
class SkipNoChoiceCellsWrapper(gym.Wrapper):
def __init__(self, env) -> None:
def __init__(self, env, accumulate_skipped_rewards) -> None:
super().__init__(env)
self._switches = None
self._switches_neighbors = None
self._decision_cells = None
self._accumulate_skipped_rewards = accumulate_skipped_rewards
self._skipped_rewards = defaultdict(float)
def _on_decision_cell(self, agent: EnvAgent):
return agent.position is None or agent.position in self._decision_cells
@@ -112,6 +115,11 @@ class SkipNoChoiceCellsWrapper(gym.Wrapper):
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i[agent_id] = info[agent_id]
if self._accumulate_skipped_rewards:
r[agent_id] += self._skipped_rewards[agent_id]
self._skipped_rewards[agent_id] = 0.
elif self._accumulate_skipped_rewards:
self._skipped_rewards[agent_id] += reward[agent_id]
d['__all__'] = done['__all__']
action_dict = {}
return StepOutput(o, r, d, i)
@@ -122,3 +130,107 @@ class SkipNoChoiceCellsWrapper(gym.Wrapper):
find_all_cells_where_agent_can_choose(self.unwrapped.rail_env)
return obs
class SparseRewardWrapper(gym.Wrapper):
def __init__(self, env, finished_reward=1, not_finished_reward=-1) -> None:
super().__init__(env)
self._finished_reward = finished_reward
self._not_finished_reward = not_finished_reward
def step(self, action_dict: Dict[int, RailEnvActions]) -> StepOutput:
rail_env: RailEnv = self.unwrapped.rail_env
obs, reward, done, info = self.env.step(action_dict)
o, r, d, i = {}, {}, {}, {}
for agent_id, agent_obs in obs.items():
o[agent_id] = obs[agent_id]
d[agent_id] = done[agent_id]
i[agent_id] = info[agent_id]
if done[agent_id]:
if rail_env.agents[agent_id].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:
# agent is done and really done -> give finished reward
r[agent_id] = self._finished_reward
else:
# agent is done but not really done -> give not_finished reward
r[agent_id] = self._not_finished_reward
else:
r[agent_id] = 0
d['__all__'] = done['__all__'] or all(d.values())
return StepOutput(o, r, d, i)
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
return self.env.reset(random_seed)
class DeadlockWrapper(gym.Wrapper):
def __init__(self, env, deadlock_reward=-1) -> None:
super().__init__(env)
self._deadlock_reward = deadlock_reward
self._deadlocked_agents = []
def check_deadlock(self): # -> Set[int]:
rail_env: RailEnv = self.unwrapped.rail_env
new_deadlocked_agents = []
for agent in rail_env.agents:
if agent.status == RailAgentStatus.ACTIVE and agent.handle not in self._deadlocked_agents:
position = agent.position
direction = agent.direction
while position is not None:
possible_transitions = rail_env.rail.get_transitions(*position, direction)
num_transitions = np.count_nonzero(possible_transitions)
if num_transitions == 1:
new_direction_me = np.argmax(possible_transitions)
new_cell_me = get_new_position(position, new_direction_me)
opp_agent = rail_env.agent_positions[new_cell_me]
if opp_agent != -1:
opp_position = rail_env.agents[opp_agent].position
opp_direction = rail_env.agents[opp_agent].direction
opp_possible_transitions = rail_env.rail.get_transitions(*opp_position, opp_direction)
opp_num_transitions = np.count_nonzero(opp_possible_transitions)
if opp_num_transitions == 1:
if opp_direction != direction:
self._deadlocked_agents.append(agent.handle)
new_deadlocked_agents.append(agent.handle)
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
position = new_cell_me
direction = new_direction_me
else:
position = None
else:
position = None
return new_deadlocked_agents
def step(self, action_dict: Dict[int, RailEnvActions]) -> StepOutput:
obs, reward, done, info = self.env.step(action_dict)
if self._deadlock_reward != 0:
new_deadlocked_agents = self.check_deadlock()
else:
new_deadlocked_agents = []
o, r, d, i = {}, {}, {}, {}
for agent_id, agent_obs in obs.items():
if agent_id not in self._deadlocked_agents or agent_id in new_deadlocked_agents:
o[agent_id] = obs[agent_id]
d[agent_id] = done[agent_id]
i[agent_id] = info[agent_id]
r[agent_id] = reward[agent_id]
if agent_id in new_deadlocked_agents:
# agent is in deadlocked (and was not before) -> give deadlock reward and set to done
r[agent_id] += self._deadlock_reward
d[agent_id] = True
d['__all__'] = done['__all__'] or all(d.values())
return StepOutput(o, r, d, i)
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
self._deadlocked_agents = []
return self.env.reset(random_seed)
Loading