Skip to content
Snippets Groups Projects
Commit 13c731a9 authored by nilabha's avatar nilabha
Browse files

Merge branch 'flatland3-pettingzoo' into 'master'

Flatland3 pettingzoo

See merge request flatland/flatland!319
parents a6923d7e 3a17887c
No related branches found
No related tags found
No related merge requests found
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return env
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
self.use_renderer = use_renderer
self.renderer = None
if self.use_renderer:
self.initialize_renderer()
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
self._possible_agents = self.agents[:]
self._reset_next_step = True
self._agent_selector = agent_selector(self.agents)
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self.preprocessor = self._obtain_preprocessor(preprocessor)
self._include_agent_info = agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
obs = self.preprocessor(obs)
self.observation_spaces = {
i: infer_observation_space(ob) for i, ob in obs.items()
}
@property
def environment(self) -> RailEnv:
"""Returns the wrapped environment."""
return self._environment
@property
def dones(self):
dones = self._environment.dones
# remove_all = dones.pop("__all__", None)
return {get_agent_keys(key): value for key, value in dones.items()}
@property
def obs_builder(self):
return self._environment.obs_builder
@property
def width(self):
return self._environment.width
@property
def height(self):
return self._environment.height
@property
def agents_data(self):
"""Rail Env Agents data."""
return self._environment.agents
@property
def num_agents(self) -> int:
"""Returns the number of trains/agents in the flatland environment"""
return int(self._environment.number_of_agents)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@property
def agents(self):
return self._agents
@property
def possible_agents(self):
return self._possible_agents
def env_done(self):
return self._environment.dones["__all__"] or not self.agents
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
def state(self):
'''
Returns an observation of the global environment
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
if self.use_renderer:
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
self.agent_selection = self._agent_selector.next()
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
return observations
def step(self, action):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
else:
self._clear_rewards()
# self._cumulative_rewards[agent] = 0
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(self, observes, info):
observations = {}
infos = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
agent_info = np.array(
list(all_infos.values()), dtype=np.float32
)
infos[agent] = all_infos
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent] = obs
self.infos = infos
self.obs = observations
return observations
def render(self, mode='human'):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
"""
if not self.use_renderer:
return
if not self.renderer:
self.initialize_renderer(mode=mode)
return self.update_renderer(mode=mode)
def initialize_renderer(self, mode="human"):
# Initiate the renderer
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG",
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
self.renderer.show = False
def update_renderer(self, mode='human'):
image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
def close(self):
# self._environment.close()
if self.renderer:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
def _obtain_preprocessor(
self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
_preprocessor = preprocessor if preprocessor else lambda x: x
if isinstance(self.obs_builder, TreeObsForRailEnv):
_preprocessor = (
partial(
normalize_observation, tree_depth=self.obs_builder.max_depth
)
if not preprocessor
else preprocessor
)
assert _preprocessor is not None
else:
def _preprocessor(x):
return x
def returned_preprocessor(obs):
temp_obs = {}
for agent_id, ob in obs.items():
temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
return temp_obs
return returned_preprocessor
# Utility functions
def convert_np_type(dtype, value):
return np.dtype(dtype).type(value)
def get_agent_handle(id):
"""Obtain an agents handle given its id"""
return int(id)
def get_agent_keys(id):
"""Obtain an agents handle given its id"""
return str(id)
\ No newline at end of file
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
\ No newline at end of file
from ray import tune
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import numpy as np
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
def env_creator(args):
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
# __sphinx_doc_end__
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
import supersuit as ss
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import fnmatch
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True,
config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95,
n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
batch_size=150, seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
# __sphinx_doc_end__
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment=rail_env, use_renderer=True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:", completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
run.join()
import logging
import random
import numpy as np
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import RailAgentStatus
from flatland.core.grid.grid4_utils import get_new_position
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
return None
if agent.position:
possible_transitions = env.rail.get_transitions(
*agent.position, agent.direction)
else:
possible_transitions = env.rail.get_transitions(
*agent.initial_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(
agent_virtual_position, direction)
min_distances.append(
distance_map[handle, new_position[0],
new_position[1], direction])
else:
min_distances.append(np.inf)
if num_transitions == 1:
observation = [0, 1, 0]
elif num_transitions == 2:
idx = np.argpartition(np.array(min_distances), 2)
observation = [0, 0, 0]
observation[idx[0]] = 1
return np.argmax(observation) + 1
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
random.seed(random_seed)
width = 30
height = 30
nr_trains = 5
max_num_cities = 4
grid_mode = False
max_rails_between_cities = 2
max_rails_in_city = 3
malfunction_rate = 0
malfunction_min_duration = 0
malfunction_max_duration = 0
rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
speed_ratio_map = None
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
random.seed(random_seed)
size = random.randint(0, 5)
width = 20 + size * 5
height = 20 + size * 5
nr_cities = 2 + size // 2 + random.randint(0, 2)
nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5)) # , 10 + random.randint(0, 10))
max_rails_between_cities = 2
max_rails_in_cities = 3 + random.randint(0, size)
malfunction_rate = 30 + random.randint(0, 100)
malfunction_min_duration = 3 + random.randint(0, 7)
malfunction_max_duration = 20 + random.randint(0, 80)
rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_cities)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
obs_builder_object=observation_builder, remove_agents_at_target=False)
print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
))
return env
except ValueError as e:
logging.error(f"Error: {e}")
width += 5
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
def sparse_env_small(random_seed, observation_builder):
width = 30 # With of map
height = 30 # Height of map
nr_trains = 2 # Number of trains that have an assigned task in the env
cities_in_map = 3 # Number of cities where agents can start or end
seed = 10 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
line_generator = sparse_rail_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
rail_env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True)
return rail_env
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
def perc_completion(env):
tasks_finished = 0
if hasattr(env, "agents_data"):
agent_data = env.agents_data
else:
agent_data = env.agents
for current_agent in agent_data:
if current_agent.status == RailAgentStatus.DONE:
tasks_finished += 1
return 100 * np.mean(tasks_finished / max(
1, len(agent_data)))
\ No newline at end of file
import numpy as np
import os
import PIL
import shutil
# MICHEL: my own imports
import unittest
import typing
from collections import defaultdict
from typing import Dict, Any, Optional, Set, List, Tuple
from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.core.grid.grid4_utils import get_new_position
# First of all we import the Flatland rail environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
print("no action possible!")
if agent.status == RailAgentStatus.DONE_REMOVED:
print(f"agent status: DONE_REMOVED for agent {agent.handle}")
print("to solve this problem, do not input actions for removed agents!")
return [(RailEnvActions.DO_NOTHING, 0)] * 2
print("agent status:")
print(RailAgentStatus(agent.status))
#return None
# NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...)
return [(RailEnvActions.DO_NOTHING, 0)] * 2
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
print(f"possible transitions: {possible_transitions}")
distance_map = env.distance_map.get()[handle]
possible_steps = []
for movement in list(range(4)):
# MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test!
# should be much better commented or structured to be readable!
if possible_transitions[movement]:
if movement == agent.direction:
action = RailEnvActions.MOVE_FORWARD
elif movement == (agent.direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
# MICHEL: prints for debugging
print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?")
# MICHEL: can this happen when we turn 180 degrees in a dead end?
# i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)?
# TRY OUT: ASSIGN MOVE_FORWARD HERE...
action = RailEnvActions.MOVE_FORWARD
print("Here we would have a ValueError...")
#raise ValueError("Wtf, debug this shit.")
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
# MICHEL: what is this doing?
# if there is only one path to target, this is both the shortest one and the second shortest path.
if len(possible_steps) == 1:
return possible_steps * 2
else:
return possible_steps
class RailEnvWrapper:
def __init__(self, env:RailEnv):
self.env = env
assert self.env is not None
assert self.env.rail is not None, "Reset original environment first!"
assert self.env.agents is not None, "Reset original environment first!"
assert len(self.env.agents) > 0, "Reset original environment first!"
# rail can be seen as part of the interface to RailEnv.
# is used by several wrappers, to e.g. access rail.get_valid_transitions(...)
#self.rail = self.env.rail
# same for env.agents
# MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated?
#self.agents = self.env.agents
#assert self.env.agents == self.agents
#print(f"agents of RailEnvWrapper are: {self.agents}")
#self.width = self.rail.width
#self.height = self.rail.height
# TODO: maybe do this in a generic way, like "for each method of self.env, ..."
# maybe using dir(self.env) (gives list of names of members)
# MICHEL: this seems to be needed after each env.reset(..) call
# otherwise, these attribute names refer to the wrong object and are out of sync...
# probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that.
# simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3].
# it's better tou use properties here!
# @property
# def number_of_agents(self):
# return self.env.number_of_agents
# @property
# def agents(self):
# return self.env.agents
# @property
# def _seed(self):
# return self.env._seed
# @property
# def obs_builder(self):
# return self.env.obs_builder
def __getattr__(self, name):
try:
return super().__getattr__(self,name)
except:
"""Expose any other attributes of the underlying environment."""
return getattr(self.env, name)
@property
def rail(self):
return self.env.rail
@property
def width(self):
return self.env.width
@property
def height(self):
return self.env.height
@property
def agent_positions(self):
return self.env.agent_positions
def get_num_agents(self):
return self.env.get_num_agents()
def get_agent_handles(self):
return self.env.get_agent_handles()
def step(self, action_dict: Dict[int, RailEnvActions]):
#self.agents = self.env.agents
# ERROR. something is wrong with the references for self.agents...
#assert self.env.agents == self.agents
return self.env.step(action_dict)
def reset(self, **kwargs):
# MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects
# that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification..
#assert self.env.agents == self.agents
obs, info = self.env.reset(**kwargs)
#assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..."
#self.reset_attributes()
#print(f"calling RailEnvWrapper.reset()")
#print(f"obs: {obs}, info:{info}")
return obs, info
class ShortestPathActionWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv):
super().__init__(env)
#self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction
# MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict.
# otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash.
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
########## MICHEL: NEW (just for debugging) ########
for agent_id, action in action_dict.items():
agent = self.agents[agent_id]
# assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None...
# assert agent.status != RailAgentStatus.DONE # not sure about this one...
print(f"agent: {agent} with status: {agent.status}")
######################################################
# input: action dict with actions in [0, 1, 2].
transformed_action_dict = {}
for agent_id, action in action_dict.items():
if action == 0:
transformed_action_dict[agent_id] = action
else:
assert action in [1, 2]
# MICHEL: how exactly do the indices work here?
#transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0]
#print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}")
#assert agent.status != RailAgentStatus.DONE_REMOVED
# MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error...
assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
obs, rewards, dones, info = self.env.step(transformed_action_dict)
return obs, rewards, dones, info
#def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
#return self.rail_env.reset(random_seed)
# MICHEL: should not be needed, as we inherit that from RailEnvWrapper...
#def reset(self, **kwargs) -> Tuple[Dict, Dict]:
# obs, info = self.env.reset(**kwargs)
# return obs, info
def find_all_cells_where_agent_can_choose(env: RailEnv):
"""
input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
WHICH HAS BEEN RESET ALREADY!
(o.w., we call env.rail, which is None before reset(), and crash.)
"""
switches = []
switches_neighbors = []
directions = list(range(4))
for h in range(env.height):
for w in range(env.width):
# MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES.
# will not show up in quadratic environments.
# should be pos = (h, w)
#pos = (w, h)
# MICHEL: changed this
pos = (h, w)
is_switch = False
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
#print(f"env is: {env}")
#print(f"env.rail is: {env.rail}")
possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = np.count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
break
if is_switch:
# Add all neighbouring rails, if pos is a switch
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
for movement in directions:
if possible_transitions[movement]:
switches_neighbors.append(get_new_position(pos, movement))
decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class NoChoiceCellsSkipper:
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
self.env = env
self.switches = None
self.switches_neighbors = None
self.decision_cells = None
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.skipped_rewards = defaultdict(list)
# env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
#self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
# compute and initialize value for switches, switches_neighbors, and decision_cells.
self.reset_cells()
# MICHEL: maybe these three methods should be part of RailEnv?
def on_decision_cell(self, agent: EnvAgent) -> bool:
"""
print(f"agent {agent.handle} is on decision cell")
if agent.position is None:
print("because agent.position is None (has not been activated yet)")
if agent.position == agent.initial_position:
print("because agent is at initial position, activated but not departed")
if agent.position in self.decision_cells:
print("because agent.position is in self.decision_cells.")
"""
return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
def on_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches
def next_to_switch(self, agent: EnvAgent) -> bool:
return agent.position in self.switches_neighbors
# MICHEL: maybe just call this step()...
def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# MICHEL: NEED TO INITIALIZE i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
i["malfunction"] = dict()
i["speed"] = dict()
i["status"] = dict()
while len(o) == 0:
#print(f"len(o)==0. stepping the rail environment...")
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
###### MICHEL: prints for debugging ###########
if not self.on_decision_cell(self.env.agents[agent_id]):
print(f"agent {agent_id} is NOT on a decision cell.")
#################################################
if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
###### MICHEL: prints for debugging ######################
if done[agent_id]:
print(f"agent {agent_id} is done.")
#if self.on_decision_cell(self.env.agents[agent_id]):
#print(f"agent {agent_id} is on decision cell.")
#cell = self.env.agents[agent_id].position
#print(f"cell is: {cell}")
#print(f"the decision cells are: {self.decision_cells}")
############################################################
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
# MICHEL: HAVE TO MODIFY THIS HERE
# because we are not using StepOutputs, the return values of step() have a different structure.
#i[agent_id] = info[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["status"][agent_id] = info["status"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
for skipped_reward in reversed(self.skipped_rewards[agent_id]):
discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
r[agent_id] = discounted_skipped_reward
self.skipped_rewards[agent_id] = []
elif self.accumulate_skipped_rewards:
self.skipped_rewards[agent_id].append(reward[agent_id])
# end of for-loop
d['__all__'] = done['__all__']
action_dict = {}
# end of while-loop
return o, r, d, i
# MICHEL: maybe just call this reset()...
def reset_cells(self) -> None:
self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
# IMPORTANT: rail env should be reset() / initialized before put into this one!
# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation!
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
super().__init__(env)
# save these so they can be inspected easier.
self.accumulate_skipped_rewards = accumulate_skipped_rewards
self.discounting = discounting
self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting)
self.skipper.reset_cells()
# TODO: this is clunky..
# for easier access / checking
self.switches = self.skipper.switches
self.switches_neighbors = self.skipper.switches_neighbors
self.decision_cells = self.skipper.decision_cells
self.skipped_rewards = self.skipper.skipped_rewards
# MICHEL: trying to isolate the core part and put it into a separate method.
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict)
return obs, rewards, dones, info
# MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc.
# arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
# TODO: check the type of random_seed. Is it bool or int?
# MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict].
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self.skipper.reset_cells()
return obs, info
\ No newline at end of file
import numpy as np
import os
import PIL
import shutil
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# First of all we import the Flatland rail environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper
from flatland.contrib.wrappers.flatland_wrappers import ShortestPathActionWrapper # noqa
import pytest
@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers")
def test_petting_zoo_interface_env():
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 11
save = True
np.random.seed(seed)
experiment_name = "flatland_pettingzoo"
total_episodes = 1
if save:
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
rail_env.reset(random_seed=seed)
# For Shortest Path Action Wrapper, change action to 1
# rail_env = ShortestPathActionWrapper(rail_env)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
dones = {}
dones['__all__'] = False
step = 0
ep_no = 0
frame_list = []
all_actions_env = []
all_actions_pettingzoo_env = []
# while not dones['__all__']:
while ep_no < total_episodes:
action_dict = {}
# Chose an action for each agent
for a in range(rail_env.get_num_agents()):
# action = env_generators.get_shortest_path_action(rail_env, a)
action = 2
all_actions_env.append(action)
action_dict.update({a: action})
step += 1
# Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict)
image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
frame_list.append(PIL.Image.fromarray(image[:, :, :3]))
if dones['__all__']:
completion = env_generators.perc_completion(rail_env)
print("Final Agents Completed:", completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
rail_env.reset(random_seed=seed+ep_no)
# __sphinx_doc_begin__
env = flatland_env.env(environment=rail_env, use_renderer=True)
seed = 11
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < total_episodes:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
# act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent))
act = 2
all_actions_pettingzoo_env.append(act)
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
# __sphinx_doc_end__
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env))
assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match"
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-sv", __file__]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment