rollout.py 20.6 KB
Newer Older
metataro's avatar
metataro committed
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python

import argparse
import collections
import json
import logging
import os
import pickle
import shelve
from pathlib import Path
11
import random
metataro's avatar
metataro committed
12 13 14 15 16

import gym
import numpy as np
import ray
from ray.rllib.agents.registry import get_agent_class
17
from ray.tune.registry import get_trainable_cls
metataro's avatar
metataro committed
18 19
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
nilabha's avatar
nilabha committed
20
# from ray.rllib.evaluation.episode import _flatten_action # ray 0.8.4
metataro's avatar
metataro committed
21
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
nilabha's avatar
nilabha committed
22
from ray.rllib.utils.space_utils import flatten_to_single_ndarray # ray 0.8.5
metataro's avatar
metataro committed
23 24
from ray.tune.utils import merge_dicts

25
from utils.loader import load_envs, load_models, load_algorithms
metataro's avatar
metataro committed
26 27 28 29 30

logger = logging.getLogger(__name__)

EXAMPLE_USAGE = """
Example Usage:
MasterScrat's avatar
MasterScrat committed
31 32 33 34 35 36 37 38 39
    python rollout.py /Users/flaurent/Sites/flatland/flatland-checkpoints/checkpoint_940/checkpoint-940 --run APEX --no-render --episodes 1000 --env 'flatland_random_sparse_small' --config '{"env_config": {"test": "true", "min_seed": 1002, "max_seed": 213783, "min_test_seed": 0, "max_test_seed": 100, "reset_env_freq": "1", "regenerate_rail_on_reset": "True", "regenerate_schedule_on_reset": "True", "observation": "tree", "observation_config": {"max_depth": 2, "shortest_path_max_depth": 30}}, "model": {"fcnet_activation": "relu", "fcnet_hiddens": [256, 256], "vf_share_layers": "True"}}' 
"""

"""
# Testing in flatland_random_sparse_small:
python rollout.py /Users/flaurent/Sites/flatland/flatland-checkpoints/checkpoint_940/checkpoint-940 --run APEX --no-render --episodes 1000 --env 'flatland_random_sparse_small' --config '{"env_config": {"test": "true", "min_seed": 1002, "max_seed": 213783, "min_test_seed": 0, "max_test_seed": 100, "reset_env_freq": "1", "regenerate_rail_on_reset": "True", "regenerate_schedule_on_reset": "True", "observation": "tree", "observation_config": {"max_depth": 2, "shortest_path_max_depth": 30}}, "model": {"fcnet_activation": "relu", "fcnet_hiddens": [256, 256], "vf_share_layers": "True"}}' 

# Testing in flatland_sparse:
python rollout.py /Users/flaurent/Sites/flatland/flatland-checkpoints/checkpoint_940/checkpoint-940 --run APEX --no-render --episodes 1000 --env 'flatland_sparse' --config '{"env_config": {"test": "true", "generator": "sparse_rail_generator", "generator_config": "small_v0", "observation": "tree", "observation_config": {"max_depth": 2, "shortest_path_max_depth": 30}}, "model": {"fcnet_activation": "relu", "fcnet_hiddens": [256, 256], "vf_share_layers": "True"}}' 
metataro's avatar
metataro committed
40 41 42 43 44
"""

# Register all necessary assets in tune registries
load_envs(os.getcwd())  # Load envs
load_models(os.getcwd())  # Load models
45 46
from algorithms import CUSTOM_ALGORITHMS
load_algorithms(CUSTOM_ALGORITHMS)  # Load algorithms
metataro's avatar
metataro committed
47

48 49 50
from collections.abc import Mapping
from copy import deepcopy

51 52 53 54 55
# Default terminal state epsilon
# https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn.py
final_epsilon =  0.02
random.seed(1)

56 57 58 59 60 61 62 63 64 65 66 67 68 69
def val_replace(mapping):
    obj = deepcopy(mapping)
    if isinstance(mapping, Mapping):
        for key, val in mapping.items():
            obj[key] = val_replace(val)
    else:
        if mapping == "False":
            return False
        if mapping == "True":
            return True
        else:
            return mapping
    return obj

metataro's avatar
metataro committed
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198

class RolloutSaver:
    """Utility class for storing rollouts.

    Currently supports two behaviours: the original, which
    simply dumps everything to a pickle file once complete,
    and a mode which stores each rollout as an entry in a Python
    shelf db file. The latter mode is more robust to memory problems
    or crashes part-way through the rollout generation. Each rollout
    is stored with a key based on the episode number (0-indexed),
    and the number of episodes is stored with the key "num_episodes",
    so to load the shelf file, use something like:

    with shelve.open('rollouts.pkl') as rollouts:
       for episode_index in range(rollouts["num_episodes"]):
          rollout = rollouts[str(episode_index)]

    If outfile is None, this class does nothing.
    """

    def __init__(self,
                 outfile=None,
                 use_shelve=False,
                 write_update_file=False,
                 target_steps=None,
                 target_episodes=None,
                 save_info=False):
        self._outfile = outfile
        self._update_file = None
        self._use_shelve = use_shelve
        self._write_update_file = write_update_file
        self._shelf = None
        self._num_episodes = 0
        self._rollouts = []
        self._current_rollout = []
        self._total_steps = 0
        self._target_episodes = target_episodes
        self._target_steps = target_steps
        self._save_info = save_info

    def _get_tmp_progress_filename(self):
        outpath = Path(self._outfile)
        return outpath.parent / ("__progress_" + outpath.name)

    @property
    def outfile(self):
        return self._outfile

    def __enter__(self):
        if self._outfile:
            if self._use_shelve:
                # Open a shelf file to store each rollout as they come in
                self._shelf = shelve.open(self._outfile)
            else:
                # Original behaviour - keep all rollouts in memory and save
                # them all at the end.
                # But check we can actually write to the outfile before going
                # through the effort of generating the rollouts:
                try:
                    with open(self._outfile, "wb") as _:
                        pass
                except IOError as x:
                    print("Can not open {} for writing - cancelling rollouts.".
                          format(self._outfile))
                    raise x
            if self._write_update_file:
                # Open a file to track rollout progress:
                self._update_file = self._get_tmp_progress_filename().open(
                    mode="w")
        return self

    def __exit__(self, type, value, traceback):
        if self._shelf:
            # Close the shelf file, and store the number of episodes for ease
            self._shelf["num_episodes"] = self._num_episodes
            self._shelf.close()
        elif self._outfile and not self._use_shelve:
            # Dump everything as one big pickle:
            pickle.dump(self._rollouts, open(self._outfile, "wb"))
        if self._update_file:
            # Remove the temp progress file:
            self._get_tmp_progress_filename().unlink()
            self._update_file = None

    def _get_progress(self):
        if self._target_episodes:
            return "{} / {} episodes completed".format(self._num_episodes,
                                                       self._target_episodes)
        elif self._target_steps:
            return "{} / {} steps completed".format(self._total_steps,
                                                    self._target_steps)
        else:
            return "{} episodes completed".format(self._num_episodes)

    def begin_rollout(self):
        self._current_rollout = []

    def end_rollout(self):
        if self._outfile:
            if self._use_shelve:
                # Save this episode as a new entry in the shelf database,
                # using the episode number as the key.
                self._shelf[str(self._num_episodes)] = self._current_rollout
            else:
                # Append this rollout to our list, to save laer.
                self._rollouts.append(self._current_rollout)
        self._num_episodes += 1
        if self._update_file:
            self._update_file.seek(0)
            self._update_file.write(self._get_progress() + "\n")
            self._update_file.flush()

    def append_step(self, obs, action, next_obs, reward, done, info):
        """Add a step to the current rollout, if we are saving them"""
        if self._outfile:
            if self._save_info:
                self._current_rollout.append(
                    [obs, action, next_obs, reward, done, info])
            else:
                self._current_rollout.append(
                    [obs, action, next_obs, reward, done])
        self._total_steps += 1


def create_parser(parser_creator=None):
    parser_creator = parser_creator or argparse.ArgumentParser
    parser = parser_creator(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="Roll out a reinforcement learning agent "
199
                    "given a checkpoint.",
metataro's avatar
metataro committed
200 201 202 203 204 205 206 207 208 209
        epilog=EXAMPLE_USAGE)

    parser.add_argument(
        "checkpoint", type=str, help="Checkpoint from which to roll out.")
    required_named = parser.add_argument_group("required named arguments")
    required_named.add_argument(
        "--run",
        type=str,
        required=True,
        help="The algorithm or model to train. This may refer to the name "
210 211 212
             "of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
             "user-defined trainable function or class registered in the "
             "tune registry.")
metataro's avatar
metataro committed
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    required_named.add_argument(
        "--env", type=str, help="The gym environment to use.")
    parser.add_argument(
        "--no-render",
        default=False,
        action="store_const",
        const=True,
        help="Surpress rendering of the environment.")
    parser.add_argument(
        "--monitor",
        default=False,
        action="store_const",
        const=True,
        help="Wrap environment in gym Monitor to record video.")
    parser.add_argument(
        "--steps", default=10000, help="Number of steps to roll out.")
    parser.add_argument("--out", default=None, help="Output filename.")
    parser.add_argument(
        "--config",
        default="{}",
        type=json.loads,
        help="Algorithm-specific configuration (e.g. env, hyperparams). "
235
             "Surpresses loading of configuration from checkpoint.")
metataro's avatar
metataro committed
236 237 238 239 240 241 242 243 244
    parser.add_argument(
        "--episodes",
        default=0,
        help="Number of complete episodes to roll out. (Overrides --steps)")
    parser.add_argument(
        "--save-info",
        default=False,
        action="store_true",
        help="Save the info field generated by the step() method, "
245
             "as well as the action, observations, rewards and done fields.")
metataro's avatar
metataro committed
246 247 248 249 250
    parser.add_argument(
        "--use-shelve",
        default=False,
        action="store_true",
        help="Save rollouts into a python shelf file (will save each episode "
251
             "as it is generated). An output filename must be set using --out.")
metataro's avatar
metataro committed
252 253 254 255 256
    parser.add_argument(
        "--track-progress",
        default=False,
        action="store_true",
        help="Write progress to a temporary file (updated "
257 258
             "after each episode). An output filename must be set using --out; "
             "the progress file will live in the same folder.")
259 260 261 262
    parser.add_argument(
        "--eager",
        action="store_true",
        help="Whether to attempt to enable TF eager execution.")
metataro's avatar
metataro committed
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
    return parser


def run(args, parser):
    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, "rb") as f:
            config = pickle.load(f)
    if "num_workers" in config:
        config["num_workers"] = min(2, config["num_workers"])
283 284 285

    updated_config = val_replace(args.config)
    config = merge_dicts(config, updated_config)
metataro's avatar
metataro committed
286 287 288 289 290 291
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    ray.init()
292 293
    
    if args.eager:
294 295
        from tensorflow.python.framework.ops import enable_eager_execution
        enable_eager_execution()
296
        config['eager'] = True
297 298 299
    
    cls = get_trainable_cls(args.run)
    
metataro's avatar
metataro committed
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    agent = cls(env=args.env, config=config)
    agent.restore(args.checkpoint)
    num_steps = int(args.steps)
    num_episodes = int(args.episodes)
    with RolloutSaver(
            args.out,
            args.use_shelve,
            write_update_file=args.track_progress,
            target_steps=num_steps,
            target_episodes=num_episodes,
            save_info=args.save_info) as saver:
        outcome = rollout(agent, args.env, num_steps, num_episodes, saver,
                          args.no_render, args.monitor)
        outcome_file = os.path.join(os.path.dirname(config_path), 'test_outcome.json')
        with open(outcome_file, 'w') as f:
            json.dump(outcome, f, indent=4)


class DefaultMapping(collections.defaultdict):
    """default_factory now takes as an argument the missing key."""

    def __missing__(self, key):
        self[key] = value = self.default_factory(key)
        return value


def default_policy_agent_mapping(unused_agent_id):
    return DEFAULT_POLICY_ID


def keep_going(steps, num_steps, episodes, num_episodes):
    """Determine whether we've collected enough data"""
    # if num_episodes is set, this overrides num_steps
    if num_episodes:
        return episodes < num_episodes
    # if num_steps is set, continue until we reach the limit
    if num_steps:
        return steps < num_steps
    # otherwise keep going forever
    return True


def rollout(agent,
            env_name,
            num_steps,
            num_episodes=0,
            saver=None,
            no_render=True,
            monitor=False):
    policy_agent_mapping = default_policy_agent_mapping

    if saver is None:
        saver = RolloutSaver()

    if hasattr(agent, "workers"):
        env = agent.workers.local_worker().env
        multiagent = isinstance(env, MultiAgentEnv)
        if agent.workers.local_worker().multiagent:
            policy_agent_mapping = agent.config["multiagent"][
                "policy_mapping_fn"]

        policy_map = agent.workers.local_worker().policy_map
        state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
        use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
        action_init = {
nilabha's avatar
nilabha committed
365 366
            p: flatten_to_single_ndarray(m.action_space.sample()) # ray 0.8.5
            # p: _flatten_action(m.action_space.sample()) # ray 0.8.4
metataro's avatar
metataro committed
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
            for p, m in policy_map.items()
        }
    else:
        env = gym.make(env_name)
        multiagent = False
        use_lstm = {DEFAULT_POLICY_ID: False}

    if monitor and not no_render and saver and saver.outfile is not None:
        # If monitoring has been requested,
        # manually wrap our environment with a gym monitor
        # which is set to record every episode.
        env = gym.wrappers.Monitor(
            env, os.path.join(os.path.dirname(saver.outfile), "monitor"),
            lambda x: True)

    steps = 0
    episodes = 0
    simulation_rewards = []
    simulation_rewards_normalized = []
    simulation_percentage_complete = []
    simulation_steps = []

    while keep_going(steps, num_steps, episodes, num_episodes):
        mapping_cache = {}  # in case policy_agent_mapping is stochastic
        saver.begin_rollout()
        obs = env.reset()
        agent_states = DefaultMapping(
            lambda agent_id: state_init[mapping_cache[agent_id]])
        prev_actions = DefaultMapping(
            lambda agent_id: action_init[mapping_cache[agent_id]])
        prev_rewards = collections.defaultdict(lambda: 0.)
        done = False
        reward_total = 0.0

        episode_steps = 0
        episode_max_steps = 0
        episode_num_agents = 0
        agents_score = collections.defaultdict(lambda: 0.)
        agents_done = set()

        while not done and keep_going(steps, num_steps, episodes,
                                      num_episodes):
            multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
            action_dict = {}
            for agent_id, a_obs in multi_obs.items():
                if a_obs is not None:
                    policy_id = mapping_cache.setdefault(
                        agent_id, policy_agent_mapping(agent_id))
                    p_use_lstm = use_lstm[policy_id]
                    if p_use_lstm:
                        a_action, p_state, _ = agent.compute_action(
                            a_obs,
                            state=agent_states[agent_id],
                            prev_action=prev_actions[agent_id],
                            prev_reward=prev_rewards[agent_id],
                            policy_id=policy_id)
                        agent_states[agent_id] = p_state
                    else:
                        a_action = agent.compute_action(
                            a_obs,
                            prev_action=prev_actions[agent_id],
                            prev_reward=prev_rewards[agent_id],
                            policy_id=policy_id)
nilabha's avatar
nilabha committed
430 431
                    a_action = flatten_to_single_ndarray(a_action)  # ray 0.8.5
                    # a_action = _flatten_action(a_action)  # tuple actions # ray 0.8.4
432 433 434 435 436 437 438

                    # Epsilon-greedy action selection for APEX
                    if hasattr(agent, '_name'):
                        if agent._name == "APEX":
                            if random.random() <= final_epsilon:
                                a_action = random.choice(np.arange(env.action_space.n))

metataro's avatar
metataro committed
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
                    action_dict[agent_id] = a_action
                    prev_actions[agent_id] = a_action
            action = action_dict

            action = action if multiagent else action[_DUMMY_AGENT_ID]
            next_obs, reward, done, info = env.step(action)
            if multiagent:
                for agent_id, r in reward.items():
                    prev_rewards[agent_id] = r
            else:
                prev_rewards[_DUMMY_AGENT_ID] = reward

            if multiagent:
                done = done["__all__"]
                reward_total += sum(reward.values())
            else:
                reward_total += reward
            if not no_render:
                env.render()
            saver.append_step(obs, action, next_obs, reward, done, info)
            steps += 1
            obs = next_obs

            for agent_id, agent_info in info.items():
                if episode_max_steps == 0:
                    episode_max_steps = agent_info["max_episode_steps"]
                    episode_num_agents = agent_info["num_agents"]
                episode_steps = max(episode_steps, agent_info["agent_step"])
                agents_score[agent_id] = agent_info["agent_score"]
                if agent_info["agent_done"]:
                    agents_done.add(agent_id)

        episode_score = sum(agents_score.values())
        simulation_rewards.append(episode_score)
nilabha's avatar
nilabha committed
473
        simulation_rewards_normalized.append(episode_score / (episode_max_steps * episode_num_agents))
metataro's avatar
metataro committed
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
        simulation_percentage_complete.append(float(len(agents_done)) / episode_num_agents)
        simulation_steps.append(episode_steps)

        saver.end_rollout()
        print(f"Episode #{episodes}: "
              f"score: {episode_score:.2f} "
              f"({np.mean(simulation_rewards):.2f}), "
              f"normalized score: {simulation_rewards_normalized[-1]:.2f} "
              f"({np.mean(simulation_rewards_normalized):.2f}), "
              f"percentage_complete: {simulation_percentage_complete[-1]:.2f} "
              f"({np.mean(simulation_percentage_complete):.2f})")
        if done:
            episodes += 1

    print("Evaluation completed:\n"
489 490 491 492 493
          f"Episodes: {episodes}\n"
          f"Mean Reward: {np.round(np.mean(simulation_rewards))}\n"
          f"Mean Normalized Reward: {np.round(np.mean(simulation_rewards_normalized))}\n"
          f"Mean Percentage Complete: {np.round(np.mean(simulation_percentage_complete), 3)}\n"
          f"Mean Steps: {np.round(np.mean(simulation_steps), 2)}")
metataro's avatar
metataro committed
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514

    return {
        'reward': [float(r) for r in simulation_rewards],
        'reward_mean': np.mean(simulation_rewards),
        'reward_std': np.std(simulation_rewards),
        'normalized_reward': [float(r) for r in simulation_rewards_normalized],
        'normalized_reward_mean': np.mean(simulation_rewards_normalized),
        'normalized_reward_std': np.std(simulation_rewards_normalized),
        'percentage_complete': [float(c) for c in simulation_percentage_complete],
        'percentage_complete_mean': np.mean(simulation_percentage_complete),
        'percentage_complete_std': np.std(simulation_percentage_complete),
        'steps': [float(c) for c in simulation_steps],
        'steps_mean': np.mean(simulation_steps),
        'steps_std': np.std(simulation_steps),
    }


if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    run(args, parser)