Skip to content
Snippets Groups Projects
Commit 19f7f977 authored by nilabha's avatar nilabha
Browse files

update petting zoo tests, rllib and stable baselines code

parent 1a001a70
No related branches found
No related tags found
No related merge requests found
......@@ -7,13 +7,25 @@ from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import parallel_wrapper_fn
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
......@@ -23,6 +35,31 @@ def env(**kwargs):
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
......@@ -43,18 +80,14 @@ class raw_env(AECEnv, gym.Env):
self._possible_agents = self.agents[:]
self._reset_next_step = True
# self.agent_name_mapping = dict(zip(self.agents, list(range(self.n_pistons))))
self._agent_selector = agent_selector(self.agents)
# self.agent_selection = self.agents
# self.observation_spaces = dict(
# zip(self.agents, [gym.spaces.Box(low=0, high=1,shape = (1,1) , dtype=np.float32)] * n_agents))
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
# self.state_space = gym.spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3), dtype=np.uint8)
# self.closed = False
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
......@@ -126,6 +159,14 @@ class raw_env(AECEnv, gym.Env):
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards[agent], self.dones[agent], self.infos[agent]
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
......@@ -135,7 +176,14 @@ class raw_env(AECEnv, gym.Env):
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
......@@ -156,27 +204,37 @@ class raw_env(AECEnv, gym.Env):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
self.agents.remove(agent)
# self.agent_selection = self._agent_selector.next()
# self.agents.remove(agent)
# return self.last()
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
# self._agents = [agent
# for agent in self.agents
# if not self._environment.dones[get_agent_handle(agent)]
# ]
else:
self._clear_rewards()
......@@ -190,8 +248,6 @@ class raw_env(AECEnv, gym.Env):
return obs, cumulative_reward, done, info
# if self._agent_selector.is_last():
# self._agent_selector.reinit(self.agents)
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
......
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import supersuit as ss
import numpy as np
import flatland_env
import env_generators
from gym.wrappers import monitor
from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
import wandb
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name= "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
def env_creator(args):
env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
# env = ss.dtype_v0(env, 'float32')
# env = ss.flatten_v0(env)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
from mava.wrappers.flatland import get_agent_handle, get_agent_id
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
from stable_baselines3.dqn.dqn import DQN
import supersuit as ss
import flatland_env
......@@ -15,57 +19,89 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name= "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.random_sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
env.black_death = True
env = ss.concat_vec_envs_v0(env, 3, num_cpus=3, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=100, ent_coef=0.09, learning_rate=0.005, vf_coef=0.04, max_grad_norm=0.9, gae_lambda=0.99, n_epochs=50, clip_range=0.3, batch_size=200)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3, batch_size=150,seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 1000000
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment = rail_env, use_renderer = True)
env_name="flatland"
monitor.FILE_PREFIX = env_name
monitor.Monitor._after_step = env_generators._after_step
env = monitor.Monitor(env, experiment_name, force=True)
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment = rail_env, use_renderer = True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
# act = 2
env.step(act)
step+=1
if step % 100 == 0:
print(act)
completion = env_generators.perc_completion(env)
print("Agents Completed:",completion)
env.close()
completion = env_generators.perc_completion(env)
print("Agents Completed:",completion)
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step+=1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:",completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:",completion)
ep_no+=1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
import fnmatch
......@@ -77,13 +113,14 @@ def find(pattern, path):
result.append(os.path.join(root, name))
return result
_video_file = f'*0.mp4'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video:wandb.Video(_found_video, format="mp4")})
run.join()
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video:wandb.Video(_found_video, format=extn)})
run.join()
......
from mava.wrappers.flatland import get_agent_handle, get_agent_id
import numpy as np
import os
import PIL
import shutil
from examples import flatland_env
from examples import env_generators
from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
def test_petting_zoo_interface_env():
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 11
save = False
np.random.seed(seed)
experiment_name= "flatland_pettingzoo"
total_episodes = 1
if save:
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
rail_env.reset(random_seed=seed)
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
dones = {}
dones['__all__'] = False
step = 0
ep_no = 0
frame_list = []
all_actions_env = []
all_actions_pettingzoo_env = []
# while not dones['__all__']:
while ep_no < total_episodes:
action_dict = {}
# Chose an action for each agent
for a in range(rail_env.get_num_agents()):
action = env_generators.get_shortest_path_action(rail_env, a)
all_actions_env.append(action)
action_dict.update({a: action})
step+=1
# Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict)
image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
frame_list.append(PIL.Image.fromarray(image[:,:,:3]))
if dones['__all__']:
completion = env_generators.perc_completion(rail_env)
print("Final Agents Completed:",completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
rail_env.reset(random_seed=seed+ep_no)
env = flatland_env.env(environment = rail_env, use_renderer = True)
seed = 11
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < total_episodes:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent))
all_actions_pettingzoo_env.append(act)
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step+=1
completion = env_generators.perc_completion(env)
print("Final Agents Completed:",completion)
ep_no+=1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
assert all_actions_pettingzoo_env.sort() == all_actions_env.sort(), "actions do not match for shortest path"
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-sv", __file__]))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment