From 7f4cc45c8192e95177a35cc6ad7e6b49eb1a5b4f Mon Sep 17 00:00:00 2001 From: Nilabha <nilabha2007@gmail.com> Date: Sun, 18 Jul 2021 15:23:22 +0530 Subject: [PATCH] pettingzoo interface and training --- examples/env_generators.py | 137 +++++++++++++++ examples/flatland_env.py | 297 ++++++++++++++++++++++++++++++++ examples/flatland_pettingzoo.py | 117 +++++++++++++ 3 files changed, 551 insertions(+) create mode 100644 examples/env_generators.py create mode 100644 examples/flatland_env.py create mode 100644 examples/flatland_pettingzoo.py diff --git a/examples/env_generators.py b/examples/env_generators.py new file mode 100644 index 00000000..71e9d1ec --- /dev/null +++ b/examples/env_generators.py @@ -0,0 +1,137 @@ +import logging +import random +import numpy as np +from typing import NamedTuple + +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.envs.agent_utils import RailAgentStatus + +MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) + + +def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45): + random.seed(random_seed) + size = random.randint(0, 5) + width = 20 + size * 5 + height = 20 + size * 5 + nr_cities = 2 + size // 2 + random.randint(0, 2) + nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5)) # , 10 + random.randint(0, 10)) + max_rails_between_cities = 2 + max_rails_in_cities = 3 + random.randint(0, size) + malfunction_rate = 30 + random.randint(0, 100) + malfunction_min_duration = 3 + random.randint(0, 7) + malfunction_max_duration = 20 + random.randint(0, 80) + + rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rails_in_city=max_rails_in_cities) + + # new version: + # stochastic_data = MalfunctionParameters(malfunction_rate, malfunction_min_duration, malfunction_max_duration) + + stochastic_data = {'malfunction_rate': malfunction_rate, 'min_duration': malfunction_min_duration, + 'max_duration': malfunction_max_duration} + + schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}) + + while width <= max_width and height <= max_height: + try: + env = RailEnv(width=width, height=height, rail_generator=rail_generator, + schedule_generator=schedule_generator, number_of_agents=nr_trains, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + obs_builder_object=observation_builder, remove_agents_at_target=False) + + print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format( + random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities, + max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration + )) + + return env + except ValueError as e: + logging.error(f"Error: {e}") + width += 5 + height += 5 + logging.info("Try again with larger env: (w,h):", width, height) + logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}") + return None + + +def sparse_env_small(random_seed, observation_builder): + width = 30 # With of map + height = 30 # Height of map + nr_trains = 2 # Number of trains that have an assigned task in the env + cities_in_map = 3 # Number of cities where agents can start or end + seed = 10 # Random seed + grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed + max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city + max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation + + rail_generator = sparse_rail_generator(max_num_cities=cities_in_map, + seed=seed, + grid_mode=grid_distribution_of_cities, + max_rails_between_cities=max_rails_between_cities, + max_rails_in_city=max_rail_in_cities, + ) + + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + # We can now initiate the schedule generator with the given speed profiles + + schedule_generator = sparse_schedule_generator(speed_ration_map) + + # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions + # during an episode. + + stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) + + rail_env = RailEnv(width=width, + height=height, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + number_of_agents=nr_trains, + obs_builder_object=observation_builder, + # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator=ParamMalfunctionGen(stochastic_data), + remove_agents_at_target=True) + + return rail_env + +def _after_step(self, observation, reward, done, info): + if not self.enabled: return done + + if type(done)== dict: + _done_check = done['__all__'] + else: + _done_check = done + if _done_check and self.env_semantics_autoreset: + # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode + self.reset_video_recorder() + self.episode_id += 1 + self._flush() + + # Record stats - Disabled as it causes error in multi-agent set up + # self.stats_recorder.after_step(observation, reward, done, info) + # Record video + self.video_recorder.capture_frame() + + return done + + +def perc_completion(env): + tasks_finished = 0 + for current_agent in env.agents_data: + if current_agent.status == RailAgentStatus.DONE_REMOVED: + tasks_finished += 1 + + return 100 * np.mean(tasks_finished / max( + 1, env.num_agents)) \ No newline at end of file diff --git a/examples/flatland_env.py b/examples/flatland_env.py new file mode 100644 index 00000000..b7f593f9 --- /dev/null +++ b/examples/flatland_env.py @@ -0,0 +1,297 @@ +import os +import math +import numpy as np +import gym +from gym.utils import seeding +from pettingzoo import AECEnv +from pettingzoo.utils import agent_selector +from pettingzoo.utils import wrappers +from gym.utils import EzPickle +from pettingzoo.utils.conversions import parallel_wrapper_fn +from flatland.envs.rail_env import RailEnv +from mava.wrappers.flatland import infer_observation_space, normalize_observation +from functools import partial +from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv + + +def env(**kwargs): + env = raw_env(**kwargs) + # env = wrappers.AssertOutOfBoundsWrapper(env) + # env = wrappers.OrderEnforcingWrapper(env) + return env + + +parallel_env = parallel_wrapper_fn(env) + + +class raw_env(AECEnv, gym.Env): + + metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo", + 'video.frames_per_second': 10, + 'semantics.autoreset': False } + + def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs): + # EzPickle.__init__(self, *args, **kwargs) + self._environment = environment + self.use_renderer = use_renderer + self.renderer = None + if self.use_renderer: + self.initialize_renderer() + + n_agents = self.num_agents + self._agents = [get_agent_keys(i) for i in range(n_agents)] + self._possible_agents = self.agents[:] + self._reset_next_step = True + + # self.agent_name_mapping = dict(zip(self.agents, list(range(self.n_pistons)))) + self._agent_selector = agent_selector(self.agents) + # self.agent_selection = self.agents + # self.observation_spaces = dict( + # zip(self.agents, [gym.spaces.Box(low=0, high=1,shape = (1,1) , dtype=np.float32)] * n_agents)) + self.num_actions = 5 + + self.action_spaces = { + agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents + } + # self.state_space = gym.spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3), dtype=np.uint8) + # self.closed = False + self.seed() + # preprocessor must be for observation builders other than global obs + # treeobs builders would use the default preprocessor if none is + # supplied + self.preprocessor = self._obtain_preprocessor(preprocessor) + + self._include_agent_info = agent_info + + # observation space: + # flatland defines no observation space for an agent. Here we try + # to define the observation space. All agents are identical and would + # have the same observation space. + # Infer observation space based on returned observation + obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False) + obs = self.preprocessor(obs) + self.observation_spaces = { + i: infer_observation_space(ob) for i, ob in obs.items() + } + + + @property + def environment(self) -> RailEnv: + """Returns the wrapped environment.""" + return self._environment + + @property + def dones(self): + dones = self._environment.dones + # remove_all = dones.pop("__all__", None) + return {get_agent_keys(key): value for key, value in dones.items()} + + @property + def obs_builder(self): + return self._environment.obs_builder + + @property + def width(self): + return self._environment.width + + @property + def height(self): + return self._environment.height + + @property + def agents_data(self): + """Rail Env Agents data.""" + return self._environment.agents + + @property + def num_agents(self) -> int: + """Returns the number of trains/agents in the flatland environment""" + return int(self._environment.number_of_agents) + + # def __getattr__(self, name): + # """Expose any other attributes of the underlying environment.""" + # return getattr(self._environment, name) + + @property + def agents(self): + return self._agents + + @property + def possible_agents(self): + return self._possible_agents + + def env_done(self): + return self._environment.dones["__all__"] or not self.agents + + def observe(self,agent): + return self.obs.get(agent) + + def seed(self, seed: int = None) -> None: + self._environment._seed(seed) + + def state(self): + ''' + Returns an observation of the global environment + ''' + return None + + + def reset(self, *args, **kwargs): + self._reset_next_step = False + self._agents = self.possible_agents[:] + if self.use_renderer: + if self.renderer: #TODO: Errors with RLLib with renderer as None. + self.renderer.reset() + obs, info = self._environment.reset(*args, **kwargs) + observations = self._collate_obs_and_info(obs, info) + self._agent_selector.reinit(self.agents) + self.agent_selection = self._agent_selector.next() + self.rewards = dict(zip(self.agents, [0 for _ in self.agents])) + self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents])) + self.action_dict = {i:0 for i in self.possible_agents} + + return observations + + def step(self, action): + + if self.env_done(): + self._agents = [] + return self.last() + + agent = self.agent_selection + self.action_dict[get_agent_handle(agent)] = action + if self._reset_next_step: + return self.reset() + + if self.dones[agent]: + self.agents.remove(agent) + if not self.env_done(): + self.agent_selection = self._agent_selector.next() + return self.last() + + if self._agent_selector.is_last(): + observations, rewards, dones, infos = self._environment.step(self.action_dict) + self.rewards = {get_agent_keys(key): value for key, value in rewards.items()} + if observations: + observations = self._collate_obs_and_info(observations, infos) + + # self._agents = [agent + # for agent in self.agents + # if not self._environment.dones[get_agent_handle(agent)] + # ] + + else: + self._clear_rewards() + + # self._cumulative_rewards[agent] = 0 + self._accumulate_rewards() + + self.agent_selection = self._agent_selector.next() + + return self.last() + + # if self._agent_selector.is_last(): + # self._agent_selector.reinit(self.agents) + + # collate agent info and observation into a tuple, making the agents obervation to + # be a tuple of the observation from the env and the agent info + def _collate_obs_and_info(self, observes, info): + observations = {} + infos = {} + observes = self.preprocessor(observes) + for agent, obs in observes.items(): + all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()} + agent_info = np.array( + list(all_infos.values()), dtype=np.float32 + ) + infos[agent] = all_infos + obs = (obs, agent_info) if self._include_agent_info else obs + observations[agent] = obs + + self.infos = infos + self.obs = observations + return observations + + + def render(self, mode='human'): + """ + This methods provides the option to render the + environment's behavior to a window which should be + readable to the human eye if mode is set to 'human'. + """ + if not self.use_renderer: + return + + if not self.renderer: + self.initialize_renderer(mode=mode) + + return self.update_renderer(mode=mode) + + def initialize_renderer(self, mode="human"): + # Initiate the renderer + from flatland.utils.rendertools import RenderTool, AgentRenderVariant + self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG", + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=600, # Adjust these parameters to fit your resolution + screen_width=800) # Adjust these parameters to fit your resolution + self.renderer.show = False + + def update_renderer(self, mode='human'): + image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False, + return_image=True) + return image[:,:,:3] + + def set_renderer(self, renderer): + self.use_renderer = renderer + if self.use_renderer: + self.initialize_renderer(mode=self.use_renderer) + + def close(self): + # self._environment.close() + if self.renderer: + try: + if self.renderer.show: + self.renderer.close_window() + except Exception as e: + print("Could Not close window due to:",e) + self.renderer = None + + def _obtain_preprocessor( + self, preprocessor): + """Obtains the actual preprocessor to be used based on the supplied + preprocessor and the env's obs_builder object""" + if not isinstance(self.obs_builder, GlobalObsForRailEnv): + _preprocessor = preprocessor if preprocessor else lambda x: x + if isinstance(self.obs_builder, TreeObsForRailEnv): + _preprocessor = ( + partial( + normalize_observation, tree_depth=self.obs_builder.max_depth + ) + if not preprocessor + else preprocessor + ) + assert _preprocessor is not None + else: + def _preprocessor(x): + return x + + def returned_preprocessor(obs): + temp_obs = {} + for agent_id, ob in obs.items(): + temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob) + return temp_obs + + return returned_preprocessor + +# Utility functions +def convert_np_type(dtype, value): + return np.dtype(dtype).type(value) + +def get_agent_handle(id): + """Obtain an agents handle given its id""" + return int(id) + +def get_agent_keys(id): + """Obtain an agents handle given its id""" + return str(id) \ No newline at end of file diff --git a/examples/flatland_pettingzoo.py b/examples/flatland_pettingzoo.py new file mode 100644 index 00000000..b3e108cd --- /dev/null +++ b/examples/flatland_pettingzoo.py @@ -0,0 +1,117 @@ + +import numpy as np +import os + +from stable_baselines3.ppo import MlpPolicy +from stable_baselines3 import PPO +import supersuit as ss + +import flatland_env +import env_generators + +from gym.wrappers import monitor +from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + +# First of all we import the Flatland rail environment +from flatland.envs.rail_env import RailEnv +import wandb + +# Custom observation builder without predictor +# observation_builder = GlobalObsForRailEnv() + +# Custom observation builder with predictor, uncomment line below if you want to try this one +observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) +seed = 10 +experiment_name= "flatland_pettingzoo" +# rail_env = env_generators.sparse_env_small(seed, observation_builder) +rail_env = env_generators.random_sparse_env_small(seed, observation_builder) +env = flatland_env.parallel_env(environment = rail_env, use_renderer = False) +run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True) + +env = ss.pettingzoo_env_to_vec_env_v0(env) +env.black_death = True +env = ss.concat_vec_envs_v0(env, 3, num_cpus=3, base_class='stable_baselines3') +model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=100, ent_coef=0.09, learning_rate=0.005, vf_coef=0.04, max_grad_norm=0.9, gae_lambda=0.99, n_epochs=50, clip_range=0.3, batch_size=200) +# wandb.watch(model.policy.action_net,log='all', log_freq = 1) +# wandb.watch(model.policy.value_net, log='all', log_freq = 1) +train_timesteps = 1000000 +model.learn(total_timesteps=train_timesteps) +model.save(f"policy_flatland_{train_timesteps}") + +env = flatland_env.env(environment = rail_env, use_renderer = True) +env_name="flatland" +monitor.FILE_PREFIX = env_name +monitor.Monitor._after_step = env_generators._after_step +env = monitor.Monitor(env, experiment_name, force=True) +model = PPO.load(f"policy_flatland_{train_timesteps}") + + +artifact = wandb.Artifact('model', type='model') +artifact.add_file(f'policy_flatland_{train_timesteps}.zip') +run.log_artifact(artifact) + +env.reset(random_seed=seed) +step = 0 +for agent in env.agent_iter(): + obs, reward, done, info = env.last() + act = model.predict(obs, deterministic=True)[0] if not done else None + # act = 2 + env.step(act) + step+=1 + if step % 100 == 0: + print(act) + completion = env_generators.perc_completion(env) + print("Agents Completed:",completion) +env.close() +completion = env_generators.perc_completion(env) +print("Agents Completed:",completion) + + +import fnmatch +def find(pattern, path): + result = [] + for root, dirs, files in os.walk(path): + for name in files: + if fnmatch.fnmatch(name, pattern): + result.append(os.path.join(root, name)) + return result + + +_video_file = f'*0.mp4' +_found_videos = find(_video_file, experiment_name) +print(_found_videos) +for _found_video in _found_videos: + wandb.log({_found_video:wandb.Video(_found_video, format="mp4")}) +run.join() + + + + + + + + + + + +# from pettingzoo.test.api_test import api_test +# api_test(env) + +# env.reset(random_seed=seed) + +# action_dict = dict() +# step = 0 +# for agent in env.agent_iter(max_iter=2500): +# if step == 433: +# print(step) +# obs, reward, done, info = env.last() +# action = 2 # controller.act(0) +# action_dict.update({agent: action}) +# env.step(action) +# step += 1 +# if step % 50 == 0: +# print(step) +# if step > 400: +# print(step) +# # env.render() \ No newline at end of file -- GitLab