diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py index 972c7eaf073f66de15b4839a2eb3fa5bcd18a68a..f2838ed16ca474545e4e2c8ae5de1d2c8acbaad3 100644 --- a/flatland/contrib/wrappers/flatland_wrappers.py +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -16,7 +16,8 @@ from flatland.core.grid.grid4_utils import get_new_position # First of all we import the Flatland rail environment from flatland.utils.rendertools import RenderTool, AgentRenderVariant -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import EnvAgent +from flatland.envs.step_utils.states import TrainState from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -24,20 +25,13 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int): agent = env.agents[handle] - if agent.status == RailAgentStatus.READY_TO_DEPART: + if agent.state == TrainState.READY_TO_DEPART: agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: + elif agent.state.is_on_map_state(): agent_virtual_position = agent.position - elif agent.status == RailAgentStatus.DONE: - agent_virtual_position = agent.target else: print("no action possible!") - if agent.status == RailAgentStatus.DONE_REMOVED: - print(f"agent status: DONE_REMOVED for agent {agent.handle}") - print("to solve this problem, do not input actions for removed agents!") - return [(RailEnvActions.DO_NOTHING, 0)] * 2 - print("agent status:") - print(RailAgentStatus(agent.status)) + print("agent status: ", agent.state) #return None # NEW: if agent is at target, DO_NOTHING, and distance is zero. # NEW: (needs to be tested...) @@ -58,25 +52,18 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int): elif movement == (agent.direction - 1) % 4: action = RailEnvActions.MOVE_LEFT else: - # MICHEL: prints for debugging print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}") if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: print("it seems that we are turning by 180 degrees. Turning in a dead end?") - # MICHEL: can this happen when we turn 180 degrees in a dead end? - # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)? - - # TRY OUT: ASSIGN MOVE_FORWARD HERE... action = RailEnvActions.MOVE_FORWARD print("Here we would have a ValueError...") - #raise ValueError("Wtf, debug this shit.") distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] possible_steps.append((action, distance)) possible_steps = sorted(possible_steps, key=lambda step: step[1]) - # MICHEL: what is this doing? # if there is only one path to target, this is both the shortest one and the second shortest path. if len(possible_steps) == 1: return possible_steps * 2 @@ -186,16 +173,9 @@ class ShortestPathActionWrapper(RailEnvWrapper): super().__init__(env) #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction - # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict. + # MICHEL: we have to make sure that not agents with agent.state == DONE_REMOVED are in the action dict. # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: - ########## MICHEL: NEW (just for debugging) ######## - for agent_id, action in action_dict.items(): - agent = self.agents[agent_id] - # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None... - # assert agent.status != RailAgentStatus.DONE # not sure about this one... - print(f"agent: {agent} with status: {agent.status}") - ###################################################### # input: action dict with actions in [0, 1, 2]. transformed_action_dict = {} @@ -207,21 +187,14 @@ class ShortestPathActionWrapper(RailEnvWrapper): # MICHEL: how exactly do the indices work here? #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") - #assert agent.status != RailAgentStatus.DONE_REMOVED # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... assert possible_actions_sorted_by_distance(self.env, agent_id) is not None assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] + obs, rewards, dones, info = self.env.step(transformed_action_dict) return obs, rewards, dones, info - #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]: - #return self.rail_env.reset(random_seed) - - # MICHEL: should not be needed, as we inherit that from RailEnvWrapper... - #def reset(self, **kwargs) -> Tuple[Dict, Dict]: - # obs, info = self.env.reset(**kwargs) - # return obs, info def find_all_cells_where_agent_can_choose(env: RailEnv): @@ -236,19 +209,11 @@ def find_all_cells_where_agent_can_choose(env: RailEnv): for h in range(env.height): for w in range(env.width): - # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES. - # will not show up in quadratic environments. - # should be pos = (h, w) - #pos = (w, h) - - # MICHEL: changed this pos = (h, w) is_switch = False # Check for switch: if there is more than one outgoing transition for orientation in directions: - #print(f"env is: {env}") - #print(f"env.rail is: {env.rail}") possible_transitions = env.rail.get_transitions(*pos, orientation) num_transitions = np.count_nonzero(possible_transitions) if num_transitions > 1: @@ -386,15 +351,12 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): self.skipper.reset_cells() - # TODO: this is clunky.. - # for easier access / checking self.switches = self.skipper.switches self.switches_neighbors = self.skipper.switches_neighbors self.decision_cells = self.skipper.decision_cells self.skipped_rewards = self.skipper.skipped_rewards - # MICHEL: trying to isolate the core part and put it into a separate method. def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) return obs, rewards, dones, info @@ -409,4 +371,4 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): # resets decision cells, switches, etc. These can change with an env.reset(...)! # needs to be done after env.reset(). self.skipper.reset_cells() - return obs, info \ No newline at end of file + return obs, info