Commit 3ea03689 authored by nilabha's avatar nilabha

Make it compatible for custom functions to be run with with train.py

parent 53f54dcd
Pipeline #5024 passed with stage
in 2 minutes and 53 seconds
#!/usr/bin/env python
import os
import numpy as np
import ray
import yaml
......@@ -8,11 +9,16 @@ from pathlib import Path
from ray.cluster_utils import Cluster
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.tune import run_experiments
from ray.tune import run_experiments, Experiment
from ray.tune.logger import TBXLogger
from ray.tune.resources import resources_to_json
from ray.tune.tune import _make_scheduler
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.ppo.ppo import PPOTrainer
from algorithms.imitation_agent.imitation_trainer import ImitationAgent
from utils.argparser import create_parser
from utils.loader import load_envs, load_models, load_algorithms
......@@ -32,6 +38,7 @@ load_models(os.getcwd()) # Load models
from algorithms import CUSTOM_ALGORITHMS
load_algorithms(CUSTOM_ALGORITHMS) # Load algorithms
MAX_ITERATIONS = 1000000
def on_episode_end(info):
episode = info["episode"] # type: MultiAgentEpisode
......@@ -70,6 +77,53 @@ def on_episode_end(info):
episode.custom_metrics["percentage_complete"] = percentage_complete
def imitation_ppo_train_fn(config,reporter=None):
imitation_trainer = ImitationAgent(config,
env=config.get("env"),)
ppo_trainer = PPOTrainer(config,
env=config.get("env"),)
expert_ratio = config.get("env_config",{}).get("expert",{}).get('ratio', 0.5)
expert_min_ratio = config.get("env_config",{}).get("expert",{}).get('min_ratio', expert_ratio)
expert_ratio_decay = config.get("env_config",{}).get("expert",{}).get('ratio_decay', 1)
for i in range(MAX_ITERATIONS):
print("== Iteration", i, "==")
trainer_type = np.random.binomial(size=1, n=1, p= expert_ratio)[0]
if trainer_type:
# improve the Imitation policy
print("-- Imitation --")
result_imitate = imitation_trainer.train()
if reporter:
reporter(**result_imitate)
if i % checkpoint_freq == 0:
checkpoint = imitation_trainer.save()
print("checkpoint saved at", checkpoint)
ppo_trainer.set_weights(imitation_trainer.get_weights())
else:
# improve the PPO policy
print("-- PPO --")
result_ppo = ppo_trainer.train()
if reporter:
reporter(**result_ppo)
if i % checkpoint_freq == 0:
checkpoint = ppo_trainer.save()
print("checkpoint saved at", checkpoint)
expert_ratio = max(expert_min_ratio, expert_ratio_decay * expert_ratio)
imitation_trainer.stop()
ppo_trainer.stop()
print("Completed: OK")
def run(args, parser):
if args.config_file:
with open(args.config_file) as f:
......@@ -95,6 +149,7 @@ def run(args, parser):
}
verbose = 1
custom_fn = False
webui_host = "localhost"
for exp in experiments.values():
# Bazel makes it hard to find files specified in `args` (and `data`).
......@@ -154,11 +209,21 @@ def run(args, parser):
# Remove any wandb related configs
if exp['config']['evaluation_config'].get('wandb'):
del exp['config']['evaluation_config']['wandb']
if args.custom_fn:
custom_fn = globals()[exp['config'].get("env_config",{}).get("custom_fn","imitation_ppo_train_fn")]
if args.config_file:
# TODO should be in exp['config'] directly
exp['config']['env_config']['yaml_config'] = args.config_file
exp['loggers'] = [WandbLogger, TBXLogger]
global checkpoint_freq,keep_checkpoints_num,checkpoint_score_attr,checkpoint_at_end
checkpoint_freq = exp['checkpoint_freq']
# TODO: Below checkpoints paramaters are not supported for default custom_fn
keep_checkpoints_num = exp['keep_checkpoints_num']
checkpoint_score_attr = exp['checkpoint_score_attr']
checkpoint_at_end = exp['checkpoint_at_end']
if args.ray_num_nodes:
cluster = Cluster()
for _ in range(args.ray_num_nodes):
......@@ -179,6 +244,22 @@ def run(args, parser):
num_gpus=args.ray_num_gpus,
webui_host=webui_host)
if custom_fn:
for exp in experiments.values():
configs = with_common_config(exp["config"])
configs['env'] = exp.get('env')
resources = PPOTrainer.default_resource_request(configs).to_json()
experiment_spec = Experiment(
custom_fn.__name__,
custom_fn,
resources_per_trial=resources,
config=configs,
stop=exp.get('stop'),
num_samples=exp.get('num_samples',1),
loggers=exp.get('loggers'),
restore=None)
experiments = experiment_spec
run_experiments(
experiments,
scheduler=_make_scheduler(args),
......
......@@ -106,6 +106,12 @@ def create_parser(parser_creator=None):
action="store_true",
help="Whether to run evaluation. Default evaluation config is default.yaml "
"to use custom evaluation config set (eval_generator:high_eval) under configs")
parser.add_argument(
"-i",
"--custom-fn",
action="store_true",
help="Whether the experiment uses a custom function for training"
"Default custom function is imitation_ppo_train_fn")
parser.add_argument(
"--bind-all",
action="store_true",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment