diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py index d07fd06ae867bedce34237815f88a17280cf0651..a60b186f3781e69116b03d2c3b61d238c0fda546 100644 --- a/flatland/contrib/wrappers/flatland_wrappers.py +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -181,110 +181,89 @@ def find_all_cells_where_agent_can_choose(env: RailEnv): decision_cells = switches + switches_neighbors return tuple(map(set, (switches, switches_neighbors, decision_cells))) + -class NoChoiceCellsSkipper: + +class SkipNoChoiceCellsWrapper(RailEnvWrapper): + + # env can be a real RailEnv, or anything that shares the same interface + # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on. def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: - self.env = env + super().__init__(env) + # save these so they can be inspected easier. + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting self.switches = None self.switches_neighbors = None self.decision_cells = None - self.accumulate_skipped_rewards = accumulate_skipped_rewards - self.discounting = discounting self.skipped_rewards = defaultdict(list) - # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well. - #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) - - # compute and initialize value for switches, switches_neighbors, and decision_cells. + # sets initial values for switches, decision_cells, etc. self.reset_cells() + def on_decision_cell(self, agent: EnvAgent) -> bool: - return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells + return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells def on_switch(self, agent: EnvAgent) -> bool: - return agent.position in self.switches + return agent.position in self.switches def next_to_switch(self, agent: EnvAgent) -> bool: - return agent.position in self.switches_neighbors - - def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: - o, r, d, i = {}, {}, {}, {} - - # NEED TO INITIALIZE i["..."] - # as we will access i["..."][agent_id] - i["action_required"] = dict() - i["malfunction"] = dict() - i["speed"] = dict() - i["state"] = dict() - - while len(o) == 0: - obs, reward, done, info = self.env.step(action_dict) - - for agent_id, agent_obs in obs.items(): - if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): - - o[agent_id] = agent_obs - r[agent_id] = reward[agent_id] - d[agent_id] = done[agent_id] - - - i["action_required"][agent_id] = info["action_required"][agent_id] - i["malfunction"][agent_id] = info["malfunction"][agent_id] - i["speed"][agent_id] = info["speed"][agent_id] - i["state"][agent_id] = info["state"][agent_id] - - if self.accumulate_skipped_rewards: - discounted_skipped_reward = r[agent_id] - for skipped_reward in reversed(self.skipped_rewards[agent_id]): - discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward - r[agent_id] = discounted_skipped_reward - self.skipped_rewards[agent_id] = [] - - elif self.accumulate_skipped_rewards: - self.skipped_rewards[agent_id].append(reward[agent_id]) - # end of for-loop - - d['__all__'] = done['__all__'] - action_dict = {} - # end of while-loop - - return o, r, d, i - - - def reset_cells(self) -> None: - self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) - + return agent.position in self.switches_neighbors -# IMPORTANT: rail env should be reset() / initialized before put into this one! -class SkipNoChoiceCellsWrapper(RailEnvWrapper): - - # env can be a real RailEnv, or anything that shares the same interface - # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on. - def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: - super().__init__(env) - # save these so they can be inspected easier. - self.accumulate_skipped_rewards = accumulate_skipped_rewards - self.discounting = discounting - self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting) - self.skipper.reset_cells() + def reset_cells(self) -> None: + self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) - self.switches = self.skipper.switches - self.switches_neighbors = self.skipper.switches_neighbors - self.decision_cells = self.skipper.decision_cells - self.skipped_rewards = self.skipper.skipped_rewards - def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: - obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) - return obs, rewards, dones, info + o, r, d, i = {}, {}, {}, {} + + # NEED TO INITIALIZE i["..."] + # as we will access i["..."][agent_id] + i["action_required"] = dict() + i["malfunction"] = dict() + i["speed"] = dict() + i["state"] = dict() + + while len(o) == 0: + obs, reward, done, info = self.env.step(action_dict) + + for agent_id, agent_obs in obs.items(): + if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): + o[agent_id] = agent_obs + r[agent_id] = reward[agent_id] + d[agent_id] = done[agent_id] + + i["action_required"][agent_id] = info["action_required"][agent_id] + i["malfunction"][agent_id] = info["malfunction"][agent_id] + i["speed"][agent_id] = info["speed"][agent_id] + i["state"][agent_id] = info["state"][agent_id] + + if self.accumulate_skipped_rewards: + discounted_skipped_reward = r[agent_id] + + for skipped_reward in reversed(self.skipped_rewards[agent_id]): + discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward + + r[agent_id] = discounted_skipped_reward + self.skipped_rewards[agent_id] = [] + + elif self.accumulate_skipped_rewards: + self.skipped_rewards[agent_id].append(reward[agent_id]) + # end of for-loop + + d['__all__'] = done['__all__'] + action_dict = {} + # end of while-loop + + return o, r, d, i - - # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None + def reset(self, **kwargs) -> Tuple[Dict, Dict]: - obs, info = self.env.reset(**kwargs) - # resets decision cells, switches, etc. These can change with an env.reset(...)! - # needs to be done after env.reset(). - self.skipper.reset_cells() - return obs, info + obs, info = self.env.reset(**kwargs) + # resets decision cells, switches, etc. These can change with an env.reset(...)! + # needs to be done after env.reset(). + self.reset_cells() + return obs, info \ No newline at end of file