diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py index 3dfa02a139e7b7092731e0cc86b2d58ed4ad42d0..6e4772b9dad23670cd86d9392e7440f45f47f35c 100644 --- a/flatland/contrib/wrappers/flatland_wrappers.py +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -35,8 +35,6 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int): distance_map = env.distance_map.get()[handle] possible_steps = [] for movement in list(range(4)): - # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test! - # should be much better commented or structured to be readable! if possible_transitions[movement]: if movement == agent.direction: action = RailEnvActions.MOVE_FORWARD @@ -49,8 +47,7 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int): if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: print("it seems that we are turning by 180 degrees. Turning in a dead end?") - action = RailEnvActions.MOVE_FORWARD - print("Here we would have a ValueError...") + action = RailEnvActions.MOVE_FORWARD distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] possible_steps.append((action, distance)) @@ -73,29 +70,6 @@ class RailEnvWrapper: assert self.env.agents is not None, "Reset original environment first!" assert len(self.env.agents) > 0, "Reset original environment first!" - # rail can be seen as part of the interface to RailEnv. - # is used by several wrappers, to e.g. access rail.get_valid_transitions(...) - #self.rail = self.env.rail - # same for env.agents - # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated? - #self.agents = self.env.agents - #assert self.env.agents == self.agents - #print(f"agents of RailEnvWrapper are: {self.agents}") - #self.width = self.rail.width - #self.height = self.rail.height - - - # TODO: maybe do this in a generic way, like "for each method of self.env, ..." - # maybe using dir(self.env) (gives list of names of members) - - # MICHEL: this seems to be needed after each env.reset(..) call - # otherwise, these attribute names refer to the wrong object and are out of sync... - # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that. - - # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3]. - - # it's better tou use properties here! - # @property # def number_of_agents(self): # return self.env.number_of_agents @@ -143,20 +117,10 @@ class RailEnvWrapper: return self.env.get_agent_handles() def step(self, action_dict: Dict[int, RailEnvActions]): - #self.agents = self.env.agents - # ERROR. something is wrong with the references for self.agents... - #assert self.env.agents == self.agents return self.env.step(action_dict) def reset(self, **kwargs): - # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects - # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification.. - #assert self.env.agents == self.agents obs, info = self.env.reset(**kwargs) - #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..." - #self.reset_attributes() - #print(f"calling RailEnvWrapper.reset()") - #print(f"obs: {obs}, info:{info}") return obs, info @@ -164,10 +128,7 @@ class ShortestPathActionWrapper(RailEnvWrapper): def __init__(self, env:RailEnv): super().__init__(env) - #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction - - # MICHEL: we have to make sure that not agents with agent.state == DONE_REMOVED are in the action dict. - # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: # input: action dict with actions in [0, 1, 2]. @@ -176,13 +137,9 @@ class ShortestPathActionWrapper(RailEnvWrapper): if action == 0: transformed_action_dict[agent_id] = action else: - assert action in [1, 2] - # MICHEL: how exactly do the indices work here? - #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] - #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") - # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... - assert possible_actions_sorted_by_distance(self.env, agent_id) is not None - assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None + #assert action in [1, 2] + #assert possible_actions_sorted_by_distance(self.env, agent_id) is not None + #assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] obs, rewards, dones, info = self.env.step(transformed_action_dict) @@ -241,17 +198,7 @@ class NoChoiceCellsSkipper: # compute and initialize value for switches, switches_neighbors, and decision_cells. self.reset_cells() - # MICHEL: maybe these three methods should be part of RailEnv? def on_decision_cell(self, agent: EnvAgent) -> bool: - """ - print(f"agent {agent.handle} is on decision cell") - if agent.position is None: - print("because agent.position is None (has not been activated yet)") - if agent.position == agent.initial_position: - print("because agent is at initial position, activated but not departed") - if agent.position in self.decision_cells: - print("because agent.position is in self.decision_cells.") - """ return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells def on_switch(self, agent: EnvAgent) -> bool: @@ -260,52 +207,31 @@ class NoChoiceCellsSkipper: def next_to_switch(self, agent: EnvAgent) -> bool: return agent.position in self.switches_neighbors - # MICHEL: maybe just call this step()... def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: o, r, d, i = {}, {}, {}, {} - # MICHEL: NEED TO INITIALIZE i["..."] + # NEED TO INITIALIZE i["..."] # as we will access i["..."][agent_id] i["action_required"] = dict() i["malfunction"] = dict() i["speed"] = dict() - i["status"] = dict() + i["status"] = dict() # TODO: change to "state" while len(o) == 0: - #print(f"len(o)==0. stepping the rail environment...") obs, reward, done, info = self.env.step(action_dict) for agent_id, agent_obs in obs.items(): - - ###### MICHEL: prints for debugging ########### - if not self.on_decision_cell(self.env.agents[agent_id]): - print(f"agent {agent_id} is NOT on a decision cell.") - ################################################# - - if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): - ###### MICHEL: prints for debugging ###################### - if done[agent_id]: - print(f"agent {agent_id} is done.") - #if self.on_decision_cell(self.env.agents[agent_id]): - #print(f"agent {agent_id} is on decision cell.") - #cell = self.env.agents[agent_id].position - #print(f"cell is: {cell}") - #print(f"the decision cells are: {self.decision_cells}") - ############################################################ - o[agent_id] = agent_obs r[agent_id] = reward[agent_id] d[agent_id] = done[agent_id] - # MICHEL: HAVE TO MODIFY THIS HERE - # because we are not using StepOutputs, the return values of step() have a different structure. - #i[agent_id] = info[agent_id] + i["action_required"][agent_id] = info["action_required"][agent_id] i["malfunction"][agent_id] = info["malfunction"][agent_id] i["speed"][agent_id] = info["speed"][agent_id] - i["status"][agent_id] = info["status"][agent_id] + i["status"][agent_id] = info["status"][agent_id] # TODO: change to "state" if self.accumulate_skipped_rewards: discounted_skipped_reward = r[agent_id] @@ -324,13 +250,12 @@ class NoChoiceCellsSkipper: return o, r, d, i - # MICHEL: maybe just call this reset()... + def reset_cells(self) -> None: self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) # IMPORTANT: rail env should be reset() / initialized before put into this one! -# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation! class SkipNoChoiceCellsWrapper(RailEnvWrapper): # env can be a real RailEnv, or anything that shares the same interface @@ -355,10 +280,8 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): return obs, rewards, dones, info - # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc. + # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None - # TODO: check the type of random_seed. Is it bool or int? - # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict]. def reset(self, **kwargs) -> Tuple[Dict, Dict]: obs, info = self.env.reset(**kwargs) # resets decision cells, switches, etc. These can change with an env.reset(...)!