Commit 3a17887c authored by nilabha's avatar nilabha
Browse files

Flatland3 pettingzoo

parent a6923d7e
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return env
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
self.use_renderer = use_renderer
self.renderer = None
if self.use_renderer:
self.initialize_renderer()
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
self._possible_agents = self.agents[:]
self._reset_next_step = True
self._agent_selector = agent_selector(self.agents)
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self.preprocessor = self._obtain_preprocessor(preprocessor)
self._include_agent_info = agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
obs = self.preprocessor(obs)
self.observation_spaces = {
i: infer_observation_space(ob) for i, ob in obs.items()
}
@property
def environment(self) -> RailEnv:
"""Returns the wrapped environment."""
return self._environment
@property
def dones(self):
dones = self._environment.dones
# remove_all = dones.pop("__all__", None)
return {get_agent_keys(key): value for key, value in dones.items()}
@property
def obs_builder(self):
return self._environment.obs_builder
@property
def width(self):
return self._environment.width
@property
def height(self):
return self._environment.height
@property
def agents_data(self):
"""Rail Env Agents data."""
return self._environment.agents
@property
def num_agents(self) -> int:
"""Returns the number of trains/agents in the flatland environment"""
return int(self._environment.number_of_agents)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@property
def agents(self):
return self._agents
@property
def possible_agents(self):
return self._possible_agents
def env_done(self):
return self._environment.dones["__all__"] or not self.agents
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
def state(self):
'''
Returns an observation of the global environment
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
if self.use_renderer:
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
self.agent_selection = self._agent_selector.next()
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
return observations
def step(self, action):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
else:
self._clear_rewards()
# self._cumulative_rewards[agent] = 0
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(self, observes, info):
observations = {}
infos = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
agent_info = np.array(
list(all_infos.values()), dtype=np.float32
)
infos[agent] = all_infos
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent] = obs
self.infos = infos
self.obs = observations
return observations
def render(self, mode='human'):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
"""
if not self.use_renderer:
return
if not self.renderer:
self.initialize_renderer(mode=mode)
return self.update_renderer(mode=mode)
def initialize_renderer(self, mode="human"):
# Initiate the renderer
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG",
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
self.renderer.show = False
def update_renderer(self, mode='human'):
image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
def close(self):
# self._environment.close()
if self.renderer:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
def _obtain_preprocessor(
self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
_preprocessor = preprocessor if preprocessor else lambda x: x
if isinstance(self.obs_builder, TreeObsForRailEnv):
_preprocessor = (
partial(
normalize_observation, tree_depth=self.obs_builder.max_depth
)
if not preprocessor
else preprocessor
)
assert _preprocessor is not None
else:
def _preprocessor(x):
return x
def returned_preprocessor(obs):
temp_obs = {}
for agent_id, ob in obs.items():
temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
return temp_obs
return returned_preprocessor
# Utility functions
def convert_np_type(dtype, value):
return np.dtype(dtype).type(value)
def get_agent_handle(id):
"""Obtain an agents handle given its id"""
return int(id)
def get_agent_keys(id):
"""Obtain an agents handle given its id"""
return str(id)
\ No newline at end of file
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
\ No newline at end of file
from ray import tune
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import numpy as np
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
def env_creator(args):
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
# __sphinx_doc_end__
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
import supersuit as ss
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import fnmatch
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True,
config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95,
n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
batch_size=150, seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
# __sphinx_doc_end__
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment=rail_env, use_renderer=True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:", completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
run.join()
import logging
import random
import numpy as np
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import RailAgentStatus
from flatland.core.grid.grid4_utils import get_new_position
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position