Commit 8fe7ea5f authored by MasterScrat's avatar MasterScrat
Browse files

Generator config registry, misc cleanup

parent 551cd808
......@@ -6,6 +6,11 @@ import yaml
GENERATOR_CONFIG_REGISTRY = {}
def get_generator_config(name: str):
return GENERATOR_CONFIG_REGISTRY[name]
config_folder = os.path.join(os.path.dirname(__file__), "generator_configs")
for file in os.listdir(config_folder):
if file.endswith('.yaml') and not file.startswith('_'):
......
......@@ -14,4 +14,6 @@ speed_ratio_map:
'0.5': 0.25
'0.3333333': 0.25
'0.25': 0.25
seed: 0
\ No newline at end of file
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
width: 25
height: 25
number_of_agents: 5
max_num_cities: 4
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
\ No newline at end of file
import os
import logging
from pprint import pprint
import gym
import yaml
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.malfunction_generators import malfunction_from_params, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
......@@ -15,14 +16,23 @@ from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
class FlatlandSparse(MultiAgentEnv):
def __init__(self, env_config) -> None:
super().__init__()
# TODO implement other generators
assert env_config['generator'] == 'sparse_rail_generator'
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
with open(env_config['generator_config']) as f:
self._config = yaml.safe_load(f)
self._config = get_generator_config(env_config['generator_config'])
if env_config.worker_index == 0 and env_config.vector_index == 0:
print("=" * 50)
pprint(self._config)
print("=" * 50)
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
render=env_config['render'],
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset']
# render=env_config['render'], # TODO need to fix gl compatibility first
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
@property
......@@ -42,25 +52,41 @@ class FlatlandSparse(MultiAgentEnv):
max_rails_in_city=self._config['max_rails_in_city']
)
stochastic_data = {
'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']
}
schedule_generator = sparse_schedule_generator({float(k): float(v)
for k, v in self._config['speed_ratio_map'].items()})
env = RailEnv(
width=self._config['width'],
height=self._config['height'],
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=self._config['number_of_agents'],
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False
)
malfunction_generator = no_malfunction_generator()
if {'malfunction_rate', 'min_duration', 'max_duration'} <= self._config.keys():
stochastic_data = {
'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']
}
malfunction_generator = malfunction_from_params(stochastic_data)
speed_ratio_map = None
if 'speed_ratio_map' in self._config:
speed_ratio_map = {
float(k): float(v) for k, v in self._config['speed_ratio_map'].items()
}
schedule_generator = sparse_schedule_generator(speed_ratio_map)
env = None
try:
env = RailEnv(
width=self._config['width'],
height=self._config['height'],
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=self._config['number_of_agents'],
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False
)
env.reset()
except ValueError as e:
logging.error("=" * 50)
logging.error(f"Error while creating env: {e}")
logging.error("=" * 50)
return env
def step(self, action_dict):
......
......@@ -31,26 +31,10 @@ flatland-sparse-global-conv-ppo:
observation_config:
max_width: 32
max_height: 32
width: 32
height: 32
number_of_agents: 10
max_num_cities: 3
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
malfunction_rate: 8000
malfunction_min_duration: 15
malfunction_max_duration: 50
# speed ratio map: keys must be strings but will be converted to float later!
speed_ratio_map:
'1.0': 0.25
'0.5': 0.25
'0.3333333': 0.25
'0.25': 0.25
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
render: False
generator: sparse_rail_generator
generator_config: 32x32_v0
wandb:
project: flatland
entity: masterscrat
......
flatland-sparse-global-conv-ppo:
run: PPO
env: flatland_sparse
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 7
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 1
env_config:
observation: global
observation_config:
max_width: 32
max_height: 32
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "global_obs"] # TODO should be set programmatically
model:
custom_model: global_obs_model
custom_options:
architecture: impala
architecture_options:
residual_layers: [[16, 2], [32, 4]]
......@@ -162,7 +162,7 @@ def create_parser(parser_creator=None):
parser = parser_creator(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Roll out a reinforcement learning agent "
"given a checkpoint.",
"given a checkpoint.",
epilog=EXAMPLE_USAGE)
parser.add_argument(
......@@ -173,9 +173,9 @@ def create_parser(parser_creator=None):
type=str,
required=True,
help="The algorithm or model to train. This may refer to the name "
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
required_named.add_argument(
"--env", type=str, help="The gym environment to use.")
parser.add_argument(
......@@ -198,7 +198,7 @@ def create_parser(parser_creator=None):
default="{}",
type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams). "
"Surpresses loading of configuration from checkpoint.")
"Surpresses loading of configuration from checkpoint.")
parser.add_argument(
"--episodes",
default=0,
......@@ -208,20 +208,20 @@ def create_parser(parser_creator=None):
default=False,
action="store_true",
help="Save the info field generated by the step() method, "
"as well as the action, observations, rewards and done fields.")
"as well as the action, observations, rewards and done fields.")
parser.add_argument(
"--use-shelve",
default=False,
action="store_true",
help="Save rollouts into a python shelf file (will save each episode "
"as it is generated). An output filename must be set using --out.")
"as it is generated). An output filename must be set using --out.")
parser.add_argument(
"--track-progress",
default=False,
action="store_true",
help="Write progress to a temporary file (updated "
"after each episode). An output filename must be set using --out; "
"the progress file will live in the same folder.")
"after each episode). An output filename must be set using --out; "
"the progress file will live in the same folder.")
return parser
......@@ -431,11 +431,11 @@ def rollout(agent,
episodes += 1
print("Evaluation completed:\n"
f"Episodes: {episodes}\n"
f"Mean Reward: {np.round(np.mean(simulation_rewards))}\n"
f"Mean Normalized Reward: {np.round(np.mean(simulation_rewards_normalized))}\n"
f"Mean Percentage Complete: {np.round(np.mean(simulation_percentage_complete), 3)}\n"
f"Mean Steps: {np.round(np.mean(simulation_steps), 2)}")
f"Episodes: {episodes}\n"
f"Mean Reward: {np.round(np.mean(simulation_rewards))}\n"
f"Mean Normalized Reward: {np.round(np.mean(simulation_rewards_normalized))}\n"
f"Mean Percentage Complete: {np.round(np.mean(simulation_percentage_complete), 3)}\n"
f"Mean Steps: {np.round(np.mean(simulation_steps), 2)}")
return {
'reward': [float(r) for r in simulation_rewards],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment