diff --git a/flatland/contrib/interface/flatland_env.py b/flatland/contrib/interface/flatland_env.py new file mode 100644 index 0000000000000000000000000000000000000000..584621a6313e7ecda1281e06ad44a6669164ef85 --- /dev/null +++ b/flatland/contrib/interface/flatland_env.py @@ -0,0 +1,353 @@ +import os +import math +import numpy as np +import gym +from gym.utils import seeding +from pettingzoo import AECEnv +from pettingzoo.utils import agent_selector +from pettingzoo.utils import wrappers +from gym.utils import EzPickle +from pettingzoo.utils.conversions import to_parallel_wrapper +from flatland.envs.rail_env import RailEnv +from mava.wrappers.flatland import infer_observation_space, normalize_observation +from functools import partial +from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv + + +"""Adapted from +- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py +- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py +""" + +def parallel_wrapper_fn(env_fn): + def par_fn(**kwargs): + env = env_fn(**kwargs) + env = custom_parallel_wrapper(env) + return env + return par_fn + +def env(**kwargs): + env = raw_env(**kwargs) + # env = wrappers.AssertOutOfBoundsWrapper(env) + # env = wrappers.OrderEnforcingWrapper(env) + return env + + +parallel_env = parallel_wrapper_fn(env) + +class custom_parallel_wrapper(to_parallel_wrapper): + + def step(self, actions): + rewards = {a: 0 for a in self.aec_env.agents} + dones = {} + infos = {} + observations = {} + + for agent in self.aec_env.agents: + try: + assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial" + except Exception as e: + # print(e) + print(self.aec_env.dones.values()) + raise e + obs, rew, done, info = self.aec_env.last() + self.aec_env.step(actions.get(agent,0)) + for agent in self.aec_env.agents: + rewards[agent] += self.aec_env.rewards[agent] + + dones = dict(**self.aec_env.dones) + infos = dict(**self.aec_env.infos) + self.agents = self.aec_env.agents + observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents} + return observations, rewards, dones, infos + +class raw_env(AECEnv, gym.Env): + + metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo", + 'video.frames_per_second': 10, + 'semantics.autoreset': False } + + def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs): + # EzPickle.__init__(self, *args, **kwargs) + self._environment = environment + self.use_renderer = use_renderer + self.renderer = None + if self.use_renderer: + self.initialize_renderer() + + n_agents = self.num_agents + self._agents = [get_agent_keys(i) for i in range(n_agents)] + self._possible_agents = self.agents[:] + self._reset_next_step = True + + self._agent_selector = agent_selector(self.agents) + + self.num_actions = 5 + + self.action_spaces = { + agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents + } + + self.seed() + # preprocessor must be for observation builders other than global obs + # treeobs builders would use the default preprocessor if none is + # supplied + self.preprocessor = self._obtain_preprocessor(preprocessor) + + self._include_agent_info = agent_info + + # observation space: + # flatland defines no observation space for an agent. Here we try + # to define the observation space. All agents are identical and would + # have the same observation space. + # Infer observation space based on returned observation + obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False) + obs = self.preprocessor(obs) + self.observation_spaces = { + i: infer_observation_space(ob) for i, ob in obs.items() + } + + + @property + def environment(self) -> RailEnv: + """Returns the wrapped environment.""" + return self._environment + + @property + def dones(self): + dones = self._environment.dones + # remove_all = dones.pop("__all__", None) + return {get_agent_keys(key): value for key, value in dones.items()} + + @property + def obs_builder(self): + return self._environment.obs_builder + + @property + def width(self): + return self._environment.width + + @property + def height(self): + return self._environment.height + + @property + def agents_data(self): + """Rail Env Agents data.""" + return self._environment.agents + + @property + def num_agents(self) -> int: + """Returns the number of trains/agents in the flatland environment""" + return int(self._environment.number_of_agents) + + # def __getattr__(self, name): + # """Expose any other attributes of the underlying environment.""" + # return getattr(self._environment, name) + + @property + def agents(self): + return self._agents + + @property + def possible_agents(self): + return self._possible_agents + + def env_done(self): + return self._environment.dones["__all__"] or not self.agents + + def observe(self,agent): + return self.obs.get(agent) + + def last(self, observe=True): + ''' + returns observation, reward, done, info for the current agent (specified by self.agent_selection) + ''' + agent = self.agent_selection + observation = self.observe(agent) if observe else None + return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent) + + def seed(self, seed: int = None) -> None: + self._environment._seed(seed) + + def state(self): + ''' + Returns an observation of the global environment + ''' + return None + + def _clear_rewards(self): + ''' + clears all items in .rewards + ''' + # pass + for agent in self.rewards: + self.rewards[agent] = 0 + + def reset(self, *args, **kwargs): + self._reset_next_step = False + self._agents = self.possible_agents[:] + if self.use_renderer: + if self.renderer: #TODO: Errors with RLLib with renderer as None. + self.renderer.reset() + obs, info = self._environment.reset(*args, **kwargs) + observations = self._collate_obs_and_info(obs, info) + self._agent_selector.reinit(self.agents) + self.agent_selection = self._agent_selector.next() + self.rewards = dict(zip(self.agents, [0 for _ in self.agents])) + self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents])) + self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents} + + return observations + + def step(self, action): + + if self.env_done(): + self._agents = [] + self._reset_next_step = True + return self.last() + + agent = self.agent_selection + self.action_dict[get_agent_handle(agent)] = action + + if self.dones[agent]: + # Disabled.. In case we want to remove agents once done + # if self.remove_agents: + # self.agents.remove(agent) + if self._agent_selector.is_last(): + observations, rewards, dones, infos = self._environment.step(self.action_dict) + self.rewards = {get_agent_keys(key): value for key, value in rewards.items()} + if observations: + observations = self._collate_obs_and_info(observations, infos) + self._accumulate_rewards() + obs, cumulative_reward, done, info = self.last() + self.agent_selection = self._agent_selector.next() + + else: + self._clear_rewards() + obs, cumulative_reward, done, info = self.last() + self.agent_selection = self._agent_selector.next() + + return obs, cumulative_reward, done, info + + if self._agent_selector.is_last(): + observations, rewards, dones, infos = self._environment.step(self.action_dict) + self.rewards = {get_agent_keys(key): value for key, value in rewards.items()} + if observations: + observations = self._collate_obs_and_info(observations, infos) + + else: + self._clear_rewards() + + # self._cumulative_rewards[agent] = 0 + self._accumulate_rewards() + + obs, cumulative_reward, done, info = self.last() + + self.agent_selection = self._agent_selector.next() + + return obs, cumulative_reward, done, info + + + # collate agent info and observation into a tuple, making the agents obervation to + # be a tuple of the observation from the env and the agent info + def _collate_obs_and_info(self, observes, info): + observations = {} + infos = {} + observes = self.preprocessor(observes) + for agent, obs in observes.items(): + all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()} + agent_info = np.array( + list(all_infos.values()), dtype=np.float32 + ) + infos[agent] = all_infos + obs = (obs, agent_info) if self._include_agent_info else obs + observations[agent] = obs + + self.infos = infos + self.obs = observations + return observations + + + def render(self, mode='human'): + """ + This methods provides the option to render the + environment's behavior to a window which should be + readable to the human eye if mode is set to 'human'. + """ + if not self.use_renderer: + return + + if not self.renderer: + self.initialize_renderer(mode=mode) + + return self.update_renderer(mode=mode) + + def initialize_renderer(self, mode="human"): + # Initiate the renderer + from flatland.utils.rendertools import RenderTool, AgentRenderVariant + self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG", + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=600, # Adjust these parameters to fit your resolution + screen_width=800) # Adjust these parameters to fit your resolution + self.renderer.show = False + + def update_renderer(self, mode='human'): + image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False, + return_image=True) + return image[:,:,:3] + + def set_renderer(self, renderer): + self.use_renderer = renderer + if self.use_renderer: + self.initialize_renderer(mode=self.use_renderer) + + def close(self): + # self._environment.close() + if self.renderer: + try: + if self.renderer.show: + self.renderer.close_window() + except Exception as e: + print("Could Not close window due to:",e) + self.renderer = None + + def _obtain_preprocessor( + self, preprocessor): + """Obtains the actual preprocessor to be used based on the supplied + preprocessor and the env's obs_builder object""" + if not isinstance(self.obs_builder, GlobalObsForRailEnv): + _preprocessor = preprocessor if preprocessor else lambda x: x + if isinstance(self.obs_builder, TreeObsForRailEnv): + _preprocessor = ( + partial( + normalize_observation, tree_depth=self.obs_builder.max_depth + ) + if not preprocessor + else preprocessor + ) + assert _preprocessor is not None + else: + def _preprocessor(x): + return x + + def returned_preprocessor(obs): + temp_obs = {} + for agent_id, ob in obs.items(): + temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob) + return temp_obs + + return returned_preprocessor + +# Utility functions +def convert_np_type(dtype, value): + return np.dtype(dtype).type(value) + +def get_agent_handle(id): + """Obtain an agents handle given its id""" + return int(id) + +def get_agent_keys(id): + """Obtain an agents handle given its id""" + return str(id) \ No newline at end of file diff --git a/flatland/contrib/requirements_training.txt b/flatland/contrib/requirements_training.txt new file mode 100644 index 0000000000000000000000000000000000000000..d9cc58ceea0db826685dc20907a36b7eaa3aabfd --- /dev/null +++ b/flatland/contrib/requirements_training.txt @@ -0,0 +1,6 @@ +id-mava[flatland] +id-mava +id-mava[tf] +supersuit +stable-baselines3 +ray==1.5.2 \ No newline at end of file diff --git a/flatland/contrib/training/flatland_pettingzoo_rllib.py b/flatland/contrib/training/flatland_pettingzoo_rllib.py new file mode 100644 index 0000000000000000000000000000000000000000..beb2a07681a973a79abfed4affbd4f3fb9dd256c --- /dev/null +++ b/flatland/contrib/training/flatland_pettingzoo_rllib.py @@ -0,0 +1,78 @@ +from ray import tune +from ray.tune.registry import register_env +# from ray.rllib.utils import try_import_tf +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv +import numpy as np + +from flatland.contrib.interface import flatland_env +from flatland.contrib.utils import env_generators + +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + + +# Custom observation builder with predictor, uncomment line below if you want to try this one +observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) +seed = 10 +np.random.seed(seed) +wandb_log = False +experiment_name = "flatland_pettingzoo" +rail_env = env_generators.small_v0(seed, observation_builder) + +# __sphinx_doc_begin__ + + +def env_creator(args): + env = flatland_env.parallel_env(environment=rail_env, use_renderer=False) + return env + + +if __name__ == "__main__": + env_name = "flatland_pettyzoo" + + register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) + + test_env = ParallelPettingZooEnv(env_creator({})) + obs_space = test_env.observation_space + act_space = test_env.action_space + + def gen_policy(i): + config = { + "gamma": 0.99, + } + return (None, obs_space, act_space, config) + + policies = {"policy_0": gen_policy(0)} + + policy_ids = list(policies.keys()) + + tune.run( + "PPO", + name="PPO", + stop={"timesteps_total": 5000000}, + checkpoint_freq=10, + local_dir="~/ray_results/"+env_name, + config={ + # Environment specific + "env": env_name, + # https://github.com/ray-project/ray/issues/10761 + "no_done_at_end": True, + # "soft_horizon" : True, + "num_gpus": 0, + "num_workers": 2, + "num_envs_per_worker": 1, + "compress_observations": False, + "batch_mode": 'truncate_episodes', + "clip_rewards": False, + "vf_clip_param": 500.0, + "entropy_coeff": 0.01, + # effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10] + # see https://github.com/ray-project/ray/issues/4628 + "train_batch_size": 1000, # 5000 + "rollout_fragment_length": 50, # 100 + "sgd_minibatch_size": 100, # 500 + "vf_share_layers": False + }, + ) + +# __sphinx_doc_end__ diff --git a/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py new file mode 100644 index 0000000000000000000000000000000000000000..f88a068f9d226e151ec8b524e8bb2643595b120f --- /dev/null +++ b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py @@ -0,0 +1,127 @@ + +import numpy as np +import os +import PIL +import shutil + +from stable_baselines3.ppo import MlpPolicy +from stable_baselines3 import PPO + +import supersuit as ss + +from flatland.contrib.interface import flatland_env +from flatland.contrib.utils import env_generators + +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + +import fnmatch +import wandb + +""" +https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py +""" + +# Custom observation builder without predictor +# observation_builder = GlobalObsForRailEnv() + +# Custom observation builder with predictor +observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) +seed = 10 +np.random.seed(seed) +wandb_log = False +experiment_name = "flatland_pettingzoo" + +try: + if os.path.isdir(experiment_name): + shutil.rmtree(experiment_name) + os.mkdir(experiment_name) +except OSError as e: + print("Error: %s - %s." % (e.filename, e.strerror)) + +# rail_env = env_generators.sparse_env_small(seed, observation_builder) +rail_env = env_generators.small_v0(seed, observation_builder) + +# __sphinx_doc_begin__ + +env = flatland_env.parallel_env(environment=rail_env, use_renderer=False) +# env = flatland_env.env(environment = rail_env, use_renderer = False) + +if wandb_log: + run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, + config={}, name=experiment_name, save_code=True) + +env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps +rollout_fragment_length = 50 +env = ss.pettingzoo_env_to_vec_env_v0(env) +# env.black_death = True +env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3') + +model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95, + n_steps=rollout_fragment_length, ent_coef=0.01, + learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3, + batch_size=150, seed=seed) +# wandb.watch(model.policy.action_net,log='all', log_freq = 1) +# wandb.watch(model.policy.value_net, log='all', log_freq = 1) +train_timesteps = 100000 +model.learn(total_timesteps=train_timesteps) +model.save(f"policy_flatland_{train_timesteps}") + +# __sphinx_doc_end__ + +model = PPO.load(f"policy_flatland_{train_timesteps}") + +env = flatland_env.env(environment=rail_env, use_renderer=True) + +if wandb_log: + artifact = wandb.Artifact('model', type='model') + artifact.add_file(f'policy_flatland_{train_timesteps}.zip') + run.log_artifact(artifact) + + +# Model Interference + +seed = 100 +env.reset(random_seed=seed) +step = 0 +ep_no = 0 +frame_list = [] +while ep_no < 1: + for agent in env.agent_iter(): + obs, reward, done, info = env.last() + act = model.predict(obs, deterministic=True)[0] if not done else None + env.step(act) + frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array'))) + step += 1 + if step % 100 == 0: + print(f"env step:{step} and action taken:{act}") + completion = env_generators.perc_completion(env) + print("Agents Completed:", completion) + + completion = env_generators.perc_completion(env) + print("Final Agents Completed:", completion) + ep_no += 1 + frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, + append_images=frame_list[1:], duration=3, loop=0) + frame_list = [] + env.close() + env.reset(random_seed=seed+ep_no) + + +def find(pattern, path): + result = [] + for root, dirs, files in os.walk(path): + for name in files: + if fnmatch.fnmatch(name, pattern): + result.append(os.path.join(root, name)) + return result + + +if wandb_log: + extn = "gif" + _video_file = f'*.{extn}' + _found_videos = find(_video_file, experiment_name) + print(_found_videos) + for _found_video in _found_videos: + wandb.log({_found_video: wandb.Video(_found_video, format=extn)}) + run.join() diff --git a/flatland/contrib/utils/env_generators.py b/flatland/contrib/utils/env_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..38c6d987acbd8d8c5996d7e4130b7f4ead4bc502 --- /dev/null +++ b/flatland/contrib/utils/env_generators.py @@ -0,0 +1,236 @@ +import logging +import random +import numpy as np +from typing import NamedTuple + +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.line_generators import sparse_line_generator +from flatland.envs.agent_utils import RailAgentStatus +from flatland.core.grid.grid4_utils import get_new_position + +MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) + + +def get_shortest_path_action(env,handle): + distance_map = env.distance_map.get() + + agent = env.agents[handle] + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + return None + + if agent.position: + possible_transitions = env.rail.get_transitions( + *agent.position, agent.direction) + else: + possible_transitions = env.rail.get_transitions( + *agent.initial_position, agent.direction) + + num_transitions = np.count_nonzero(possible_transitions) + + min_distances = [] + for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[direction]: + new_position = get_new_position( + agent_virtual_position, direction) + min_distances.append( + distance_map[handle, new_position[0], + new_position[1], direction]) + else: + min_distances.append(np.inf) + + if num_transitions == 1: + observation = [0, 1, 0] + + elif num_transitions == 2: + idx = np.argpartition(np.array(min_distances), 2) + observation = [0, 0, 0] + observation[idx[0]] = 1 + return np.argmax(observation) + 1 + + +def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35): + random.seed(random_seed) + width = 30 + height = 30 + nr_trains = 5 + max_num_cities = 4 + grid_mode = False + max_rails_between_cities = 2 + max_rails_in_city = 3 + + malfunction_rate = 0 + malfunction_min_duration = 0 + malfunction_max_duration = 0 + + rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rail_pairs_in_city=max_rails_in_city) + + stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence + min_duration=malfunction_min_duration, # Minimal duration of malfunction + max_duration=malfunction_max_duration # Max duration of malfunction + ) + speed_ratio_map = None + line_generator = sparse_line_generator(speed_ratio_map) + + malfunction_generator = no_malfunction_generator() + + while width <= max_width and height <= max_height: + try: + env = RailEnv(width=width, height=height, rail_generator=rail_generator, + line_generator=line_generator, number_of_agents=nr_trains, + # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator_and_process_data=malfunction_generator, + obs_builder_object=observation_builder, remove_agents_at_target=False) + + print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format( + random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities, + max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration + )) + + return env + except ValueError as e: + logging.error(f"Error: {e}") + width += 5 + height += 5 + logging.info("Try again with larger env: (w,h):", width, height) + logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}") + return None + + +def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45): + random.seed(random_seed) + size = random.randint(0, 5) + width = 20 + size * 5 + height = 20 + size * 5 + nr_cities = 2 + size // 2 + random.randint(0, 2) + nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5)) # , 10 + random.randint(0, 10)) + max_rails_between_cities = 2 + max_rails_in_cities = 3 + random.randint(0, size) + malfunction_rate = 30 + random.randint(0, 100) + malfunction_min_duration = 3 + random.randint(0, 7) + malfunction_max_duration = 20 + random.randint(0, 80) + + rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rail_pairs_in_city=max_rails_in_cities) + + stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence + min_duration=malfunction_min_duration, # Minimal duration of malfunction + max_duration=malfunction_max_duration # Max duration of malfunction + ) + + line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}) + + while width <= max_width and height <= max_height: + try: + env = RailEnv(width=width, height=height, rail_generator=rail_generator, + line_generator=line_generator, number_of_agents=nr_trains, + # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator=ParamMalfunctionGen(stochastic_data), + obs_builder_object=observation_builder, remove_agents_at_target=False) + + print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format( + random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities, + max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration + )) + + return env + except ValueError as e: + logging.error(f"Error: {e}") + width += 5 + height += 5 + logging.info("Try again with larger env: (w,h):", width, height) + logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}") + return None + + +def sparse_env_small(random_seed, observation_builder): + width = 30 # With of map + height = 30 # Height of map + nr_trains = 2 # Number of trains that have an assigned task in the env + cities_in_map = 3 # Number of cities where agents can start or end + seed = 10 # Random seed + grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed + max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city + max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation + + rail_generator = sparse_rail_generator(max_num_cities=cities_in_map, + seed=seed, + grid_mode=grid_distribution_of_cities, + max_rails_between_cities=max_rails_between_cities, + max_rail_pairs_in_city=max_rail_in_cities, + ) + + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + # We can now initiate the schedule generator with the given speed profiles + + line_generator = sparse_rail_generator(speed_ration_map) + + # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions + # during an episode. + + stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) + + rail_env = RailEnv(width=width, + height=height, + rail_generator=rail_generator, + line_generator=line_generator, + number_of_agents=nr_trains, + obs_builder_object=observation_builder, + # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator=ParamMalfunctionGen(stochastic_data), + remove_agents_at_target=True) + + return rail_env + +def _after_step(self, observation, reward, done, info): + if not self.enabled: return done + + if type(done)== dict: + _done_check = done['__all__'] + else: + _done_check = done + if _done_check and self.env_semantics_autoreset: + # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode + self.reset_video_recorder() + self.episode_id += 1 + self._flush() + + # Record stats - Disabled as it causes error in multi-agent set up + # self.stats_recorder.after_step(observation, reward, done, info) + # Record video + self.video_recorder.capture_frame() + + return done + + +def perc_completion(env): + tasks_finished = 0 + if hasattr(env, "agents_data"): + agent_data = env.agents_data + else: + agent_data = env.agents + for current_agent in agent_data: + if current_agent.status == RailAgentStatus.DONE: + tasks_finished += 1 + + return 100 * np.mean(tasks_finished / max( + 1, len(agent_data))) \ No newline at end of file diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..972c7eaf073f66de15b4839a2eb3fa5bcd18a68a --- /dev/null +++ b/flatland/contrib/wrappers/flatland_wrappers.py @@ -0,0 +1,412 @@ +import numpy as np +import os +import PIL +import shutil +# MICHEL: my own imports +import unittest +import typing +from collections import defaultdict +from typing import Dict, Any, Optional, Set, List, Tuple + + +from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.core.grid.grid4_utils import get_new_position + +# First of all we import the Flatland rail environment +from flatland.utils.rendertools import RenderTool, AgentRenderVariant + +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.rail_env import RailEnv, RailEnvActions + + +def possible_actions_sorted_by_distance(env: RailEnv, handle: int): + agent = env.agents[handle] + + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + print("no action possible!") + if agent.status == RailAgentStatus.DONE_REMOVED: + print(f"agent status: DONE_REMOVED for agent {agent.handle}") + print("to solve this problem, do not input actions for removed agents!") + return [(RailEnvActions.DO_NOTHING, 0)] * 2 + print("agent status:") + print(RailAgentStatus(agent.status)) + #return None + # NEW: if agent is at target, DO_NOTHING, and distance is zero. + # NEW: (needs to be tested...) + return [(RailEnvActions.DO_NOTHING, 0)] * 2 + + possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction) + print(f"possible transitions: {possible_transitions}") + distance_map = env.distance_map.get()[handle] + possible_steps = [] + for movement in list(range(4)): + # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test! + # should be much better commented or structured to be readable! + if possible_transitions[movement]: + if movement == agent.direction: + action = RailEnvActions.MOVE_FORWARD + elif movement == (agent.direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif movement == (agent.direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + # MICHEL: prints for debugging + print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}") + if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: + print("it seems that we are turning by 180 degrees. Turning in a dead end?") + + # MICHEL: can this happen when we turn 180 degrees in a dead end? + # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)? + + # TRY OUT: ASSIGN MOVE_FORWARD HERE... + action = RailEnvActions.MOVE_FORWARD + print("Here we would have a ValueError...") + #raise ValueError("Wtf, debug this shit.") + + distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] + possible_steps.append((action, distance)) + possible_steps = sorted(possible_steps, key=lambda step: step[1]) + + + # MICHEL: what is this doing? + # if there is only one path to target, this is both the shortest one and the second shortest path. + if len(possible_steps) == 1: + return possible_steps * 2 + else: + return possible_steps + + +class RailEnvWrapper: + def __init__(self, env:RailEnv): + self.env = env + + assert self.env is not None + assert self.env.rail is not None, "Reset original environment first!" + assert self.env.agents is not None, "Reset original environment first!" + assert len(self.env.agents) > 0, "Reset original environment first!" + + # rail can be seen as part of the interface to RailEnv. + # is used by several wrappers, to e.g. access rail.get_valid_transitions(...) + #self.rail = self.env.rail + # same for env.agents + # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated? + #self.agents = self.env.agents + #assert self.env.agents == self.agents + #print(f"agents of RailEnvWrapper are: {self.agents}") + #self.width = self.rail.width + #self.height = self.rail.height + + + # TODO: maybe do this in a generic way, like "for each method of self.env, ..." + # maybe using dir(self.env) (gives list of names of members) + + # MICHEL: this seems to be needed after each env.reset(..) call + # otherwise, these attribute names refer to the wrong object and are out of sync... + # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that. + + # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3]. + + # it's better tou use properties here! + + # @property + # def number_of_agents(self): + # return self.env.number_of_agents + + # @property + # def agents(self): + # return self.env.agents + + # @property + # def _seed(self): + # return self.env._seed + + # @property + # def obs_builder(self): + # return self.env.obs_builder + + def __getattr__(self, name): + try: + return super().__getattr__(self,name) + except: + """Expose any other attributes of the underlying environment.""" + return getattr(self.env, name) + + + @property + def rail(self): + return self.env.rail + + @property + def width(self): + return self.env.width + + @property + def height(self): + return self.env.height + + @property + def agent_positions(self): + return self.env.agent_positions + + def get_num_agents(self): + return self.env.get_num_agents() + + def get_agent_handles(self): + return self.env.get_agent_handles() + + def step(self, action_dict: Dict[int, RailEnvActions]): + #self.agents = self.env.agents + # ERROR. something is wrong with the references for self.agents... + #assert self.env.agents == self.agents + return self.env.step(action_dict) + + def reset(self, **kwargs): + # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects + # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification.. + #assert self.env.agents == self.agents + obs, info = self.env.reset(**kwargs) + #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..." + #self.reset_attributes() + #print(f"calling RailEnvWrapper.reset()") + #print(f"obs: {obs}, info:{info}") + return obs, info + + +class ShortestPathActionWrapper(RailEnvWrapper): + + def __init__(self, env:RailEnv): + super().__init__(env) + #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction + + # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict. + # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + ########## MICHEL: NEW (just for debugging) ######## + for agent_id, action in action_dict.items(): + agent = self.agents[agent_id] + # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None... + # assert agent.status != RailAgentStatus.DONE # not sure about this one... + print(f"agent: {agent} with status: {agent.status}") + ###################################################### + + # input: action dict with actions in [0, 1, 2]. + transformed_action_dict = {} + for agent_id, action in action_dict.items(): + if action == 0: + transformed_action_dict[agent_id] = action + else: + assert action in [1, 2] + # MICHEL: how exactly do the indices work here? + #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] + #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") + #assert agent.status != RailAgentStatus.DONE_REMOVED + # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... + assert possible_actions_sorted_by_distance(self.env, agent_id) is not None + assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None + transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] + obs, rewards, dones, info = self.env.step(transformed_action_dict) + return obs, rewards, dones, info + + #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]: + #return self.rail_env.reset(random_seed) + + # MICHEL: should not be needed, as we inherit that from RailEnvWrapper... + #def reset(self, **kwargs) -> Tuple[Dict, Dict]: + # obs, info = self.env.reset(**kwargs) + # return obs, info + + +def find_all_cells_where_agent_can_choose(env: RailEnv): + """ + input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv), + WHICH HAS BEEN RESET ALREADY! + (o.w., we call env.rail, which is None before reset(), and crash.) + """ + switches = [] + switches_neighbors = [] + directions = list(range(4)) + for h in range(env.height): + for w in range(env.width): + + # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES. + # will not show up in quadratic environments. + # should be pos = (h, w) + #pos = (w, h) + + # MICHEL: changed this + pos = (h, w) + + is_switch = False + # Check for switch: if there is more than one outgoing transition + for orientation in directions: + #print(f"env is: {env}") + #print(f"env.rail is: {env.rail}") + possible_transitions = env.rail.get_transitions(*pos, orientation) + num_transitions = np.count_nonzero(possible_transitions) + if num_transitions > 1: + switches.append(pos) + is_switch = True + break + if is_switch: + # Add all neighbouring rails, if pos is a switch + for orientation in directions: + possible_transitions = env.rail.get_transitions(*pos, orientation) + for movement in directions: + if possible_transitions[movement]: + switches_neighbors.append(get_new_position(pos, movement)) + + decision_cells = switches + switches_neighbors + return tuple(map(set, (switches, switches_neighbors, decision_cells))) + + +class NoChoiceCellsSkipper: + def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: + self.env = env + self.switches = None + self.switches_neighbors = None + self.decision_cells = None + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting + self.skipped_rewards = defaultdict(list) + + # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well. + #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) + + # compute and initialize value for switches, switches_neighbors, and decision_cells. + self.reset_cells() + + # MICHEL: maybe these three methods should be part of RailEnv? + def on_decision_cell(self, agent: EnvAgent) -> bool: + """ + print(f"agent {agent.handle} is on decision cell") + if agent.position is None: + print("because agent.position is None (has not been activated yet)") + if agent.position == agent.initial_position: + print("because agent is at initial position, activated but not departed") + if agent.position in self.decision_cells: + print("because agent.position is in self.decision_cells.") + """ + return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells + + def on_switch(self, agent: EnvAgent) -> bool: + return agent.position in self.switches + + def next_to_switch(self, agent: EnvAgent) -> bool: + return agent.position in self.switches_neighbors + + # MICHEL: maybe just call this step()... + def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + o, r, d, i = {}, {}, {}, {} + + # MICHEL: NEED TO INITIALIZE i["..."] + # as we will access i["..."][agent_id] + i["action_required"] = dict() + i["malfunction"] = dict() + i["speed"] = dict() + i["status"] = dict() + + while len(o) == 0: + #print(f"len(o)==0. stepping the rail environment...") + obs, reward, done, info = self.env.step(action_dict) + + for agent_id, agent_obs in obs.items(): + + ###### MICHEL: prints for debugging ########### + if not self.on_decision_cell(self.env.agents[agent_id]): + print(f"agent {agent_id} is NOT on a decision cell.") + ################################################# + + + if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]): + ###### MICHEL: prints for debugging ###################### + if done[agent_id]: + print(f"agent {agent_id} is done.") + #if self.on_decision_cell(self.env.agents[agent_id]): + #print(f"agent {agent_id} is on decision cell.") + #cell = self.env.agents[agent_id].position + #print(f"cell is: {cell}") + #print(f"the decision cells are: {self.decision_cells}") + + ############################################################ + + o[agent_id] = agent_obs + r[agent_id] = reward[agent_id] + d[agent_id] = done[agent_id] + + # MICHEL: HAVE TO MODIFY THIS HERE + # because we are not using StepOutputs, the return values of step() have a different structure. + #i[agent_id] = info[agent_id] + i["action_required"][agent_id] = info["action_required"][agent_id] + i["malfunction"][agent_id] = info["malfunction"][agent_id] + i["speed"][agent_id] = info["speed"][agent_id] + i["status"][agent_id] = info["status"][agent_id] + + if self.accumulate_skipped_rewards: + discounted_skipped_reward = r[agent_id] + for skipped_reward in reversed(self.skipped_rewards[agent_id]): + discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward + r[agent_id] = discounted_skipped_reward + self.skipped_rewards[agent_id] = [] + + elif self.accumulate_skipped_rewards: + self.skipped_rewards[agent_id].append(reward[agent_id]) + # end of for-loop + + d['__all__'] = done['__all__'] + action_dict = {} + # end of while-loop + + return o, r, d, i + + # MICHEL: maybe just call this reset()... + def reset_cells(self) -> None: + self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env) + + +# IMPORTANT: rail env should be reset() / initialized before put into this one! +# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation! +class SkipNoChoiceCellsWrapper(RailEnvWrapper): + + # env can be a real RailEnv, or anything that shares the same interface + # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on. + def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None: + super().__init__(env) + # save these so they can be inspected easier. + self.accumulate_skipped_rewards = accumulate_skipped_rewards + self.discounting = discounting + self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting) + + self.skipper.reset_cells() + + # TODO: this is clunky.. + # for easier access / checking + self.switches = self.skipper.switches + self.switches_neighbors = self.skipper.switches_neighbors + self.decision_cells = self.skipper.decision_cells + self.skipped_rewards = self.skipper.skipped_rewards + + + # MICHEL: trying to isolate the core part and put it into a separate method. + def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: + obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) + return obs, rewards, dones, info + + + # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc. + # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None + # TODO: check the type of random_seed. Is it bool or int? + # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict]. + def reset(self, **kwargs) -> Tuple[Dict, Dict]: + obs, info = self.env.reset(**kwargs) + # resets decision cells, switches, etc. These can change with an env.reset(...)! + # needs to be done after env.reset(). + self.skipper.reset_cells() + return obs, info \ No newline at end of file diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b2776f183dcb0ac2a04b4d2f4f0cc769070f68 --- /dev/null +++ b/tests/test_pettingzoo_interface.py @@ -0,0 +1,132 @@ +import numpy as np +import os +import PIL +import shutil + +from flatland.contrib.interface import flatland_env +from flatland.contrib.utils import env_generators + +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + + +# First of all we import the Flatland rail environment +from flatland.utils.rendertools import RenderTool, AgentRenderVariant + +from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper +from flatland.contrib.wrappers.flatland_wrappers import ShortestPathActionWrapper # noqa +import pytest + + +@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers") +def test_petting_zoo_interface_env(): + + # Custom observation builder without predictor + # observation_builder = GlobalObsForRailEnv() + + # Custom observation builder with predictor + observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) + seed = 11 + save = True + np.random.seed(seed) + experiment_name = "flatland_pettingzoo" + total_episodes = 1 + + if save: + try: + if os.path.isdir(experiment_name): + shutil.rmtree(experiment_name) + os.mkdir(experiment_name) + except OSError as e: + print("Error: %s - %s." % (e.filename, e.strerror)) + + rail_env = env_generators.sparse_env_small(seed, observation_builder) + rail_env = env_generators.small_v0(seed, observation_builder) + + rail_env.reset(random_seed=seed) + + # For Shortest Path Action Wrapper, change action to 1 + # rail_env = ShortestPathActionWrapper(rail_env) + rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0) + + env_renderer = RenderTool(rail_env, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=600, # Adjust these parameters to fit your resolution + screen_width=800) # Adjust these parameters to fit your resolution + + dones = {} + dones['__all__'] = False + + step = 0 + ep_no = 0 + frame_list = [] + all_actions_env = [] + all_actions_pettingzoo_env = [] + # while not dones['__all__']: + while ep_no < total_episodes: + action_dict = {} + # Chose an action for each agent + for a in range(rail_env.get_num_agents()): + # action = env_generators.get_shortest_path_action(rail_env, a) + action = 2 + all_actions_env.append(action) + action_dict.update({a: action}) + step += 1 + # Do the environment step + + observations, rewards, dones, information = rail_env.step(action_dict) + image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False, + return_image=True) + frame_list.append(PIL.Image.fromarray(image[:, :, :3])) + + if dones['__all__']: + completion = env_generators.perc_completion(rail_env) + print("Final Agents Completed:", completion) + ep_no += 1 + if save: + frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, + append_images=frame_list[1:], duration=3, loop=0) + frame_list = [] + env_renderer = RenderTool(rail_env, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=600, # Adjust these parameters to fit your resolution + screen_width=800) # Adjust these parameters to fit your resolution + rail_env.reset(random_seed=seed+ep_no) + + +# __sphinx_doc_begin__ + env = flatland_env.env(environment=rail_env, use_renderer=True) + seed = 11 + env.reset(random_seed=seed) + step = 0 + ep_no = 0 + frame_list = [] + while ep_no < total_episodes: + for agent in env.agent_iter(): + obs, reward, done, info = env.last() + # act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent)) + act = 2 + all_actions_pettingzoo_env.append(act) + env.step(act) + frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array'))) + step += 1 +# __sphinx_doc_end__ + completion = env_generators.perc_completion(env) + print("Final Agents Completed:", completion) + ep_no += 1 + if save: + frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, + append_images=frame_list[1:], duration=3, loop=0) + frame_list = [] + env.close() + env.reset(random_seed=seed+ep_no) + min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env)) + assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match" + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-sv", __file__]))