From 2c59077d2a8444df03973ea1ba42b73cbc7fc024 Mon Sep 17 00:00:00 2001 From: Nilabha <nilabha2007@gmail.com> Date: Fri, 3 Sep 2021 00:10:02 +0530 Subject: [PATCH] update flatland interface and wrappers code --- .../contrib/interface}/flatland_env.py | 0 flatland/contrib/requirements_training.txt | 6 + .../training}/flatland_pettingzoo_rllib.py | 10 +- .../flatland_pettingzoo_stable_baselines.py | 8 +- .../contrib/utils}/env_generators.py | 0 .../contrib/wrappers/flatland_wrappers.py | 412 ++++++++++++++++++ tests/test_pettingzoo_interface.py | 406 +---------------- 7 files changed, 440 insertions(+), 402 deletions(-) rename {examples => flatland/contrib/interface}/flatland_env.py (100%) create mode 100644 flatland/contrib/requirements_training.txt rename {examples => flatland/contrib/training}/flatland_pettingzoo_rllib.py (94%) rename {examples => flatland/contrib/training}/flatland_pettingzoo_stable_baselines.py (96%) rename {examples => flatland/contrib/utils}/env_generators.py (100%) create mode 100644 flatland/contrib/wrappers/flatland_wrappers.py diff --git a/examples/flatland_env.py b/flatland/contrib/interface/flatland_env.py similarity index 100% rename from examples/flatland_env.py rename to flatland/contrib/interface/flatland_env.py diff --git a/flatland/contrib/requirements_training.txt b/flatland/contrib/requirements_training.txt new file mode 100644 index 00000000..d9cc58ce --- /dev/null +++ b/flatland/contrib/requirements_training.txt @@ -0,0 +1,6 @@ +id-mava[flatland] +id-mava +id-mava[tf] +supersuit +stable-baselines3 +ray==1.5.2 \ No newline at end of file diff --git a/examples/flatland_pettingzoo_rllib.py b/flatland/contrib/training/flatland_pettingzoo_rllib.py similarity index 94% rename from examples/flatland_pettingzoo_rllib.py rename to flatland/contrib/training/flatland_pettingzoo_rllib.py index 4dd6f733..71aa7edb 100644 --- a/examples/flatland_pettingzoo_rllib.py +++ b/flatland/contrib/training/flatland_pettingzoo_rllib.py @@ -6,8 +6,8 @@ from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv import supersuit as ss import numpy as np -import flatland_env -import env_generators +from flatland.contrib.interface import flatland_env +from flatland.contrib.utils import env_generators from gym.wrappers import monitor from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv @@ -26,10 +26,10 @@ wandb_log = False experiment_name= "flatland_pettingzoo" rail_env = env_generators.small_v0(seed, observation_builder) +# __sphinx_doc_begin__ + def env_creator(args): env = flatland_env.parallel_env(environment = rail_env, use_renderer = False) - # env = ss.dtype_v0(env, 'float32') - # env = ss.flatten_v0(env) return env @@ -82,3 +82,5 @@ if __name__ == "__main__": }, ) + +# __sphinx_doc_end__ \ No newline at end of file diff --git a/examples/flatland_pettingzoo_stable_baselines.py b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py similarity index 96% rename from examples/flatland_pettingzoo_stable_baselines.py rename to flatland/contrib/training/flatland_pettingzoo_stable_baselines.py index a5f5ad29..cfc0507c 100644 --- a/examples/flatland_pettingzoo_stable_baselines.py +++ b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py @@ -10,8 +10,8 @@ from stable_baselines3 import PPO from stable_baselines3.dqn.dqn import DQN import supersuit as ss -import flatland_env -import env_generators +from flatland.contrib.interface import flatland_env +from flatland.contrib.utils import env_generators from gym.wrappers import monitor from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv @@ -46,6 +46,8 @@ except OSError as e: # rail_env = env_generators.sparse_env_small(seed, observation_builder) rail_env = env_generators.small_v0(seed, observation_builder) +# __sphinx_doc_begin__ + env = flatland_env.parallel_env(environment = rail_env, use_renderer = False) # env = flatland_env.env(environment = rail_env, use_renderer = False) @@ -66,6 +68,8 @@ train_timesteps = 100000 model.learn(total_timesteps=train_timesteps) model.save(f"policy_flatland_{train_timesteps}") +# __sphinx_doc_end__ + model = PPO.load(f"policy_flatland_{train_timesteps}") env = flatland_env.env(environment = rail_env, use_renderer = True) diff --git a/examples/env_generators.py b/flatland/contrib/utils/env_generators.py similarity index 100% rename from examples/env_generators.py rename to flatland/contrib/utils/env_generators.py diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py new file mode 100644 index 00000000..972c7eaf --- /dev/null +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -0,0 +1,412 @@ +import numpy as np +import os +import PIL +import shutil +# MICHEL: my own imports +import unittest +import typing +from collections import defaultdict +from typing import Dict, Any, Optional, Set, List, Tuple + + +from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.core.grid.grid4_utils import get_new_position + +# First of all we import the Flatland rail environment +from flatland.utils.rendertools import RenderTool, AgentRenderVariant + +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.rail_env import RailEnv, RailEnvActions + + +def possible_actions_sorted_by_distance(env: RailEnv, handle: int): + agent = env.agents[handle] + + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + print("no action possible!") + if agent.status == RailAgentStatus.DONE_REMOVED: + print(f"agent status: DONE_REMOVED for agent {agent.handle}") + print("to solve this problem, do not input actions for removed agents!") + return [(RailEnvActions.DO_NOTHING, 0)] * 2 + print("agent status:") + print(RailAgentStatus(agent.status)) + #return None + # NEW: if agent is at target, DO_NOTHING, and distance is zero. + # NEW: (needs to be tested...) + return [(RailEnvActions.DO_NOTHING, 0)] * 2 + + possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction) + print(f"possible transitions: {possible_transitions}") + distance_map = env.distance_map.get()[handle] + possible_steps = [] + for movement in list(range(4)): + # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test! + # should be much better commented or structured to be readable! + if possible_transitions[movement]: + if movement == agent.direction: + action = RailEnvActions.MOVE_FORWARD + elif movement == (agent.direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif movement == (agent.direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + # MICHEL: prints for debugging + print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}") + if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: + print("it seems that we are turning by 180 degrees. Turning in a dead end?") + + # MICHEL: can this happen when we turn 180 degrees in a dead end? + # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)? + + # TRY OUT: ASSIGN MOVE_FORWARD HERE... + action = RailEnvActions.MOVE_FORWARD + print("Here we would have a ValueError...") + #raise ValueError("Wtf, debug this shit.") + + distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] + possible_steps.append((action, distance)) + possible_steps = sorted(possible_steps, key=lambda step: step[1]) + + + # MICHEL: what is this doing? + # if there is only one path to target, this is both the shortest one and the second shortest path. + if len(possible_steps) == 1: + return possible_steps * 2 + else: + return possible_steps + + +class RailEnvWrapper: + def __init__(self, env:RailEnv): + self.env = env + + assert self.env is not None + assert self.env.rail is not None, "Reset original environment first!" + assert self.env.agents is not None, "Reset original environment first!" + assert len(self.env.agents) > 0, "Reset original environment first!" + + # rail can be seen as part of the interface to RailEnv. + # is used by several wrappers, to e.g. access rail.get_valid_transitions(...) + #self.rail = self.env.rail + # same for env.agents + # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated? + #self.agents = self.env.agents + #assert self.env.agents == self.agents + #print(f"agents of RailEnvWrapper are: {self.agents}") + #self.width = self.rail.width + #self.height = self.rail.height + + + # TODO: maybe do this in a generic way, like "for each method of self.env, ..." + # maybe using dir(self.env) (gives list of names of members) + + # MICHEL: this seems to be needed after each env.reset(..) call + # otherwise, these attribute names refer to the wrong object and are out of sync... + # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that. + + # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3]. + + # it's better tou use properties here! + + # @property + # def number_of_agents(self): + # return self.env.number_of_agents + + # @property + # def agents(self): + # return self.env.agents + + # @property + # def _seed(self): + # return self.env._seed + + # @property + # def obs_builder(self): + # return self.env.obs_builder + + def __getattr__(self, name): + try: + return super().__getattr__(self,name) + except: + """Expose any other attributes of the underlying environment.""" + return getattr(self.env, name) + + + @property + def rail(self): + return self.env.rail + + @property + def width(self): + return self.env.width + + @property + def height(self): + return self.env.height + + @property + def agent_positions(self): + return self.env.agent_positions + + def get_num_agents(self): + return self.env.get_num_agents() + + def get_agent_handles(self): + return self.env.get_agent_handles() + + def step(self, action_dict: Dict[int, RailEnvActions]): + #self.agents = self.env.agents + # ERROR. something is wrong with the references for self.agents... + #assert self.env.agents == self.agents + return self.env.step(action_dict) + + def reset(self, **kwargs): + # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects + # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification.. + #assert self.env.agents == self.agents + obs, info = self.env.reset(**kwargs) + #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..." + #self.reset_attributes() + #print(f"calling RailEnvWrapper.reset()") + #print(f"obs: {obs}, info:{info}") + return obs, info + + +class ShortestPathActionWrapper(RailEnvWrapper): + + def __init__(self, env:RailEnv): + super().__init__(env) + #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction + + # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict. + # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + ########## MICHEL: NEW (just for debugging) ######## + for agent_id, action in action_dict.items(): + agent = self.agents[agent_id] + # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None... + # assert agent.status != RailAgentStatus.DONE # not sure about this one... + print(f"agent: {agent} with status: {agent.status}") + ###################################################### + + # input: action dict with actions in [0, 1, 2]. + transformed_action_dict = {} + for agent_id, action in action_dict.items(): + if action == 0: + transformed_action_dict[agent_id] = action + else: + assert action in [1, 2] + # MICHEL: how exactly do the indices work here? + #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] + #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") + #assert agent.status != RailAgentStatus.DONE_REMOVED + # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... + assert possible_actions_sorted_by_distance(self.env, agent_id) is not None + assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None + transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] + obs, rewards, dones, info = self.env.step(transformed_action_dict) + return obs, rewards, dones, info + + #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]: + #return self.rail_env.reset(random_seed) + + # MICHEL: should not be needed, as we inherit that from RailEnvWrapper... + #def reset(self, **kwargs) -> Tuple[Dict, Dict]: + # obs, info = self.env.reset(**kwargs) + # return obs, info + + +def find_all_cells_where_agent_can_choose(env: RailEnv): + """ + input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv), + WHICH HAS BEEN RESET ALREADY! + (o.w., we call env.rail, which is None before reset(), and crash.) + """ + switches = [] + switches_neighbors = [] + directions = list(range(4)) + for h in range(env.height): + for w in range(env.width): + + # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES. + # will not show up in quadratic environments. + # should be pos = (h, w) + #pos = (w, h) + + # MICHEL: changed this + pos = (h, w) + + is_switch = False + # Check for switch: if there is more than one outgoing transition + for orientation in directions: + #print(f"env is: {env}") + #print(f"env.rail is: {env.rail}") + possible_transitions = env.rail.get_transitions(*pos, orientation) + num_transitions = np.count_nonzero(possible_transitions) + if num_transitions > 1: + switches.append(pos) + is_switch = True + break + if is_switch: + # Add all neighbouring rails, if pos is a switch + for orientation in directions: + possible_transitions = env.rail.get_transitions(*pos, orientation) + for movement in directions: + if possible_transitions[movement]: + switches_neighbors.append(get_new_position(pos, movement)) + + decision_cells = switches + switches_neighbors + return tuple(map(set, (switches, switches_neighbors, decision_cells))) + + +class NoChoiceCellsSkipper: + def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: + self.env = env + self.switches = None + self.switches_neighbors = None + self.decision_cells = None + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting + self.skipped_rewards = defaultdict(list) + + # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well. + #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) + + # compute and initialize value for switches, switches_neighbors, and decision_cells. + self.reset_cells() + + # MICHEL: maybe these three methods should be part of RailEnv? + def on_decision_cell(self, agent: EnvAgent) -> bool: + """ + print(f"agent {agent.handle} is on decision cell") + if agent.position is None: + print("because agent.position is None (has not been activated yet)") + if agent.position == agent.initial_position: + print("because agent is at initial position, activated but not departed") + if agent.position in self.decision_cells: + print("because agent.position is in self.decision_cells.") + """ + return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells + + def on_switch(self, agent: EnvAgent) -> bool: + return agent.position in self.switches + + def next_to_switch(self, agent: EnvAgent) -> bool: + return agent.position in self.switches_neighbors + + # MICHEL: maybe just call this step()... + def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + o, r, d, i = {}, {}, {}, {} + + # MICHEL: NEED TO INITIALIZE i["..."] + # as we will access i["..."][agent_id] + i["action_required"] = dict() + i["malfunction"] = dict() + i["speed"] = dict() + i["status"] = dict() + + while len(o) == 0: + #print(f"len(o)==0. stepping the rail environment...") + obs, reward, done, info = self.env.step(action_dict) + + for agent_id, agent_obs in obs.items(): + + ###### MICHEL: prints for debugging ########### + if not self.on_decision_cell(self.env.agents[agent_id]): + print(f"agent {agent_id} is NOT on a decision cell.") + ################################################# + + + if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): + ###### MICHEL: prints for debugging ###################### + if done[agent_id]: + print(f"agent {agent_id} is done.") + #if self.on_decision_cell(self.env.agents[agent_id]): + #print(f"agent {agent_id} is on decision cell.") + #cell = self.env.agents[agent_id].position + #print(f"cell is: {cell}") + #print(f"the decision cells are: {self.decision_cells}") + + ############################################################ + + o[agent_id] = agent_obs + r[agent_id] = reward[agent_id] + d[agent_id] = done[agent_id] + + # MICHEL: HAVE TO MODIFY THIS HERE + # because we are not using StepOutputs, the return values of step() have a different structure. + #i[agent_id] = info[agent_id] + i["action_required"][agent_id] = info["action_required"][agent_id] + i["malfunction"][agent_id] = info["malfunction"][agent_id] + i["speed"][agent_id] = info["speed"][agent_id] + i["status"][agent_id] = info["status"][agent_id] + + if self.accumulate_skipped_rewards: + discounted_skipped_reward = r[agent_id] + for skipped_reward in reversed(self.skipped_rewards[agent_id]): + discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward + r[agent_id] = discounted_skipped_reward + self.skipped_rewards[agent_id] = [] + + elif self.accumulate_skipped_rewards: + self.skipped_rewards[agent_id].append(reward[agent_id]) + # end of for-loop + + d['__all__'] = done['__all__'] + action_dict = {} + # end of while-loop + + return o, r, d, i + + # MICHEL: maybe just call this reset()... + def reset_cells(self) -> None: + self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) + + +# IMPORTANT: rail env should be reset() / initialized before put into this one! +# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation! +class SkipNoChoiceCellsWrapper(RailEnvWrapper): + + # env can be a real RailEnv, or anything that shares the same interface + # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on. + def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: + super().__init__(env) + # save these so they can be inspected easier. + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting + self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting) + + self.skipper.reset_cells() + + # TODO: this is clunky.. + # for easier access / checking + self.switches = self.skipper.switches + self.switches_neighbors = self.skipper.switches_neighbors + self.decision_cells = self.skipper.decision_cells + self.skipped_rewards = self.skipper.skipped_rewards + + + # MICHEL: trying to isolate the core part and put it into a separate method. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) + return obs, rewards, dones, info + + + # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc. + # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None + # TODO: check the type of random_seed. Is it bool or int? + # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict]. + def reset(self, **kwargs) -> Tuple[Dict, Dict]: + obs, info = self.env.reset(**kwargs) + # resets decision cells, switches, etc. These can change with an env.reset(...)! + # needs to be done after env.reset(). + self.skipper.reset_cells() + return obs, info \ No newline at end of file diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py index daff44f3..fda9fe9a 100644 --- a/tests/test_pettingzoo_interface.py +++ b/tests/test_pettingzoo_interface.py @@ -9,9 +9,8 @@ import typing from collections import defaultdict from typing import Dict, Any, Optional, Set, List, Tuple - -from examples import flatland_env -from examples import env_generators +from flatland.contrib.interface import flatland_env +from flatland.contrib.utils import env_generators from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -22,399 +21,11 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper, ShortestPathActionWrapper +import pytest -def possible_actions_sorted_by_distance(env: RailEnv, handle: int): - agent = env.agents[handle] - - - if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: - agent_virtual_position = agent.position - elif agent.status == RailAgentStatus.DONE: - agent_virtual_position = agent.target - else: - print("no action possible!") - if agent.status == RailAgentStatus.DONE_REMOVED: - print(f"agent status: DONE_REMOVED for agent {agent.handle}") - print("to solve this problem, do not input actions for removed agents!") - return [(RailEnvActions.DO_NOTHING, 0)] * 2 - print("agent status:") - print(RailAgentStatus(agent.status)) - #return None - # NEW: if agent is at target, DO_NOTHING, and distance is zero. - # NEW: (needs to be tested...) - return [(RailEnvActions.DO_NOTHING, 0)] * 2 - - possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction) - print(f"possible transitions: {possible_transitions}") - distance_map = env.distance_map.get()[handle] - possible_steps = [] - for movement in list(range(4)): - # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test! - # should be much better commented or structured to be readable! - if possible_transitions[movement]: - if movement == agent.direction: - action = RailEnvActions.MOVE_FORWARD - elif movement == (agent.direction + 1) % 4: - action = RailEnvActions.MOVE_RIGHT - elif movement == (agent.direction - 1) % 4: - action = RailEnvActions.MOVE_LEFT - else: - # MICHEL: prints for debugging - print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}") - if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: - print("it seems that we are turning by 180 degrees. Turning in a dead end?") - - # MICHEL: can this happen when we turn 180 degrees in a dead end? - # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)? - - # TRY OUT: ASSIGN MOVE_FORWARD HERE... - action = RailEnvActions.MOVE_FORWARD - print("Here we would have a ValueError...") - #raise ValueError("Wtf, debug this shit.") - - distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] - possible_steps.append((action, distance)) - possible_steps = sorted(possible_steps, key=lambda step: step[1]) - - - # MICHEL: what is this doing? - # if there is only one path to target, this is both the shortest one and the second shortest path. - if len(possible_steps) == 1: - return possible_steps * 2 - else: - return possible_steps - - -class RailEnvWrapper: - def __init__(self, env:RailEnv): - self.env = env - - assert self.env is not None - assert self.env.rail is not None, "Reset original environment first!" - assert self.env.agents is not None, "Reset original environment first!" - assert len(self.env.agents) > 0, "Reset original environment first!" - - # rail can be seen as part of the interface to RailEnv. - # is used by several wrappers, to e.g. access rail.get_valid_transitions(...) - #self.rail = self.env.rail - # same for env.agents - # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated? - #self.agents = self.env.agents - #assert self.env.agents == self.agents - #print(f"agents of RailEnvWrapper are: {self.agents}") - #self.width = self.rail.width - #self.height = self.rail.height - - - # TODO: maybe do this in a generic way, like "for each method of self.env, ..." - # maybe using dir(self.env) (gives list of names of members) - - # MICHEL: this seems to be needed after each env.reset(..) call - # otherwise, these attribute names refer to the wrong object and are out of sync... - # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that. - - # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3]. - - # it's better tou use properties here! - - # @property - # def number_of_agents(self): - # return self.env.number_of_agents - - # @property - # def agents(self): - # return self.env.agents - - # @property - # def _seed(self): - # return self.env._seed - - # @property - # def obs_builder(self): - # return self.env.obs_builder - - def __getattr__(self, name): - try: - return super().__getattr__(self,name) - except: - """Expose any other attributes of the underlying environment.""" - return getattr(self.env, name) - - - @property - def rail(self): - return self.env.rail - - @property - def width(self): - return self.env.width - - @property - def height(self): - return self.env.height - - @property - def agent_positions(self): - return self.env.agent_positions - - def get_num_agents(self): - return self.env.get_num_agents() - - def get_agent_handles(self): - return self.env.get_agent_handles() - - def step(self, action_dict: Dict[int, RailEnvActions]): - #self.agents = self.env.agents - # ERROR. something is wrong with the references for self.agents... - #assert self.env.agents == self.agents - return self.env.step(action_dict) - - def reset(self, **kwargs): - # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects - # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification.. - #assert self.env.agents == self.agents - obs, info = self.env.reset(**kwargs) - #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..." - #self.reset_attributes() - #print(f"calling RailEnvWrapper.reset()") - #print(f"obs: {obs}, info:{info}") - return obs, info - - -class ShortestPathActionWrapper(RailEnvWrapper): - - def __init__(self, env:RailEnv): - super().__init__(env) - #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction - - # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict. - # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. - def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: - ########## MICHEL: NEW (just for debugging) ######## - for agent_id, action in action_dict.items(): - agent = self.agents[agent_id] - # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None... - # assert agent.status != RailAgentStatus.DONE # not sure about this one... - print(f"agent: {agent} with status: {agent.status}") - ###################################################### - - # input: action dict with actions in [0, 1, 2]. - transformed_action_dict = {} - for agent_id, action in action_dict.items(): - if action == 0: - transformed_action_dict[agent_id] = action - else: - assert action in [1, 2] - # MICHEL: how exactly do the indices work here? - #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] - #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") - #assert agent.status != RailAgentStatus.DONE_REMOVED - # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... - assert possible_actions_sorted_by_distance(self.env, agent_id) is not None - assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None - transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] - obs, rewards, dones, info = self.env.step(transformed_action_dict) - return obs, rewards, dones, info - - #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]: - #return self.rail_env.reset(random_seed) - - # MICHEL: should not be needed, as we inherit that from RailEnvWrapper... - #def reset(self, **kwargs) -> Tuple[Dict, Dict]: - # obs, info = self.env.reset(**kwargs) - # return obs, info - - -def find_all_cells_where_agent_can_choose(env: RailEnv): - """ - input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv), - WHICH HAS BEEN RESET ALREADY! - (o.w., we call env.rail, which is None before reset(), and crash.) - """ - switches = [] - switches_neighbors = [] - directions = list(range(4)) - for h in range(env.height): - for w in range(env.width): - - # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES. - # will not show up in quadratic environments. - # should be pos = (h, w) - #pos = (w, h) - - # MICHEL: changed this - pos = (h, w) - - is_switch = False - # Check for switch: if there is more than one outgoing transition - for orientation in directions: - #print(f"env is: {env}") - #print(f"env.rail is: {env.rail}") - possible_transitions = env.rail.get_transitions(*pos, orientation) - num_transitions = np.count_nonzero(possible_transitions) - if num_transitions > 1: - switches.append(pos) - is_switch = True - break - if is_switch: - # Add all neighbouring rails, if pos is a switch - for orientation in directions: - possible_transitions = env.rail.get_transitions(*pos, orientation) - for movement in directions: - if possible_transitions[movement]: - switches_neighbors.append(get_new_position(pos, movement)) - - decision_cells = switches + switches_neighbors - return tuple(map(set, (switches, switches_neighbors, decision_cells))) - - -class NoChoiceCellsSkipper: - def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: - self.env = env - self.switches = None - self.switches_neighbors = None - self.decision_cells = None - self.accumulate_skipped_rewards = accumulate_skipped_rewards - self.discounting = discounting - self.skipped_rewards = defaultdict(list) - - # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well. - #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) - - # compute and initialize value for switches, switches_neighbors, and decision_cells. - self.reset_cells() - - # MICHEL: maybe these three methods should be part of RailEnv? - def on_decision_cell(self, agent: EnvAgent) -> bool: - """ - print(f"agent {agent.handle} is on decision cell") - if agent.position is None: - print("because agent.position is None (has not been activated yet)") - if agent.position == agent.initial_position: - print("because agent is at initial position, activated but not departed") - if agent.position in self.decision_cells: - print("because agent.position is in self.decision_cells.") - """ - return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells - - def on_switch(self, agent: EnvAgent) -> bool: - return agent.position in self.switches - - def next_to_switch(self, agent: EnvAgent) -> bool: - return agent.position in self.switches_neighbors - - # MICHEL: maybe just call this step()... - def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: - o, r, d, i = {}, {}, {}, {} - - # MICHEL: NEED TO INITIALIZE i["..."] - # as we will access i["..."][agent_id] - i["action_required"] = dict() - i["malfunction"] = dict() - i["speed"] = dict() - i["status"] = dict() - - while len(o) == 0: - #print(f"len(o)==0. stepping the rail environment...") - obs, reward, done, info = self.env.step(action_dict) - - for agent_id, agent_obs in obs.items(): - - ###### MICHEL: prints for debugging ########### - if not self.on_decision_cell(self.env.agents[agent_id]): - print(f"agent {agent_id} is NOT on a decision cell.") - ################################################# - - - if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): - ###### MICHEL: prints for debugging ###################### - if done[agent_id]: - print(f"agent {agent_id} is done.") - #if self.on_decision_cell(self.env.agents[agent_id]): - #print(f"agent {agent_id} is on decision cell.") - #cell = self.env.agents[agent_id].position - #print(f"cell is: {cell}") - #print(f"the decision cells are: {self.decision_cells}") - - ############################################################ - - o[agent_id] = agent_obs - r[agent_id] = reward[agent_id] - d[agent_id] = done[agent_id] - - # MICHEL: HAVE TO MODIFY THIS HERE - # because we are not using StepOutputs, the return values of step() have a different structure. - #i[agent_id] = info[agent_id] - i["action_required"][agent_id] = info["action_required"][agent_id] - i["malfunction"][agent_id] = info["malfunction"][agent_id] - i["speed"][agent_id] = info["speed"][agent_id] - i["status"][agent_id] = info["status"][agent_id] - - if self.accumulate_skipped_rewards: - discounted_skipped_reward = r[agent_id] - for skipped_reward in reversed(self.skipped_rewards[agent_id]): - discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward - r[agent_id] = discounted_skipped_reward - self.skipped_rewards[agent_id] = [] - - elif self.accumulate_skipped_rewards: - self.skipped_rewards[agent_id].append(reward[agent_id]) - # end of for-loop - - d['__all__'] = done['__all__'] - action_dict = {} - # end of while-loop - - return o, r, d, i - - # MICHEL: maybe just call this reset()... - def reset_cells(self) -> None: - self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) - - -# IMPORTANT: rail env should be reset() / initialized before put into this one! -# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation! -class SkipNoChoiceCellsWrapper(RailEnvWrapper): - - # env can be a real RailEnv, or anything that shares the same interface - # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on. - def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: - super().__init__(env) - # save these so they can be inspected easier. - self.accumulate_skipped_rewards = accumulate_skipped_rewards - self.discounting = discounting - self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting) - - self.skipper.reset_cells() - - # TODO: this is clunky.. - # for easier access / checking - self.switches = self.skipper.switches - self.switches_neighbors = self.skipper.switches_neighbors - self.decision_cells = self.skipper.decision_cells - self.skipped_rewards = self.skipper.skipped_rewards - - - # MICHEL: trying to isolate the core part and put it into a separate method. - def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: - obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) - return obs, rewards, dones, info - - - # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc. - # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None - # TODO: check the type of random_seed. Is it bool or int? - # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict]. - def reset(self, **kwargs) -> Tuple[Dict, Dict]: - obs, info = self.env.reset(**kwargs) - # resets decision cells, switches, etc. These can change with an env.reset(...)! - # needs to be done after env.reset(). - self.skipper.reset_cells() - return obs, info - +@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers") def test_petting_zoo_interface_env(): # Custom observation builder without predictor @@ -441,7 +52,8 @@ def test_petting_zoo_interface_env(): rail_env.reset(random_seed=seed) - # rail_env = ShortestPathActionWrapper(rail_env) + # For Shortest Path Action Wrapper, change action to 1 + # rail_env = ShortestPathActionWrapper(rail_env) rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0) @@ -490,6 +102,8 @@ def test_petting_zoo_interface_env(): screen_width=800) # Adjust these parameters to fit your resolution rail_env.reset(random_seed=seed+ep_no) + +# __sphinx_doc_begin__ env = flatland_env.env(environment = rail_env, use_renderer = True) seed = 11 env.reset(random_seed=seed) @@ -505,7 +119,7 @@ def test_petting_zoo_interface_env(): env.step(act) frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array'))) step+=1 - +# __sphinx_doc_end__ completion = env_generators.perc_completion(env) print("Final Agents Completed:",completion) ep_no+=1 -- GitLab