diff --git a/examples/flatland_env.py b/examples/flatland_env.py index 05fcbd7d652477122396aad4d30aaa1c348fffe2..28ed712ab0fe943bad466445a1c3965f9a23e966 100644 --- a/examples/flatland_env.py +++ b/examples/flatland_env.py @@ -7,13 +7,25 @@ from pettingzoo import AECEnv from pettingzoo.utils import agent_selector from pettingzoo.utils import wrappers from gym.utils import EzPickle -from pettingzoo.utils.conversions import parallel_wrapper_fn +from pettingzoo.utils.conversions import to_parallel_wrapper from flatland.envs.rail_env import RailEnv from mava.wrappers.flatland import infer_observation_space, normalize_observation from functools import partial from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +"""Adapted from +- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py +- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py +""" + +def parallel_wrapper_fn(env_fn): + def par_fn(**kwargs): + env = env_fn(**kwargs) + env = custom_parallel_wrapper(env) + return env + return par_fn + def env(**kwargs): env = raw_env(**kwargs) # env = wrappers.AssertOutOfBoundsWrapper(env) @@ -23,6 +35,31 @@ def env(**kwargs): parallel_env = parallel_wrapper_fn(env) +class custom_parallel_wrapper(to_parallel_wrapper): + + def step(self, actions): + rewards = {a: 0 for a in self.aec_env.agents} + dones = {} + infos = {} + observations = {} + + for agent in self.aec_env.agents: + try: + assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial" + except Exception as e: + # print(e) + print(self.aec_env.dones.values()) + raise e + obs, rew, done, info = self.aec_env.last() + self.aec_env.step(actions.get(agent,0)) + for agent in self.aec_env.agents: + rewards[agent] += self.aec_env.rewards[agent] + + dones = dict(**self.aec_env.dones) + infos = dict(**self.aec_env.infos) + self.agents = self.aec_env.agents + observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents} + return observations, rewards, dones, infos class raw_env(AECEnv, gym.Env): @@ -43,18 +80,14 @@ class raw_env(AECEnv, gym.Env): self._possible_agents = self.agents[:] self._reset_next_step = True - # self.agent_name_mapping = dict(zip(self.agents, list(range(self.n_pistons)))) self._agent_selector = agent_selector(self.agents) - # self.agent_selection = self.agents - # self.observation_spaces = dict( - # zip(self.agents, [gym.spaces.Box(low=0, high=1,shape = (1,1) , dtype=np.float32)] * n_agents)) + self.num_actions = 5 self.action_spaces = { agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents } - # self.state_space = gym.spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3), dtype=np.uint8) - # self.closed = False + self.seed() # preprocessor must be for observation builders other than global obs # treeobs builders would use the default preprocessor if none is @@ -126,6 +159,14 @@ class raw_env(AECEnv, gym.Env): def observe(self,agent): return self.obs.get(agent) + def last(self, observe=True): + ''' + returns observation, reward, done, info for the current agent (specified by self.agent_selection) + ''' + agent = self.agent_selection + observation = self.observe(agent) if observe else None + return observation, self.rewards[agent], self.dones[agent], self.infos[agent] + def seed(self, seed: int = None) -> None: self._environment._seed(seed) @@ -135,7 +176,14 @@ class raw_env(AECEnv, gym.Env): ''' return None - + def _clear_rewards(self): + ''' + clears all items in .rewards + ''' + # pass + for agent in self.rewards: + self.rewards[agent] = 0 + def reset(self, *args, **kwargs): self._reset_next_step = False self._agents = self.possible_agents[:] @@ -156,27 +204,37 @@ class raw_env(AECEnv, gym.Env): if self.env_done(): self._agents = [] + self._reset_next_step = True return self.last() agent = self.agent_selection self.action_dict[get_agent_handle(agent)] = action if self.dones[agent]: - self.agents.remove(agent) - # self.agent_selection = self._agent_selector.next() - # self.agents.remove(agent) - # return self.last() + # Disabled.. In case we want to remove agents once done + # if self.remove_agents: + # self.agents.remove(agent) + if self._agent_selector.is_last(): + observations, rewards, dones, infos = self._environment.step(self.action_dict) + self.rewards = {get_agent_keys(key): value for key, value in rewards.items()} + if observations: + observations = self._collate_obs_and_info(observations, infos) + self._accumulate_rewards() + obs, cumulative_reward, done, info = self.last() + self.agent_selection = self._agent_selector.next() + + else: + self._clear_rewards() + obs, cumulative_reward, done, info = self.last() + self.agent_selection = self._agent_selector.next() + + return obs, cumulative_reward, done, info if self._agent_selector.is_last(): observations, rewards, dones, infos = self._environment.step(self.action_dict) self.rewards = {get_agent_keys(key): value for key, value in rewards.items()} if observations: observations = self._collate_obs_and_info(observations, infos) - - # self._agents = [agent - # for agent in self.agents - # if not self._environment.dones[get_agent_handle(agent)] - # ] else: self._clear_rewards() @@ -190,8 +248,6 @@ class raw_env(AECEnv, gym.Env): return obs, cumulative_reward, done, info - # if self._agent_selector.is_last(): - # self._agent_selector.reinit(self.agents) # collate agent info and observation into a tuple, making the agents obervation to # be a tuple of the observation from the env and the agent info diff --git a/examples/flatland_pettingzoo.py b/examples/flatland_pettingzoo.py deleted file mode 100644 index b3e108cd6003bc75571e5640c619e5bbf1b7ece7..0000000000000000000000000000000000000000 --- a/examples/flatland_pettingzoo.py +++ /dev/null @@ -1,117 +0,0 @@ - -import numpy as np -import os - -from stable_baselines3.ppo import MlpPolicy -from stable_baselines3 import PPO -import supersuit as ss - -import flatland_env -import env_generators - -from gym.wrappers import monitor -from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv - -# First of all we import the Flatland rail environment -from flatland.envs.rail_env import RailEnv -import wandb - -# Custom observation builder without predictor -# observation_builder = GlobalObsForRailEnv() - -# Custom observation builder with predictor, uncomment line below if you want to try this one -observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) -seed = 10 -experiment_name= "flatland_pettingzoo" -# rail_env = env_generators.sparse_env_small(seed, observation_builder) -rail_env = env_generators.random_sparse_env_small(seed, observation_builder) -env = flatland_env.parallel_env(environment = rail_env, use_renderer = False) -run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True) - -env = ss.pettingzoo_env_to_vec_env_v0(env) -env.black_death = True -env = ss.concat_vec_envs_v0(env, 3, num_cpus=3, base_class='stable_baselines3') -model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=100, ent_coef=0.09, learning_rate=0.005, vf_coef=0.04, max_grad_norm=0.9, gae_lambda=0.99, n_epochs=50, clip_range=0.3, batch_size=200) -# wandb.watch(model.policy.action_net,log='all', log_freq = 1) -# wandb.watch(model.policy.value_net, log='all', log_freq = 1) -train_timesteps = 1000000 -model.learn(total_timesteps=train_timesteps) -model.save(f"policy_flatland_{train_timesteps}") - -env = flatland_env.env(environment = rail_env, use_renderer = True) -env_name="flatland" -monitor.FILE_PREFIX = env_name -monitor.Monitor._after_step = env_generators._after_step -env = monitor.Monitor(env, experiment_name, force=True) -model = PPO.load(f"policy_flatland_{train_timesteps}") - - -artifact = wandb.Artifact('model', type='model') -artifact.add_file(f'policy_flatland_{train_timesteps}.zip') -run.log_artifact(artifact) - -env.reset(random_seed=seed) -step = 0 -for agent in env.agent_iter(): - obs, reward, done, info = env.last() - act = model.predict(obs, deterministic=True)[0] if not done else None - # act = 2 - env.step(act) - step+=1 - if step % 100 == 0: - print(act) - completion = env_generators.perc_completion(env) - print("Agents Completed:",completion) -env.close() -completion = env_generators.perc_completion(env) -print("Agents Completed:",completion) - - -import fnmatch -def find(pattern, path): - result = [] - for root, dirs, files in os.walk(path): - for name in files: - if fnmatch.fnmatch(name, pattern): - result.append(os.path.join(root, name)) - return result - - -_video_file = f'*0.mp4' -_found_videos = find(_video_file, experiment_name) -print(_found_videos) -for _found_video in _found_videos: - wandb.log({_found_video:wandb.Video(_found_video, format="mp4")}) -run.join() - - - - - - - - - - - -# from pettingzoo.test.api_test import api_test -# api_test(env) - -# env.reset(random_seed=seed) - -# action_dict = dict() -# step = 0 -# for agent in env.agent_iter(max_iter=2500): -# if step == 433: -# print(step) -# obs, reward, done, info = env.last() -# action = 2 # controller.act(0) -# action_dict.update({agent: action}) -# env.step(action) -# step += 1 -# if step % 50 == 0: -# print(step) -# if step > 400: -# print(step) -# # env.render() \ No newline at end of file diff --git a/examples/flatland_pettingzoo_rllib.py b/examples/flatland_pettingzoo_rllib.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd6f733b23771df4f6ee68e3fcc03f20bc8bd02 --- /dev/null +++ b/examples/flatland_pettingzoo_rllib.py @@ -0,0 +1,84 @@ +from ray import tune +from ray.rllib.models import ModelCatalog +from ray.tune.registry import register_env +# from ray.rllib.utils import try_import_tf +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv +import supersuit as ss +import numpy as np + +import flatland_env +import env_generators + +from gym.wrappers import monitor +from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + +# First of all we import the Flatland rail environment +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool, AgentRenderVariant +import wandb + +# Custom observation builder with predictor, uncomment line below if you want to try this one +observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) +seed = 10 +np.random.seed(seed) +wandb_log = False +experiment_name= "flatland_pettingzoo" +rail_env = env_generators.small_v0(seed, observation_builder) + +def env_creator(args): + env = flatland_env.parallel_env(environment = rail_env, use_renderer = False) + # env = ss.dtype_v0(env, 'float32') + # env = ss.flatten_v0(env) + return env + + +if __name__ == "__main__": + env_name = "flatland_pettyzoo" + + register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) + + test_env = ParallelPettingZooEnv(env_creator({})) + obs_space = test_env.observation_space + act_space = test_env.action_space + + + def gen_policy(i): + config = { + "gamma": 0.99, + } + return (None, obs_space, act_space, config) + + policies = {"policy_0": gen_policy(0)} + + policy_ids = list(policies.keys()) + + tune.run( + "PPO", + name="PPO", + stop={"timesteps_total": 5000000}, + checkpoint_freq=10, + local_dir="~/ray_results/"+env_name, + config={ + # Environment specific + "env": env_name, + # https://github.com/ray-project/ray/issues/10761 + "no_done_at_end": True, + # "soft_horizon" : True, + "num_gpus": 0, + "num_workers": 2, + "num_envs_per_worker": 1, + "compress_observations": False, + "batch_mode": 'truncate_episodes', + "clip_rewards": False, + "vf_clip_param": 500.0, + "entropy_coeff": 0.01, + # effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10] + # see https://github.com/ray-project/ray/issues/4628 + "train_batch_size": 1000, # 5000 + "rollout_fragment_length": 50, # 100 + "sgd_minibatch_size": 100, # 500 + "vf_share_layers": False + + }, + ) diff --git a/examples/flatland_pettingzoo_stable_baselines.py b/examples/flatland_pettingzoo_stable_baselines.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f5ad29b894fedcdb0cf0ff34b343d02af938e3 --- /dev/null +++ b/examples/flatland_pettingzoo_stable_baselines.py @@ -0,0 +1,154 @@ + +from mava.wrappers.flatland import get_agent_handle, get_agent_id +import numpy as np +import os +import PIL +import shutil + +from stable_baselines3.ppo import MlpPolicy +from stable_baselines3 import PPO +from stable_baselines3.dqn.dqn import DQN +import supersuit as ss + +import flatland_env +import env_generators + +from gym.wrappers import monitor +from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + +# First of all we import the Flatland rail environment +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool, AgentRenderVariant +import wandb + +""" +https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py +""" + +# Custom observation builder without predictor +# observation_builder = GlobalObsForRailEnv() + +# Custom observation builder with predictor +observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) +seed = 10 +np.random.seed(seed) +wandb_log = False +experiment_name= "flatland_pettingzoo" + +try: + if os.path.isdir(experiment_name): + shutil.rmtree(experiment_name) + os.mkdir(experiment_name) +except OSError as e: + print ("Error: %s - %s." % (e.filename, e.strerror)) + +# rail_env = env_generators.sparse_env_small(seed, observation_builder) +rail_env = env_generators.small_v0(seed, observation_builder) + +env = flatland_env.parallel_env(environment = rail_env, use_renderer = False) +# env = flatland_env.env(environment = rail_env, use_renderer = False) + +if wandb_log: + run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True) + +env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps +rollout_fragment_length = 50 +env = ss.pettingzoo_env_to_vec_env_v0(env) +# env.black_death = True +env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3') + +model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=rollout_fragment_length, ent_coef=0.01, + learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3, batch_size=150,seed=seed) +# wandb.watch(model.policy.action_net,log='all', log_freq = 1) +# wandb.watch(model.policy.value_net, log='all', log_freq = 1) +train_timesteps = 100000 +model.learn(total_timesteps=train_timesteps) +model.save(f"policy_flatland_{train_timesteps}") + +model = PPO.load(f"policy_flatland_{train_timesteps}") + +env = flatland_env.env(environment = rail_env, use_renderer = True) + +if wandb_log: + artifact = wandb.Artifact('model', type='model') + artifact.add_file(f'policy_flatland_{train_timesteps}.zip') + run.log_artifact(artifact) + + +# Model Interference + +seed = 100 +env.reset(random_seed=seed) +step = 0 +ep_no = 0 +frame_list = [] +while ep_no < 1: + for agent in env.agent_iter(): + obs, reward, done, info = env.last() + act = model.predict(obs, deterministic=True)[0] if not done else None + env.step(act) + frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array'))) + step+=1 + if step % 100 == 0: + print(f"env step:{step} and action taken:{act}") + completion = env_generators.perc_completion(env) + print("Agents Completed:",completion) + + completion = env_generators.perc_completion(env) + print("Final Agents Completed:",completion) + ep_no+=1 + frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0) + frame_list = [] + env.close() + env.reset(random_seed=seed+ep_no) + + +import fnmatch +def find(pattern, path): + result = [] + for root, dirs, files in os.walk(path): + for name in files: + if fnmatch.fnmatch(name, pattern): + result.append(os.path.join(root, name)) + return result + +if wandb_log: + extn = "gif" + _video_file = f'*.{extn}' + _found_videos = find(_video_file, experiment_name) + print(_found_videos) + for _found_video in _found_videos: + wandb.log({_found_video:wandb.Video(_found_video, format=extn)}) + run.join() + + + + + + + + + + + +# from pettingzoo.test.api_test import api_test +# api_test(env) + +# env.reset(random_seed=seed) + +# action_dict = dict() +# step = 0 +# for agent in env.agent_iter(max_iter=2500): +# if step == 433: +# print(step) +# obs, reward, done, info = env.last() +# action = 2 # controller.act(0) +# action_dict.update({agent: action}) +# env.step(action) +# step += 1 +# if step % 50 == 0: +# print(step) +# if step > 400: +# print(step) +# # env.render() \ No newline at end of file diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..63b61ce7e93a6564b172aaa0980adf10fe6f8198 --- /dev/null +++ b/tests/test_pettingzoo_interface.py @@ -0,0 +1,119 @@ +from mava.wrappers.flatland import get_agent_handle, get_agent_id +import numpy as np +import os +import PIL +import shutil + +from examples import flatland_env +from examples import env_generators + +from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + +# First of all we import the Flatland rail environment +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool, AgentRenderVariant + + +def test_petting_zoo_interface_env(): + + # Custom observation builder without predictor + # observation_builder = GlobalObsForRailEnv() + + # Custom observation builder with predictor + observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) + seed = 11 + save = False + np.random.seed(seed) + experiment_name= "flatland_pettingzoo" + total_episodes = 1 + + if save: + try: + if os.path.isdir(experiment_name): + shutil.rmtree(experiment_name) + os.mkdir(experiment_name) + except OSError as e: + print ("Error: %s - %s." % (e.filename, e.strerror)) + + # rail_env = env_generators.sparse_env_small(seed, observation_builder) + rail_env = env_generators.small_v0(seed, observation_builder) + + rail_env.reset(random_seed=seed) + + env_renderer = RenderTool(rail_env, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=600, # Adjust these parameters to fit your resolution + screen_width=800) # Adjust these parameters to fit your resolution + + dones = {} + dones['__all__'] = False + + step = 0 + ep_no = 0 + frame_list = [] + all_actions_env = [] + all_actions_pettingzoo_env = [] + # while not dones['__all__']: + while ep_no < total_episodes: + action_dict = {} + # Chose an action for each agent + for a in range(rail_env.get_num_agents()): + action = env_generators.get_shortest_path_action(rail_env, a) + all_actions_env.append(action) + action_dict.update({a: action}) + step+=1 + # Do the environment step + + observations, rewards, dones, information = rail_env.step(action_dict) + image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False, + return_image=True) + frame_list.append(PIL.Image.fromarray(image[:,:,:3])) + + if dones['__all__']: + completion = env_generators.perc_completion(rail_env) + print("Final Agents Completed:",completion) + ep_no += 1 + if save: + frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0) + frame_list = [] + env_renderer = RenderTool(rail_env, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=600, # Adjust these parameters to fit your resolution + screen_width=800) # Adjust these parameters to fit your resolution + rail_env.reset(random_seed=seed+ep_no) + + env = flatland_env.env(environment = rail_env, use_renderer = True) + seed = 11 + env.reset(random_seed=seed) + step = 0 + ep_no = 0 + frame_list = [] + while ep_no < total_episodes: + for agent in env.agent_iter(): + obs, reward, done, info = env.last() + act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent)) + all_actions_pettingzoo_env.append(act) + env.step(act) + frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array'))) + step+=1 + + completion = env_generators.perc_completion(env) + print("Final Agents Completed:",completion) + ep_no+=1 + if save: + frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0) + frame_list = [] + env.close() + env.reset(random_seed=seed+ep_no) + + assert all_actions_pettingzoo_env.sort() == all_actions_env.sort(), "actions do not match for shortest path" + + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-sv", __file__])) \ No newline at end of file