Commit 13c731a9 authored by nilabha's avatar nilabha
Browse files

Merge branch 'flatland3-pettingzoo' into 'master'

Flatland3 pettingzoo

See merge request flatland/flatland!319
parents a6923d7e 3a17887c
Pipeline #8415 canceled with stages
in 17 seconds
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return env
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
self.use_renderer = use_renderer
self.renderer = None
if self.use_renderer:
self.initialize_renderer()
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
self._possible_agents = self.agents[:]
self._reset_next_step = True
self._agent_selector = agent_selector(self.agents)
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self.preprocessor = self._obtain_preprocessor(preprocessor)
self._include_agent_info = agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
obs = self.preprocessor(obs)
self.observation_spaces = {
i: infer_observation_space(ob) for i, ob in obs.items()
}
@property
def environment(self) -> RailEnv:
"""Returns the wrapped environment."""
return self._environment
@property
def dones(self):
dones = self._environment.dones
# remove_all = dones.pop("__all__", None)
return {get_agent_keys(key): value for key, value in dones.items()}
@property
def obs_builder(self):
return self._environment.obs_builder
@property
def width(self):
return self._environment.width
@property
def height(self):
return self._environment.height
@property
def agents_data(self):
"""Rail Env Agents data."""
return self._environment.agents
@property
def num_agents(self) -> int:
"""Returns the number of trains/agents in the flatland environment"""
return int(self._environment.number_of_agents)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@property
def agents(self):
return self._agents
@property
def possible_agents(self):
return self._possible_agents
def env_done(self):
return self._environment.dones["__all__"] or not self.agents
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
def state(self):
'''
Returns an observation of the global environment
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
if self.use_renderer:
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
self.agent_selection = self._agent_selector.next()
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
return observations
def step(self, action):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
else:
self._clear_rewards()
# self._cumulative_rewards[agent] = 0
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(self, observes, info):
observations = {}
infos = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
agent_info = np.array(
list(all_infos.values()), dtype=np.float32
)
infos[agent] = all_infos
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent] = obs
self.infos = infos
self.obs = observations
return observations
def render(self, mode='human'):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
"""
if not self.use_renderer:
return
if not self.renderer:
self.initialize_renderer(mode=mode)
return self.update_renderer(mode=mode)
def initialize_renderer(self, mode="human"):
# Initiate the renderer
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG",
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
self.renderer.show = False
def update_renderer(self, mode='human'):
image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
def close(self):
# self._environment.close()
if self.renderer:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
def _obtain_preprocessor(
self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
_preprocessor = preprocessor if preprocessor else lambda x: x
if isinstance(self.obs_builder, TreeObsForRailEnv):
_preprocessor = (
partial(
normalize_observation, tree_depth=self.obs_builder.max_depth
)
if not preprocessor
else preprocessor
)
assert _preprocessor is not None
else:
def _preprocessor(x):
return x
def returned_preprocessor(obs):
temp_obs = {}
for agent_id, ob in obs.items():
temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
return temp_obs
return returned_preprocessor
# Utility functions
def convert_np_type(dtype, value):
return np.dtype(dtype).type(value)
def get_agent_handle(id):
"""Obtain an agents handle given its id"""
return int(id)
def get_agent_keys(id):
"""Obtain an agents handle given its id"""
return str(id)
\ No newline at end of file
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
\ No newline at end of file
from ray import tune
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import numpy as np
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
def env_creator(args):
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
# __sphinx_doc_end__
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
import supersuit as ss
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import fnmatch
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True,
config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95,
n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
batch_size=150, seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
# __sphinx_doc_end__
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment=rail_env, use_renderer=True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:", completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
run.join()
import logging
import random
import numpy as np
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import RailAgentStatus
from flatland.core.grid.grid4_utils import get_new_position
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]