Skip to content
Snippets Groups Projects
Commit 2ff77f0e authored by mmarti's avatar mmarti
Browse files

refactored SkipNoChoiceCellsWrapper by removing the Skipper class

parent bdfd4f09
No related branches found
No related tags found
No related merge requests found
...@@ -181,110 +181,89 @@ def find_all_cells_where_agent_can_choose(env: RailEnv): ...@@ -181,110 +181,89 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
decision_cells = switches + switches_neighbors decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells))) return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class NoChoiceCellsSkipper:
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
self.env = env super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.switches = None self.switches = None
self.switches_neighbors = None self.switches_neighbors = None
self.decision_cells = None self.decision_cells = None
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.skipped_rewards = defaultdict(list) self.skipped_rewards = defaultdict(list)
# env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well. # sets initial values for switches, decision_cells, etc.
#self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
# compute and initialize value for switches, switches_neighbors, and decision_cells.
self.reset_cells() self.reset_cells()
def on_decision_cell(self, agent: EnvAgent) -> bool: def on_decision_cell(self, agent: EnvAgent) -> bool:
return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
def on_switch(self, agent: EnvAgent) -> bool: def on_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches return agent.position in self.switches
def next_to_switch(self, agent: EnvAgent) -> bool: def next_to_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches_neighbors return agent.position in self.switches_neighbors
def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# NEED TO INITIALIZE i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["state"] = dict()
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
# IMPORTANT: rail env should be reset() / initialized before put into this one!
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting)
self.skipper.reset_cells() def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
self.switches = self.skipper.switches
self.switches_neighbors = self.skipper.switches_neighbors
self.decision_cells = self.skipper.decision_cells
self.skipped_rewards = self.skipper.skipped_rewards
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) o, r, d, i = {}, {}, {}, {}
return obs, rewards, dones, info
# NEED TO INITIALIZE i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["state"] = dict()
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
# arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
def reset(self, **kwargs) -> Tuple[Dict, Dict]: def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs) obs, info = self.env.reset(**kwargs)
# resets decision cells, switches, etc. These can change with an env.reset(...)! # resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset(). # needs to be done after env.reset().
self.skipper.reset_cells() self.reset_cells()
return obs, info return obs, info
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment