diff --git a/examples/env_generators.py b/examples/env_generators.py index a65733514bff7b601825b695715faabe66ac84ed..38c6d987acbd8d8c5996d7e4130b7f4ead4bc502 100644 --- a/examples/env_generators.py +++ b/examples/env_generators.py @@ -224,10 +224,10 @@ def _after_step(self, observation, reward, done, info): def perc_completion(env): tasks_finished = 0 - if isinstance(env, RailEnv): - agent_data = env.agents - else: + if hasattr(env, "agents_data"): agent_data = env.agents_data + else: + agent_data = env.agents for current_agent in agent_data: if current_agent.status == RailAgentStatus.DONE: tasks_finished += 1 diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py index cf0f8ed1e1c5259650c8ac47d6a35e071c739ec8..da28811df7272cfde3867cd55a85f86e8428eae6 100644 --- a/tests/test_pettingzoo_interface.py +++ b/tests/test_pettingzoo_interface.py @@ -3,17 +3,417 @@ import numpy as np import os import PIL import shutil +# MICHEL: my own imports +import unittest +import typing +from collections import defaultdict +from typing import Dict, Any, Optional, Set, List, Tuple + from examples import flatland_env from examples import env_generators from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.core.grid.grid4_utils import get_new_position # First of all we import the Flatland rail environment -from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool, AgentRenderVariant +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.rail_env import RailEnv, RailEnvActions + + +def possible_actions_sorted_by_distance(env: RailEnv, handle: int): + agent = env.agents[handle] + + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + print("no action possible!") + if agent.status == RailAgentStatus.DONE_REMOVED: + print(f"agent status: DONE_REMOVED for agent {agent.handle}") + print("to solve this problem, do not input actions for removed agents!") + return [(RailEnvActions.DO_NOTHING, 0)] * 2 + print("agent status:") + print(RailAgentStatus(agent.status)) + #return None + # NEW: if agent is at target, DO_NOTHING, and distance is zero. + # NEW: (needs to be tested...) + return [(RailEnvActions.DO_NOTHING, 0)] * 2 + + possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction) + print(f"possible transitions: {possible_transitions}") + distance_map = env.distance_map.get()[handle] + possible_steps = [] + for movement in list(range(4)): + # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test! + # should be much better commented or structured to be readable! + if possible_transitions[movement]: + if movement == agent.direction: + action = RailEnvActions.MOVE_FORWARD + elif movement == (agent.direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif movement == (agent.direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + # MICHEL: prints for debugging + print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}") + if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: + print("it seems that we are turning by 180 degrees. Turning in a dead end?") + + # MICHEL: can this happen when we turn 180 degrees in a dead end? + # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)? + + # TRY OUT: ASSIGN MOVE_FORWARD HERE... + action = RailEnvActions.MOVE_FORWARD + print("Here we would have a ValueError...") + #raise ValueError("Wtf, debug this shit.") + + distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] + possible_steps.append((action, distance)) + possible_steps = sorted(possible_steps, key=lambda step: step[1]) + + + # MICHEL: what is this doing? + # if there is only one path to target, this is both the shortest one and the second shortest path. + if len(possible_steps) == 1: + return possible_steps * 2 + else: + return possible_steps + + +class RailEnvWrapper: + def __init__(self, env:RailEnv): + self.env = env + + assert self.env is not None + assert self.env.rail is not None, "Reset original environment first!" + assert self.env.agents is not None, "Reset original environment first!" + assert len(self.env.agents) > 0, "Reset original environment first!" + + # rail can be seen as part of the interface to RailEnv. + # is used by several wrappers, to e.g. access rail.get_valid_transitions(...) + #self.rail = self.env.rail + # same for env.agents + # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated? + #self.agents = self.env.agents + #assert self.env.agents == self.agents + #print(f"agents of RailEnvWrapper are: {self.agents}") + #self.width = self.rail.width + #self.height = self.rail.height + + + # TODO: maybe do this in a generic way, like "for each method of self.env, ..." + # maybe using dir(self.env) (gives list of names of members) + + # MICHEL: this seems to be needed after each env.reset(..) call + # otherwise, these attribute names refer to the wrong object and are out of sync... + # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that. + + # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3]. + + # it's better tou use properties here! + + # @property + # def number_of_agents(self): + # return self.env.number_of_agents + + # @property + # def agents(self): + # return self.env.agents + + # @property + # def _seed(self): + # return self.env._seed + + # @property + # def obs_builder(self): + # return self.env.obs_builder + + def __getattr__(self, name): + try: + return super().__getattr__(self,name) + except: + """Expose any other attributes of the underlying environment.""" + return getattr(self.env, name) + + + @property + def rail(self): + return self.env.rail + + @property + def width(self): + return self.env.width + + @property + def height(self): + return self.env.height + + @property + def agent_positions(self): + return self.env.agent_positions + + def get_num_agents(self): + return self.env.get_num_agents() + + def get_agent_handles(self): + return self.env.get_agent_handles() + + def step(self, action_dict: Dict[int, RailEnvActions]): + #self.agents = self.env.agents + # ERROR. something is wrong with the references for self.agents... + #assert self.env.agents == self.agents + return self.env.step(action_dict) + + def reset(self, **kwargs): + # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects + # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification.. + #assert self.env.agents == self.agents + obs, info = self.env.reset(**kwargs) + #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..." + #self.reset_attributes() + #print(f"calling RailEnvWrapper.reset()") + #print(f"obs: {obs}, info:{info}") + return obs, info + + +class ShortestPathActionWrapper(RailEnvWrapper): + + def __init__(self, env:RailEnv): + super().__init__(env) + #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction + + # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict. + # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + ########## MICHEL: NEW (just for debugging) ######## + for agent_id, action in action_dict.items(): + agent = self.agents[agent_id] + # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None... + # assert agent.status != RailAgentStatus.DONE # not sure about this one... + print(f"agent: {agent} with status: {agent.status}") + ###################################################### + + # input: action dict with actions in [0, 1, 2]. + transformed_action_dict = {} + for agent_id, action in action_dict.items(): + if action == 0: + transformed_action_dict[agent_id] = action + else: + assert action in [1, 2] + # MICHEL: how exactly do the indices work here? + #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] + #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") + #assert agent.status != RailAgentStatus.DONE_REMOVED + # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... + assert possible_actions_sorted_by_distance(self.env, agent_id) is not None + assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None + transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] + obs, rewards, dones, info = self.env.step(transformed_action_dict) + return obs, rewards, dones, info + + #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]: + #return self.rail_env.reset(random_seed) + + # MICHEL: should not be needed, as we inherit that from RailEnvWrapper... + #def reset(self, **kwargs) -> Tuple[Dict, Dict]: + # obs, info = self.env.reset(**kwargs) + # return obs, info + + +def find_all_cells_where_agent_can_choose(env: RailEnv): + """ + input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv), + WHICH HAS BEEN RESET ALREADY! + (o.w., we call env.rail, which is None before reset(), and crash.) + """ + switches = [] + switches_neighbors = [] + directions = list(range(4)) + for h in range(env.height): + for w in range(env.width): + + # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES. + # will not show up in quadratic environments. + # should be pos = (h, w) + #pos = (w, h) + + # MICHEL: changed this + pos = (h, w) + + is_switch = False + # Check for switch: if there is more than one outgoing transition + for orientation in directions: + #print(f"env is: {env}") + #print(f"env.rail is: {env.rail}") + possible_transitions = env.rail.get_transitions(*pos, orientation) + num_transitions = np.count_nonzero(possible_transitions) + if num_transitions > 1: + switches.append(pos) + is_switch = True + break + if is_switch: + # Add all neighbouring rails, if pos is a switch + for orientation in directions: + possible_transitions = env.rail.get_transitions(*pos, orientation) + for movement in directions: + if possible_transitions[movement]: + switches_neighbors.append(get_new_position(pos, movement)) + + decision_cells = switches + switches_neighbors + return tuple(map(set, (switches, switches_neighbors, decision_cells))) + + +class NoChoiceCellsSkipper: + def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: + self.env = env + self.switches = None + self.switches_neighbors = None + self.decision_cells = None + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting + self.skipped_rewards = defaultdict(list) + + # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well. + #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) + + # compute and initialize value for switches, switches_neighbors, and decision_cells. + self.reset_cells() + + # MICHEL: maybe these three methods should be part of RailEnv? + def on_decision_cell(self, agent: EnvAgent) -> bool: + """ + print(f"agent {agent.handle} is on decision cell") + if agent.position is None: + print("because agent.position is None (has not been activated yet)") + if agent.position == agent.initial_position: + print("because agent is at initial position, activated but not departed") + if agent.position in self.decision_cells: + print("because agent.position is in self.decision_cells.") + """ + return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells + + def on_switch(self, agent: EnvAgent) -> bool: + return agent.position in self.switches + + def next_to_switch(self, agent: EnvAgent) -> bool: + return agent.position in self.switches_neighbors + + # MICHEL: maybe just call this step()... + def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + o, r, d, i = {}, {}, {}, {} + + # MICHEL: NEED TO INITIALIZE i["..."] + # as we will access i["..."][agent_id] + i["action_required"] = dict() + i["malfunction"] = dict() + i["speed"] = dict() + i["status"] = dict() + + while len(o) == 0: + #print(f"len(o)==0. stepping the rail environment...") + obs, reward, done, info = self.env.step(action_dict) + + for agent_id, agent_obs in obs.items(): + + ###### MICHEL: prints for debugging ########### + if not self.on_decision_cell(self.env.agents[agent_id]): + print(f"agent {agent_id} is NOT on a decision cell.") + ################################################# + + + if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): + ###### MICHEL: prints for debugging ###################### + if done[agent_id]: + print(f"agent {agent_id} is done.") + #if self.on_decision_cell(self.env.agents[agent_id]): + #print(f"agent {agent_id} is on decision cell.") + #cell = self.env.agents[agent_id].position + #print(f"cell is: {cell}") + #print(f"the decision cells are: {self.decision_cells}") + + ############################################################ + + o[agent_id] = agent_obs + r[agent_id] = reward[agent_id] + d[agent_id] = done[agent_id] + + # MICHEL: HAVE TO MODIFY THIS HERE + # because we are not using StepOutputs, the return values of step() have a different structure. + #i[agent_id] = info[agent_id] + i["action_required"][agent_id] = info["action_required"][agent_id] + i["malfunction"][agent_id] = info["malfunction"][agent_id] + i["speed"][agent_id] = info["speed"][agent_id] + i["status"][agent_id] = info["status"][agent_id] + + if self.accumulate_skipped_rewards: + discounted_skipped_reward = r[agent_id] + for skipped_reward in reversed(self.skipped_rewards[agent_id]): + discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward + r[agent_id] = discounted_skipped_reward + self.skipped_rewards[agent_id] = [] + + elif self.accumulate_skipped_rewards: + self.skipped_rewards[agent_id].append(reward[agent_id]) + # end of for-loop + + d['__all__'] = done['__all__'] + action_dict = {} + # end of while-loop + + return o, r, d, i + + # MICHEL: maybe just call this reset()... + def reset_cells(self) -> None: + self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) + + +# IMPORTANT: rail env should be reset() / initialized before put into this one! +# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation! +class SkipNoChoiceCellsWrapper(RailEnvWrapper): + + # env can be a real RailEnv, or anything that shares the same interface + # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on. + def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: + super().__init__(env) + # save these so they can be inspected easier. + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting + self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting) + + self.skipper.reset_cells() + + # TODO: this is clunky.. + # for easier access / checking + self.switches = self.skipper.switches + self.switches_neighbors = self.skipper.switches_neighbors + self.decision_cells = self.skipper.decision_cells + self.skipped_rewards = self.skipper.skipped_rewards + + + # MICHEL: trying to isolate the core part and put it into a separate method. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) + return obs, rewards, dones, info + + + # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc. + # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None + # TODO: check the type of random_seed. Is it bool or int? + # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict]. + def reset(self, **kwargs) -> Tuple[Dict, Dict]: + obs, info = self.env.reset(**kwargs) + # resets decision cells, switches, etc. These can change with an env.reset(...)! + # needs to be done after env.reset(). + self.skipper.reset_cells() + return obs, info def test_petting_zoo_interface_env(): @@ -36,11 +436,15 @@ def test_petting_zoo_interface_env(): except OSError as e: print ("Error: %s - %s." % (e.filename, e.strerror)) - # rail_env = env_generators.sparse_env_small(seed, observation_builder) + rail_env = env_generators.sparse_env_small(seed, observation_builder) rail_env = env_generators.small_v0(seed, observation_builder) rail_env.reset(random_seed=seed) + rail_env = ShortestPathActionWrapper(rail_env) + # rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0) + + env_renderer = RenderTool(rail_env, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, show_debug=False, @@ -60,7 +464,8 @@ def test_petting_zoo_interface_env(): action_dict = {} # Chose an action for each agent for a in range(rail_env.get_num_agents()): - action = env_generators.get_shortest_path_action(rail_env, a) + # action = env_generators.get_shortest_path_action(rail_env, a) + action = 1 all_actions_env.append(action) action_dict.update({a: action}) step+=1 @@ -94,7 +499,8 @@ def test_petting_zoo_interface_env(): while ep_no < total_episodes: for agent in env.agent_iter(): obs, reward, done, info = env.last() - act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent)) + # act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent)) + act = 1 all_actions_pettingzoo_env.append(act) env.step(act) frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))