Commit f582ff19 authored by nilabha's avatar nilabha

update changes for evaluation

parent eb3e7fe3
Pipeline #4938 passed with stage
in 21 minutes and 52 seconds
......@@ -128,6 +128,7 @@ dmypy.json
# misc
.idea
.vscode/
# custom extras
small_tree_video/
......
......@@ -5,7 +5,10 @@ import humps
import yaml
GENERATOR_CONFIG_REGISTRY = {}
EVAL_CONFIG_REGISTRY = {}
def get_eval_config(name: str = None):
return EVAL_CONFIG_REGISTRY[name]
def get_generator_config(name: str):
return GENERATOR_CONFIG_REGISTRY[name]
......@@ -23,3 +26,16 @@ for file in os.listdir(config_folder):
print("- Successfully Loaded Generator Config {} from {}".format(
filename, basename
))
eval_config_folder = os.path.join(os.path.dirname(__file__), "eval_configs")
for file in os.listdir(eval_config_folder):
if file.endswith('.yaml') and not file.startswith('_'):
basename = os.path.basename(file)
filename = basename.replace(".yaml", "")
with open(os.path.join(eval_config_folder, file)) as f:
EVAL_CONFIG_REGISTRY[filename] = yaml.safe_load(f)
print("- Successfully Loaded Evaluation Config {} from {}".format(
filename, basename
))
evaluation_num_workers: 2
# Enable evaluation, once per training iteration.
evaluation_interval: 50
# Run 1 episode each time evaluation runs.
evaluation_num_episodes: 50
# Override the env config for evaluation.
evaluation_config:
explore: False
env_config:
seed: 100
\ No newline at end of file
evaluation_num_workers: 2
# Enable evaluation, once per training iteration.
evaluation_interval: 3
# Run 1 episode each time evaluation runs.
evaluation_num_episodes: 2
# Override the env config for evaluation.
evaluation_config:
explore: False
env_config:
seed: 100
\ No newline at end of file
......@@ -15,6 +15,10 @@ from ray.tune.tune import _make_scheduler
from utils.argparser import create_parser
from utils.loader import load_envs, load_models
from envs.flatland import get_eval_config
from ray.rllib.utils import merge_dicts
# Custom wandb logger with hotfix to allow custom callbacks
from wandblogger import WandbLogger
......@@ -125,6 +129,29 @@ def run(args, parser):
exp['config']['callbacks'] = {
'on_episode_end': on_episode_end,
}
if args.eval:
eval_configs = get_eval_config(exp['config'].get('env_config',\
{}).get('eval_generator',"default"))
eval_seed = eval_configs.get('evaluation_config',{}).get('env_config',{}).get('seed')
# add evaluation config to the current config
exp['config'] = merge_dicts(exp['config'],eval_configs)
if exp['config'].get('evaluation_config'):
exp['config']['evaluation_config']['env_config'] = exp['config'].get('env_config')
eval_env_config = exp['config']['evaluation_config'].get('env_config')
if eval_seed and eval_env_config:
# We override the env seed from the evaluation config
eval_env_config['seed'] = eval_seed
# Remove any wandb related configs
if eval_env_config:
if eval_env_config.get('wandb'):
del eval_env_config['wandb']
# Remove any wandb related configs
if exp['config']['evaluation_config'].get('wandb'):
del exp['config']['evaluation_config']['wandb']
if args.config_file:
# TODO should be in exp['config'] directly
exp['config']['env_config']['yaml_config'] = args.config_file
......
......@@ -100,6 +100,12 @@ def create_parser(parser_creator=None):
action="store_true",
default=True,
help="Whether to log additional flatland specfic metrics such as percentage complete or normalized score.")
parser.add_argument(
"-e",
"--eval",
action="store_true",
help="Whether to run evaluation. Default evaluation config is default.yaml "
"to use custom evaluation config set (eval_generator:high_eval) under configs")
parser.add_argument(
"--bind-all",
action="store_true",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment