Commit 27e8d3bc authored by MasterScrat's avatar MasterScrat

Merge branch 'refactor'

parents 22dd86d9 050238d1
import importlib
import os
import humps
import yaml
GENERATOR_CONFIG_REGISTRY = {}
def get_generator_config(name: str):
return GENERATOR_CONFIG_REGISTRY[name]
config_folder = os.path.join(os.path.dirname(__file__), "generator_configs")
for file in os.listdir(config_folder):
if file.endswith('.yaml') and not file.startswith('_'):
basename = os.path.basename(file)
filename = basename.replace(".yaml", "")
with open(os.path.join(config_folder, file)) as f:
GENERATOR_CONFIG_REGISTRY[filename] = yaml.safe_load(f)
print("- Successfully Loaded Generator Config {} from {}".format(
filename, basename
))
width: 32
height: 32
number_of_agents: 10
max_num_cities: 3
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
malfunction_rate: 8000
malfunction_min_duration: 15
malfunction_max_duration: 50
# speed ratio map: keys must be strings but will be converted to float later!
speed_ratio_map:
'1.0': 0.25
'0.5': 0.25
'0.3333333': 0.25
'0.25': 0.25
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
width: 25
height: 25
number_of_agents: 2
max_num_cities: 4
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
width: 25
height: 25
number_of_agents: 1
max_num_cities: 4
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
width: 25
height: 25
number_of_agents: 3
max_num_cities: 4
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
width: 25
height: 25
number_of_agents: 5
max_num_cities: 4
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
......@@ -3,6 +3,7 @@ import os
from abc import ABC, abstractmethod
import gym
import humps
from flatland.core.env_observation_builder import ObservationBuilder
......@@ -42,4 +43,11 @@ def make_obs(name: str, config, *args, **kwargs) -> Observation:
# automatically import any Python files in the obs/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = importlib.import_module(f'.{file[:-3]}', __name__)
basename = os.path.basename(file)
filename = basename.replace(".py", "")
class_name = humps.pascalize(filename)
module = importlib.import_module(f'.{file[:-3]}', package=__name__)
print("- Successfully Loaded Observation class {} from {}".format(
class_name, basename
))
......@@ -15,8 +15,10 @@ class TreeObservation(Observation):
def __init__(self, config) -> None:
super().__init__(config)
self._builder = TreeObsForRailEnvRLLibWrapper(
TreeObsForRailEnv(max_depth=config['max_depth'],
predictor=ShortestPathPredictorForRailEnv(config['shortest_path_max_depth']))
TreeObsForRailEnv(
max_depth=config['max_depth'],
predictor=ShortestPathPredictorForRailEnv(config['shortest_path_max_depth'])
)
)
def builder(self) -> ObservationBuilder:
......@@ -105,12 +107,11 @@ def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray
def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
num_remaining_nodes = int((4**(remaining_depth+1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes*6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes*4
num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4
data, distance, agent_data = _split_node_into_feature_groups(node)
......@@ -182,4 +183,4 @@ class TreeObsForRailEnvRLLibWrapper(ObservationBuilder):
self._builder.print_subtree(node, label, indent)
def set_env(self, env):
self._builder.set_env(env)
\ No newline at end of file
self._builder.set_env(env)
......@@ -49,6 +49,10 @@ class FlatlandRllibWrapper(object):
for agent, done in dones.items():
if agent != '__all__' and not agent in obs:
continue # skip agent if there is no observation
# Use this if using a single policy for multiple agents
# TODO find better way to handle this
#if True or agent not in self._agents_done:
if agent not in self._agents_done:
if agent != '__all__':
if done:
......
......@@ -3,9 +3,9 @@ import random
import gym
from ray.rllib import MultiAgentEnv
from envs.flatland.env_generators import random_sparse_env_small
from envs.flatland.utils.env_generators import random_sparse_env_small
from envs.flatland.observations import make_obs
from envs.flatland.rllib_wrapper import FlatlandRllibWrapper
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
class FlatlandRandomSparseSmall(MultiAgentEnv):
......
import logging
import gym
import numpy as np
from flatland.envs.malfunction_generators import no_malfunction_generator, malfunction_from_params
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper, StepOutput
class FlatlandSingle(gym.Env):
def render(self, mode='human'):
pass
def __init__(self, env_config):
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._config = get_generator_config(env_config['generator_config'])
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
def _launch(self):
rail_generator = sparse_rail_generator(
seed=self._config['seed'],
max_num_cities=self._config['max_num_cities'],
grid_mode=self._config['grid_mode'],
max_rails_between_cities=self._config['max_rails_between_cities'],
max_rails_in_city=self._config['max_rails_in_city']
)
malfunction_generator = no_malfunction_generator()
if {'malfunction_rate', 'min_duration', 'max_duration'} <= self._config.keys():
stochastic_data = {
'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']
}
malfunction_generator = malfunction_from_params(stochastic_data)
speed_ratio_map = None
if 'speed_ratio_map' in self._config:
speed_ratio_map = {
float(k): float(v) for k, v in self._config['speed_ratio_map'].items()
}
schedule_generator = sparse_schedule_generator(speed_ratio_map)
env = None
try:
env = RailEnv(
width=self._config['width'],
height=self._config['height'],
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=self._config['number_of_agents'],
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False,
random_seed=self._config['seed']
)
env.reset()
except ValueError as e:
logging.error("=" * 50)
logging.error(f"Error while creating env: {e}")
logging.error("=" * 50)
return env
def step(self, action_list):
# print("="*50)
# print(action_dict)
action_dict = {}
for i, action in enumerate(action_list):
action_dict[i] = action
step_r = self._env.step(action_dict)
# print(step_r)
# print("="*50)
return StepOutput(
obs=[step for step in step_r.obs.values()],
reward=np.sum([r for r in step_r.reward.values()]),
done=all(step_r.done.values()),
info=step_r.info[0]
)
#return step_r
def reset(self):
foo = self._env.reset()
# print("="*50)
# print(foo)
# print("="*50)
return [step for step in foo.values()]
#return foo
@property
def observation_space(self) -> gym.spaces.Space:
observation_space = self._observation.observation_space()
if isinstance(observation_space, gym.spaces.Box):
return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self._config['number_of_agents'], *observation_space.shape,))
elif isinstance(observation_space, gym.spaces.Tuple):
spaces = observation_space.spaces * self._config['number_of_agents']
return gym.spaces.Tuple(spaces)
else:
raise ValueError("Unhandled space:", observation_space.__class__)
@property
def action_space(self) -> gym.spaces.Space:
return gym.spaces.MultiDiscrete([5] * self._config['number_of_agents'])
import logging
from pprint import pprint
import gym
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.malfunction_generators import malfunction_from_params, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.rllib_wrapper import FlatlandRllibWrapper
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
class FlatlandSparse(MultiAgentEnv):
def __init__(self, env_config) -> None:
super().__init__()
self._config = env_config
# TODO implement other generators
assert env_config['generator'] == 'sparse_rail_generator'
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._env = FlatlandRllibWrapper(rail_env=self._launch(), render=env_config['render'],
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset'])
self._config = get_generator_config(env_config['generator_config'])
if env_config.worker_index == 0 and env_config.vector_index == 0:
print("=" * 50)
pprint(self._config)
print("=" * 50)
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
# render=env_config['render'], # TODO need to fix gl compatibility first
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
@property
def observation_space(self) -> gym.spaces.Space:
......@@ -27,28 +44,50 @@ class FlatlandSparse(MultiAgentEnv):
return self._env.action_space
def _launch(self):
rail_generator = sparse_rail_generator(seed=self._config['seed'], max_num_cities=self._config['max_num_cities'],
grid_mode=self._config['grid_mode'],
max_rails_between_cities=self._config['max_rails_between_cities'],
max_rails_in_city=self._config['max_rails_in_city'])
stochastic_data = {'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']}
schedule_generator = sparse_schedule_generator({float(k): float(v)
for k, v in self._config['speed_ratio_map'].items()})
env = RailEnv(
width=self._config['width'],
height=self._config['height'],
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=self._config['number_of_agents'],
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False
rail_generator = sparse_rail_generator(
seed=self._config['seed'],
max_num_cities=self._config['max_num_cities'],
grid_mode=self._config['grid_mode'],
max_rails_between_cities=self._config['max_rails_between_cities'],
max_rails_in_city=self._config['max_rails_in_city']
)
malfunction_generator = no_malfunction_generator()
if {'malfunction_rate', 'min_duration', 'max_duration'} <= self._config.keys():
stochastic_data = {
'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']
}
malfunction_generator = malfunction_from_params(stochastic_data)
speed_ratio_map = None
if 'speed_ratio_map' in self._config:
speed_ratio_map = {
float(k): float(v) for k, v in self._config['speed_ratio_map'].items()
}
schedule_generator = sparse_schedule_generator(speed_ratio_map)
env = None
try:
env = RailEnv(
width=self._config['width'],
height=self._config['height'],
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=self._config['number_of_agents'],
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False,
random_seed=self._config['seed']
)
env.reset()
except ValueError as e:
logging.error("=" * 50)
logging.error(f"Error while creating env: {e}")
logging.error("=" * 50)
return env
def step(self, action_dict):
......
flatland-random-sparse-small-tree-fc-ppo:
run: PPO
env: flatland_single
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 5
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 1
env_config:
observation: tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_single_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_single_v0", "tree_obs"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
flatland-random-sparse-small-tree-fc-ppo:
run: PPO
env: flatland_single
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 5
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 1
env_config:
observation: tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_double_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_double_v0", "tree_obs"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
flatland-random-sparse-small-tree-fc-ppo:
run: PPO
env: flatland_single
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 5
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 1
env_config:
observation: tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_triple_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_triple_v0", "tree_obs"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
flatland-sparse-single-global-conv-ppo:
run: PPO
env: flatland_single
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 11
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 1
env_config:
observation: global
observation_config:
max_width: 32
max_height: 32
generator: sparse_rail_generator
generator_config: small_single_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_single_v0", "global_obs"] # TODO should be set programmatically
model:
custom_model: global_obs_model
custom_options:
architecture: impala
architecture_options:
residual_layers: [[16, 2], [32, 4]]
flatland-random-sparse-small-tree-fc-ppo:
run: PPO
env: flatland_single
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 7
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 0
env_config:
observation: tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_single_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_single_v0", "tree_obs"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
flatland-sparse-single-global-conv-ppo:
run: PPO
env: flatland_single
sto