train.py 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from collections.abc import Iterable
from numbers import Number

import hydra
import nethack_baselines.rllib.models  # noqa: F401
import numpy as np
import ray
import ray.tune.integration.wandb
from nethack_baselines.rllib.envs import RLLibNLEEnv
from nethack_baselines.rllib.util.loading import NAME_TO_TRAINER
from omegaconf import DictConfig, OmegaConf
from ray import tune
from ray.rllib.models.catalog import MODEL_DEFAULTS
from ray.tune.integration.wandb import (_VALID_ITERABLE_TYPES, _VALID_TYPES,
                                        WandbLoggerCallback)
from ray.tune.registry import register_env
from ray.tune.utils import merge_dicts


def get_full_config(cfg: DictConfig) -> DictConfig:
    env_flags = OmegaConf.to_container(cfg)
    max_num_steps = 1e6
    if cfg.env in ("staircase", "pet"):
        max_num_steps = 1000
    env_flags["max_num_steps"] = int(max_num_steps)
    env_flags["seedspath"] = ""
    return OmegaConf.create(env_flags)


@hydra.main(config_name="config")
def train(cfg: DictConfig) -> None:
    ray.init(num_gpus=cfg.num_gpus, num_cpus=cfg.num_cpus + 1)
    cfg = get_full_config(cfg)
    register_env("RLlibNLE-v0", RLLibNLEEnv)

    try:
        algo, trainer = NAME_TO_TRAINER[cfg.algo]
    except KeyError:
        raise ValueError("The algorithm you specified isn't currently supported: %s", cfg.algo)

    config = algo.DEFAULT_CONFIG.copy()

    args_config = OmegaConf.to_container(cfg)

    # Algo-specific config. Requires hydra config keys to match rllib exactly
    algo_config = args_config.pop(cfg.algo)

    # Remove unnecessary config keys
    for algo in NAME_TO_TRAINER.keys():
        if algo != cfg.algo:
            args_config.pop(algo, None)

    # Merge config from hydra (will have some rogue keys but that's ok)
    config = merge_dicts(config, args_config)

    # Update configuration with parsed arguments in specific ways
    config = merge_dicts(
        config,
        {
            "framework": "torch",
            "num_gpus": cfg.num_gpus,
            "seed": cfg.seed,
            "env": "RLlibNLE-v0",
            "env_config": {
                "flags": cfg,
                "name": cfg.env,
            },
            "train_batch_size": cfg.train_batch_size,
            "model": merge_dicts(
                MODEL_DEFAULTS,
                {
                    "custom_model": "rllib_nle_model",
                    "custom_model_config": {"flags": cfg, "algo": cfg.algo},
                    "use_lstm": cfg.use_lstm,
                    "lstm_use_prev_reward": True,
                    "lstm_use_prev_action": True,
                    "lstm_cell_size": cfg.hidden_dim,
                },
            ),
            "num_workers": cfg.num_cpus,
            "num_envs_per_worker": int(cfg.num_actors / cfg.num_cpus),
            "evaluation_interval": 100,
            "evaluation_num_episodes": 50,
            "evaluation_config": {"explore": False},
            "rollout_fragment_length": cfg.unroll_length,
        },
    )

    # Merge algo-specific config at top level
    config = merge_dicts(config, algo_config)

    # Ensure we can use the config we've specified above
    trainer_class = trainer.with_updates(default_config=config)

    callbacks = []
    if cfg.wandb:
        callbacks.append(
            WandbLoggerCallback(
                project=cfg.project,
                api_key_file="~/.wandb_api_key",
                entity=cfg.entity,
                group=cfg.group,
                tags=cfg.tags.split(","),
            )
        )
        os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1"  # Only log to wandb

    # Hacky monkey-patching to allow for OmegaConf config
    def _is_allowed_type(obj):
        """Return True if type is allowed for logging to wandb"""
        if isinstance(obj, DictConfig):
            return True
        if isinstance(obj, np.ndarray) and obj.size == 1:
            return isinstance(obj.item(), Number)
        if isinstance(obj, Iterable) and len(obj) > 0:
            return isinstance(obj[0], _VALID_ITERABLE_TYPES)
        return isinstance(obj, _VALID_TYPES)

    ray.tune.integration.wandb._is_allowed_type = _is_allowed_type

    local_dir = os.path.join(os.getcwd(), "ray_results")

    analysis = tune.run(
        trainer_class,
        stop={"timesteps_total": cfg.total_steps},
        config=config,
        name=cfg.name,
        callbacks=callbacks,
        local_dir=local_dir,
        checkpoint_freq=cfg.checkpoint_freq,
        checkpoint_at_end=True,
    )

    checkpoints = analysis.get_trial_checkpoints_paths(trial=analysis.trials[0])
    checkpoint_paths = [check_path for check_path, _ in checkpoints]
    print("Model Checkpoint Paths:")
    print(checkpoint_paths)
    print(
        "Use these in the agents/rllib_batched_agent.py file when submitting"
        "your agent, using the path from the root of this repository"
    )


if __name__ == "__main__":
    train()