import os from collections.abc import Iterable from numbers import Number import hydra import nethack_baselines.rllib.models # noqa: F401 import numpy as np import ray import ray.tune.integration.wandb from nethack_baselines.rllib.envs import RLLibNLEEnv from nethack_baselines.rllib.util.loading import NAME_TO_TRAINER from omegaconf import DictConfig, OmegaConf from ray import tune from ray.rllib.models.catalog import MODEL_DEFAULTS from ray.tune.integration.wandb import (_VALID_ITERABLE_TYPES, _VALID_TYPES, WandbLoggerCallback) from ray.tune.registry import register_env from ray.tune.utils import merge_dicts def get_full_config(cfg: DictConfig) -> DictConfig: env_flags = OmegaConf.to_container(cfg) max_num_steps = 1e6 if cfg.env in ("staircase", "pet"): max_num_steps = 1000 env_flags["max_num_steps"] = int(max_num_steps) env_flags["seedspath"] = "" return OmegaConf.create(env_flags) @hydra.main(config_name="config") def train(cfg: DictConfig) -> None: ray.init(num_gpus=cfg.num_gpus, num_cpus=cfg.num_cpus + 1) cfg = get_full_config(cfg) register_env("RLlibNLE-v0", RLLibNLEEnv) try: algo, trainer = NAME_TO_TRAINER[cfg.algo] except KeyError: raise ValueError("The algorithm you specified isn't currently supported: %s", cfg.algo) config = algo.DEFAULT_CONFIG.copy() args_config = OmegaConf.to_container(cfg) # Algo-specific config. Requires hydra config keys to match rllib exactly algo_config = args_config.pop(cfg.algo) # Remove unnecessary config keys for algo in NAME_TO_TRAINER.keys(): if algo != cfg.algo: args_config.pop(algo, None) # Merge config from hydra (will have some rogue keys but that's ok) config = merge_dicts(config, args_config) # Update configuration with parsed arguments in specific ways config = merge_dicts( config, { "framework": "torch", "num_gpus": cfg.num_gpus, "seed": cfg.seed, "env": "RLlibNLE-v0", "env_config": { "flags": cfg, "name": cfg.env, }, "train_batch_size": cfg.train_batch_size, "model": merge_dicts( MODEL_DEFAULTS, { "custom_model": "rllib_nle_model", "custom_model_config": {"flags": cfg, "algo": cfg.algo}, "use_lstm": cfg.use_lstm, "lstm_use_prev_reward": True, "lstm_use_prev_action": True, "lstm_cell_size": cfg.hidden_dim, }, ), "num_workers": cfg.num_cpus, "num_envs_per_worker": int(cfg.num_actors / cfg.num_cpus), "evaluation_interval": 100, "evaluation_num_episodes": 50, "evaluation_config": {"explore": False}, "rollout_fragment_length": cfg.unroll_length, }, ) # Merge algo-specific config at top level config = merge_dicts(config, algo_config) # Ensure we can use the config we've specified above trainer_class = trainer.with_updates(default_config=config) callbacks = [] if cfg.wandb: callbacks.append( WandbLoggerCallback( project=cfg.project, api_key_file="~/.wandb_api_key", entity=cfg.entity, group=cfg.group, tags=cfg.tags.split(","), ) ) os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1" # Only log to wandb # Hacky monkey-patching to allow for OmegaConf config def _is_allowed_type(obj): """Return True if type is allowed for logging to wandb""" if isinstance(obj, DictConfig): return True if isinstance(obj, np.ndarray) and obj.size == 1: return isinstance(obj.item(), Number) if isinstance(obj, Iterable) and len(obj) > 0: return isinstance(obj[0], _VALID_ITERABLE_TYPES) return isinstance(obj, _VALID_TYPES) ray.tune.integration.wandb._is_allowed_type = _is_allowed_type local_dir = os.path.join(os.getcwd(), "ray_results") analysis = tune.run( trainer_class, stop={"timesteps_total": cfg.total_steps}, config=config, name=cfg.name, callbacks=callbacks, local_dir=local_dir, checkpoint_freq=cfg.checkpoint_freq, checkpoint_at_end=True, ) checkpoints = analysis.get_trial_checkpoints_paths(trial=analysis.trials[0]) checkpoint_paths = [check_path for check_path, _ in checkpoints] print("Model Checkpoint Paths:") print(checkpoint_paths) print( "Use these in the agents/rllib_batched_agent.py file when submitting" "your agent, using the path from the root of this repository" ) if __name__ == "__main__": train()