Commit 20df7d27 authored by nilabha's avatar nilabha

Update checkpoints, rollout scripts, add test_result script

parent 6e5f6dad
Pipeline #5251 failed with stage
in 2 minutes and 59 seconds
......@@ -278,7 +278,7 @@ if __name__ == "__main__":
"shortest_path_max_depth": 30},
"generator": "sparse_rail_generator",
"generator_config": "small_v0",
"eval_generator": "test_eval"},
"eval_generator": "test"},
"model" : {
"fcnet_activation": "relu",
"fcnet_hiddens": [256, 256],
......
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
import numpy as np
import pandas as pd
import os
import json
# The file all_eval_runs are generated by the wandb_data.py
df_eval = pd.read_csv('all_eval_runs.csv')
df_test_results = df_eval[["run","group"]].drop_duplicates()
all_runs = df_test_results["run"].to_list()
colnames = ["run","percentage_complete_mean","normalized_reward_mean"]
df_test_metrics = pd.DataFrame(columns= colnames)
for cur_run in all_runs:
result_file = "checkpoints/"+ cur_run + "/test_outcome.json"
if os.path.isfile(result_file):
with open(result_file) as f:
data = json.load(f)
df_test_metrics = df_test_metrics.append({colnames[0]:cur_run,colnames[1]:data.get(colnames[1]),colnames[2]:data.get(colnames[2])},ignore_index = True)
df_test = pd.merge(df_test_metrics,df_test_results,how='left')
df_all_final_results = df_test.groupby("group").aggregate([np.mean,np.std]).reset_index()
df_all_final_results.to_csv('test_results_group.csv',index=False)
......@@ -13,6 +13,7 @@ import gym
import numpy as np
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.tune.registry import get_trainable_cls
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
# from ray.rllib.evaluation.episode import _flatten_action # ray 0.8.4
......@@ -20,7 +21,6 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.space_utils import flatten_to_single_ndarray # ray 0.8.5
from ray.tune.utils import merge_dicts
from algorithms.imitation_agent.imitation_trainer import ImitationAgent
from utils.loader import load_envs, load_models, load_algorithms
logger = logging.getLogger(__name__)
......@@ -44,6 +44,23 @@ load_models(os.getcwd()) # Load models
from algorithms import CUSTOM_ALGORITHMS
load_algorithms(CUSTOM_ALGORITHMS) # Load algorithms
from collections.abc import Mapping
from copy import deepcopy
def val_replace(mapping):
obj = deepcopy(mapping)
if isinstance(mapping, Mapping):
for key, val in mapping.items():
obj[key] = val_replace(val)
else:
if mapping == "False":
return False
if mapping == "True":
return True
else:
return mapping
return obj
class RolloutSaver:
"""Utility class for storing rollouts.
......@@ -233,6 +250,10 @@ def create_parser(parser_creator=None):
help="Write progress to a temporary file (updated "
"after each episode). An output filename must be set using --out; "
"the progress file will live in the same folder.")
parser.add_argument(
"--eager",
action="store_true",
help="Whether to attempt to enable TF eager execution.")
return parser
......@@ -253,26 +274,22 @@ def run(args, parser):
config = pickle.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
config = merge_dicts(config, args.config)
updated_config = val_replace(args.config)
config = merge_dicts(config, updated_config)
if not args.env:
if not config.get("env"):
parser.error("the following arguments are required: --env")
args.env = config.get("env")
ray.init()
try:
cls = get_agent_class(args.run)
except:
cls = ImitationAgent # CUSTOM_ALGORITHMS[args.run]
if args.eager:
from tensorflow.python.framework.ops import enable_eager_execution
enable_eager_execution()
print("========================")
print(config)
print("========================")
print(config.get("model").get("vf_share_layers"))
print("========================")
config['model']['vf_share_layers'] = False
print("========================")
cls = get_trainable_cls(args.run)
agent = cls(env=args.env, config=config)
agent.restore(args.checkpoint)
num_steps = int(args.steps)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment