train.py 5.78 KB
Newer Older
metataro's avatar
metataro committed
1 2 3 4 5
#!/usr/bin/env python

import os

import ray
MasterScrat's avatar
MasterScrat committed
6
import yaml
7
from pathlib import Path
metataro's avatar
metataro committed
8 9 10
from ray.cluster_utils import Cluster
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_tf, try_import_torch
11
from ray.tune import run_experiments
12
from ray.tune.logger import TBXLogger
MasterScrat's avatar
MasterScrat committed
13 14
from ray.tune.resources import resources_to_json
from ray.tune.tune import _make_scheduler
metataro's avatar
metataro committed
15

MasterScrat's avatar
MasterScrat committed
16
from argparser import create_parser
metataro's avatar
metataro committed
17
from utils.loader import load_envs, load_models
MasterScrat's avatar
MasterScrat committed
18 19
# Custom wandb logger with hotfix to allow custom callbacks
from wandblogger import WandbLogger
metataro's avatar
metataro committed
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

# Try to import both backends for flag checking/warnings.
tf = try_import_tf()
torch, _ = try_import_torch()

# Register all necessary assets in tune registries
load_envs(os.getcwd())  # Load envs
load_models(os.getcwd())  # Load models


def on_episode_end(info):
    episode = info["episode"]  # type: MultiAgentEpisode

    episode_steps = 0
    episode_max_steps = 0
    episode_num_agents = 0
    episode_score = 0
    episode_done_agents = 0

    for agent, agent_info in episode._agent_to_last_info.items():
        if episode_max_steps == 0:
            episode_max_steps = agent_info["max_episode_steps"]
            episode_num_agents = agent_info["num_agents"]
        episode_steps = max(episode_steps, agent_info["agent_step"])
        episode_score += agent_info["agent_score"]
        if agent_info["agent_done"]:
            episode_done_agents += 1

48 49
    # Not a valid check when considering a single policy for multiple agents
    #assert len(episode._agent_to_last_info) == episode_num_agents
metataro's avatar
metataro committed
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

    norm_factor = 1.0 / (episode_max_steps + episode_num_agents)
    percentage_complete = float(episode_done_agents) / episode_num_agents

    episode.custom_metrics["episode_steps"] = episode_steps
    episode.custom_metrics["episode_max_steps"] = episode_max_steps
    episode.custom_metrics["episode_num_agents"] = episode_num_agents
    episode.custom_metrics["episode_return"] = episode.total_reward
    episode.custom_metrics["episode_score"] = episode_score
    episode.custom_metrics["episode_score_normalized"] = episode_score * norm_factor
    episode.custom_metrics["percentage_complete"] = percentage_complete


def run(args, parser):
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.safe_load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to ~/ray_results/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "keep_checkpoints_num": args.keep_checkpoints_num,
                "checkpoint_score_attr": args.checkpoint_score_attr,
                "local_dir": args.local_dir,
                "resources_per_trial": (
MasterScrat's avatar
MasterScrat committed
77 78
                        args.resources_per_trial and
                        resources_to_json(args.resources_per_trial)),
metataro's avatar
metataro committed
79 80 81 82 83 84 85 86 87
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "num_samples": args.num_samples,
                "upload_dir": args.upload_dir,
            }
        }

    verbose = 1
MasterScrat's avatar
MasterScrat committed
88
    webui_host = "localhost"
metataro's avatar
metataro committed
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    for exp in experiments.values():
        # Bazel makes it hard to find files specified in `args` (and `data`).
        # Look for them here.
        # NOTE: Some of our yaml files don't have a `config` section.
        if exp.get("config", {}).get("input") and \
                not os.path.exists(exp["config"]["input"]):
            # This script runs in the ray/rllib dir.
            rllib_dir = Path(__file__).parent
            input_file = rllib_dir.absolute().joinpath(exp["config"]["input"])
            exp["config"]["input"] = str(input_file)

        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")
        if args.eager:
            exp["config"]["eager"] = True
        if args.torch:
            exp["config"]["use_pytorch"] = True
        if args.v:
            exp["config"]["log_level"] = "INFO"
            verbose = 2
        if args.vv:
            exp["config"]["log_level"] = "DEBUG"
            verbose = 3
        if args.trace:
            if not exp["config"].get("eager"):
                raise ValueError("Must enable --eager to enable tracing.")
            exp["config"]["eager_tracing"] = True
MasterScrat's avatar
MasterScrat committed
118 119
        if args.bind_all:
            webui_host = "0.0.0.0"
metataro's avatar
metataro committed
120 121 122 123
        if args.log_flatland_stats:
            exp['config']['callbacks'] = {
                'on_episode_end': on_episode_end,
            }
124
        exp['loggers'] = [WandbLogger, TBXLogger]
metataro's avatar
metataro committed
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

    if args.ray_num_nodes:
        cluster = Cluster()
        for _ in range(args.ray_num_nodes):
            cluster.add_node(
                num_cpus=args.ray_num_cpus or 1,
                num_gpus=args.ray_num_gpus or 0,
                object_store_memory=args.ray_object_store_memory,
                memory=args.ray_memory,
                redis_max_memory=args.ray_redis_max_memory)
        ray.init(address=cluster.address)
    else:
        ray.init(
            address=args.ray_address,
            object_store_memory=args.ray_object_store_memory,
            memory=args.ray_memory,
            redis_max_memory=args.ray_redis_max_memory,
            num_cpus=args.ray_num_cpus,
MasterScrat's avatar
MasterScrat committed
143 144 145
            num_gpus=args.ray_num_gpus,
            webui_host=webui_host)

metataro's avatar
metataro committed
146 147 148 149 150 151 152 153 154 155 156 157 158
    run_experiments(
        experiments,
        scheduler=_make_scheduler(args),
        queue_trials=args.queue_trials,
        resume=args.resume,
        verbose=verbose,
        concurrent=True)


if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    run(args, parser)