Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • yoon_jaeseok/neurips-2021-the-nethack-challenge
  • gjuceviciute/neurips-2021-the-nethack-challenge
  • bagcangman/neurips-2021-the-nethack-challenge
  • leocd/neurips-2021-the-nethack-challenge
  • froot_joos21/neurips-2021-the-nethack-challenge
  • clint_herron/neurips-2021-the-nethack-challenge
  • debjoy_saha/neurips-2021-the-nethack-challenge
  • tinys/neurips-2021-the-nethack-challenge
  • daan/neurips-2021-the-nethack-challenge
  • matthew_zellman/neurips-2021-the-nethack-challenge
  • christophe_cerisara/neurips-2021-the-nethack-challenge
  • tl_boright/neurips-2021-the-nethack-challenge
  • kvzhao/neurips-2021-the-nethack-challenge
  • nethack/neurips-2021-the-nethack-challenge
14 results
Show changes
Showing
with 2670 additions and 23 deletions
# TorchBeast NetHackChallenge Benchmark
This is a baseline model for the NetHack Challenge based on
[TorchBeast](https://github.com/facebookresearch/torchbeast) - FAIR's
implementation of IMPALA for PyTorch.
It comes with all the code you need to train, run and submit a model
that is based on the results published in the original NLE paper.
This implementation can run with 2 GPUS (one for acting and one for
learning), and runs many simultaneous environments with dynamic
batching. Currently it has been configured to run with only 1 GPU.
## Installation
**[Native Installation]**
To get this running you'll need to follow the TorchBeast installation instructions for PolyBeast from the [TorchBeast repo](https://github.com/facebookresearch/torchbeast#faster-version-polybeast).
**[Docker Installation]**
You can fast track the installation of PolyBeast, by running the competitions own Dockerfile. Prebuilt images are also hosted on the Docker Hub. These commands should open an image that allows you run the baseline
**To Run Existing Docker Image**
`docker pull fairnle/challenge:dev`
```docker run -it -v `pwd`:/home/aicrowd --gpus='all' fairnle/challenge:dev```
**To Build Your Own Image**
*Dev Image* - runs with root user, doesn't copy all your files across into image
`docker build -t competition --target nhc-dev .`
*or Submission Image* - runs with aicrowd user, copies across all your files into image
`docker build -t competition --target nhc-submit .`
*Run Image*
```docker run -it -v `pwd`:/home/aicrowd --gpus='all' competition```
## Running The Baseline
Once installed, in this directory run:
`python polyhydra.py`
To change parameters, edit `config.yaml`, or to override parameters
from the command-line run:
`python polyhydra.py embedding_dim=16`
The training will save checkpoints to a new directory (`outputs`) and
should the environments create any outputs, they will be saved to
`nle_data` - (by default recordings of episodes are switched off to
save space).
The default polybeast runs on 2 GPUs, one for the learner and one for
the actors. However, with only one GPU you can run still run
polybeast - just override the `actor_device` argument:
`python polyhydra.py actor_device=cpu`
NOTE: if you get a "Too many open files" error, try: `ulimit -Sn 10000`.
## Making a submission
In the output directory of your trained model, you should find two files, `checkpoint.tar` and `config.yaml`. Add both of them to your submission repo. Then change the `MODEL_DIR` variable in `agents/torchbeast_agent.py` to point to the directory where these files are located. And finally, simply set the `AGENT` in `submission_config.py` to be 'TorchBeastAgent' so that your torchbeast agent variation is used for the submission.
After that, follow [these instructions](/docs/SUBMISSION.md) to submit your model to AIcrowd!
## Repo Structure
```
baselines/torchbeast
├── core/
├── models/ # <- Models HERE
├── util/
├── config.yaml # <- Flags HERE
├── polybeast_env.py # <- Training Env HERE
├── polybeast_learner.py # <- Training Loop HERE
└── polyhydra.py # <- main() HERE
```
The structure is simple, compartmentalising the environment setup,
training loop and models in to different files. You can tweak any of
these separately, and add parameters to the flags (which are passed
around).
## About the Model
This model (`BaselineNet`) we provide is simple and all in
`models/baseline.py`.
* It encodes the dungeon into a fixed-size representation
(`GlyphEncoder`)
* It encodes the topline message into a fixed-size representation
(`MessageEncoder`)
* It encodes the bottom line statistics (eg armour class, health) into
a fixed-size representation (`BLStatsEncoder`)
* It concatenates all these outputs into a fixed size, runs this
through a fully connected layer, and into an LSTM.
* The outputs of the LSTM go through policy and baseline heads (since
this is an actor-critic alorithm)
As you can see there is a lot of data to play with in this game, and
plenty to try, both in modelling and in the learning algorithms used.
## Improvement Ideas
*Here are some ideas we haven't tried yet, but might be easy places to start. Happy tinkering!*
### Model Improvements (`baseline.py`)
* The model is currently not using the terminal observations
(`tty_chars`, `tty_colors`, `tty_cursor`), so it has no idea about
menus - could this we make use of this somehow?
* The bottom-line stats are very informative, but very simply encoded
in `BLStatsEncoder` - is there a better way to do this?
* The `GlyphEncoder` builds a embedding for the glyphs, and then takes
a crop of these centered around the player icon coordinates
(`@`). Should the crop be reusing these the same embedding matrix?
* The current model constrains the vast action space to a smaller
subset of actions. Is it too constrained? Or not constrained enough?
### Environment Improvements (`polybeast_env.py`)
* Opening menus (such as when spellcasting) do not advance the in game
timer. However, models can also get stuck in menus as you have to
learn what buttons to press to close the menu. Can changing the
penalty for not advancing the in-game timer improve the result?
* The NetHackChallenge assesses the score on random character
assignments. Might it be easier to learn on just a few of these at
the beginning of training?
### Algorithm/Optimisation Improvements (`polybeast_learner.py`)
* Can we add some intrinsic rewards to help our agents learn?
* Should we add penalties for disincentivise pathological behaviour we
observe?
* Can we improve the model by using a different optimizer?
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
defaults:
- hydra/job_logging: colorlog
- hydra/hydra_logging: colorlog
# - hydra/launcher: submitit_slurm
# # To Be Used With hydra submitit_slurm if you have SLURM cluster
# # pip install hydra-core hydra_colorlog
# # can set these on the commandline too, e.g. `hydra.launcher.partition=dev`
# hydra:
# launcher:
# timeout_min: 4300
# cpus_per_task: 20
# gpus_per_node: 2
# tasks_per_node: 1
# mem_gb: 20
# nodes: 1
# partition: dev
# comment: null
# max_num_timeout: 5 # will requeue on timeout or preemption
name: null # can use this to have multiple runs with same params, eg name=1,2,3,4,5
## WANDB settings
wandb: false # Enable wandb logging.
project: nethack_challenge # The wandb project name.
entity: user1 # The wandb user to log to.
group: group1 # The wandb group for the run.
# POLYBEAST ENV settings
mock: false # Use mock environment instead of NetHack.
single_ttyrec: true # Record ttyrec only for actor 0.
num_seeds: 0 # If larger than 0, samples fixed number of environment seeds to be used.'
write_profiler_trace: false # Collect and write a profiler trace for chrome://tracing/.
fn_penalty_step: constant # Function to accumulate penalty.
penalty_time: 0.0 # Penalty per time step in the episode.
penalty_step: -0.01 # Penalty per step in the episode.
reward_lose: 0 # Reward for losing (dying before finding the staircase).
reward_win: 100 # Reward for winning (finding the staircase).
state_counter: none # Method for counting state visits. Default none.
character: 'mon-hum-neu-mal' # Specification of the NetHack character.
## typical characters we use
# 'mon-hum-neu-mal'
# 'val-dwa-law-fem'
# 'wiz-elf-cha-mal'
# 'tou-hum-neu-fem'
# '@' # random (used in Challenge assessment)
# RUN settings.
mode: train # Training or test mode.
env: challenge # Name of Gym environment to create.
# # env (task) names: challenge, staircase, pet,
# eat, gold, score, scout, oracle
# TRAINING settings.
num_actors: 256 # Number of actors.
total_steps: 1e9 # Total environment steps to train for. Will be cast to int.
batch_size: 32 # Learner batch size.
unroll_length: 80 # The unroll length (time dimension).
num_learner_threads: 1 # Number learner threads.
num_inference_threads: 1 # Number inference threads.
disable_cuda: false # Disable CUDA.
learner_device: cuda:0 # Set learner device.
actor_device: cuda:0 # Set actor device.
# OPTIMIZER settings. (RMS Prop)
learning_rate: 0.0002 # Learning rate.
grad_norm_clipping: 40 # Global gradient norm clip.
alpha: 0.99 # RMSProp smoothing constant.
momentum: 0 # RMSProp momentum.
epsilon: 0.000001 # RMSProp epsilon.
# LOSS settings.
entropy_cost: 0.001 # Entropy cost/multiplier.
baseline_cost: 0.5 # Baseline cost/multiplier.
discounting: 0.999 # Discounting factor.
normalize_reward: true # Normalizes reward by dividing by running stdev from mean.
# MODEL settings.
model: baseline # Name of model to build (see models/__init__.py).
use_lstm: true # Use LSTM in agent model.
hidden_dim: 256 # Size of hidden representations.
embedding_dim: 64 # Size of glyph embeddings.
layers: 5 # Number of ConvNet Layers for Glyph Model
crop_dim: 9 # Size of crop (c x c)
use_index_select: true # Whether to use index_select instead of embedding lookup (for speed reasons).
restrict_action_space: True # Use a restricted ACTION SPACE (only nethack.USEFUL_ACTIONS)
msg:
hidden_dim: 64 # Hidden dimension for message encoder.
embedding_dim: 32 # Embedding dimension for characters in message encoder.
# TEST settings.
load_dir: null # Path to load a model from for testing
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import csv
import datetime
import json
import logging
import os
import time
import weakref
def _save_metadata(path, metadata):
metadata["date_save"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
with open(path, "w") as f:
json.dump(metadata, f, indent=4, sort_keys=True)
def gather_metadata():
metadata = dict(
date_start=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
env=os.environ.copy(),
successful=False,
)
# Git metadata.
try:
import git
except ImportError:
logging.warning(
"Couldn't import gitpython module; install it with `pip install gitpython`."
)
else:
try:
repo = git.Repo(search_parent_directories=True)
metadata["git"] = {
"commit": repo.commit().hexsha,
"is_dirty": repo.is_dirty(),
"path": repo.git_dir,
}
if not repo.head.is_detached:
metadata["git"]["branch"] = repo.active_branch.name
except git.InvalidGitRepositoryError:
pass
if "git" not in metadata:
logging.warning("Couldn't determine git data.")
# Slurm metadata.
if "SLURM_JOB_ID" in os.environ:
slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")]
metadata["slurm"] = {}
for k in slurm_env_keys:
d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower()
metadata["slurm"][d_key] = os.environ[k]
return metadata
class FileWriter:
def __init__(self, xp_args=None, rootdir="~/palaas"):
if rootdir == "~/palaas":
# make unique id in case someone uses the default rootdir
xpid = "{proc}_{unixtime}".format(
proc=os.getpid(), unixtime=int(time.time())
)
rootdir = os.path.join(rootdir, xpid)
self.basepath = os.path.expandvars(os.path.expanduser(rootdir))
self._tick = 0
# metadata gathering
if xp_args is None:
xp_args = {}
self.metadata = gather_metadata()
# we need to copy the args, otherwise when we close the file writer
# (and rewrite the args) we might have non-serializable objects (or
# other nasty stuff).
self.metadata["args"] = copy.deepcopy(xp_args)
formatter = logging.Formatter("%(message)s")
self._logger = logging.getLogger("palaas/out")
# to stdout handler
shandle = logging.StreamHandler()
shandle.setFormatter(formatter)
self._logger.addHandler(shandle)
self._logger.setLevel(logging.INFO)
# to file handler
if not os.path.exists(self.basepath):
self._logger.info("Creating log directory: %s", self.basepath)
os.makedirs(self.basepath, exist_ok=True)
else:
self._logger.info("Found log directory: %s", self.basepath)
self.paths = dict(
msg="{base}/out.log".format(base=self.basepath),
logs="{base}/logs.csv".format(base=self.basepath),
fields="{base}/fields.csv".format(base=self.basepath),
meta="{base}/meta.json".format(base=self.basepath),
)
self._logger.info("Saving arguments to %s", self.paths["meta"])
if os.path.exists(self.paths["meta"]):
self._logger.warning(
"Path to meta file already exists. " "Not overriding meta."
)
else:
self.save_metadata()
self._logger.info("Saving messages to %s", self.paths["msg"])
if os.path.exists(self.paths["msg"]):
self._logger.warning(
"Path to message file already exists. " "New data will be appended."
)
fhandle = logging.FileHandler(self.paths["msg"])
fhandle.setFormatter(formatter)
self._logger.addHandler(fhandle)
self._logger.info("Saving logs data to %s", self.paths["logs"])
self._logger.info("Saving logs' fields to %s", self.paths["fields"])
self.fieldnames = ["_tick", "_time"]
if os.path.exists(self.paths["logs"]):
self._logger.warning(
"Path to log file already exists. " "New data will be appended."
)
# Override default fieldnames.
with open(self.paths["fields"], "r") as csvfile:
reader = csv.reader(csvfile)
lines = list(reader)
if len(lines) > 0:
self.fieldnames = lines[-1]
# Override default tick: use the last tick from the logs file plus 1.
with open(self.paths["logs"], "r") as csvfile:
reader = csv.reader(csvfile)
lines = list(reader)
# Need at least two lines in order to read the last tick:
# the first is the csv header and the second is the first line
# of data.
if len(lines) > 1:
self._tick = int(lines[-1][0]) + 1
self._fieldfile = open(self.paths["fields"], "a")
self._fieldwriter = csv.writer(self._fieldfile)
self._fieldfile.flush()
self._logfile = open(self.paths["logs"], "a")
self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames)
# Auto-close (and save) on destruction.
weakref.finalize(self, _save_metadata, self.paths["meta"], self.metadata)
def log(self, to_log, tick=None, verbose=False):
if tick is not None:
raise NotImplementedError
else:
to_log["_tick"] = self._tick
self._tick += 1
to_log["_time"] = time.time()
old_len = len(self.fieldnames)
for k in to_log:
if k not in self.fieldnames:
self.fieldnames.append(k)
if old_len != len(self.fieldnames):
self._fieldwriter.writerow(self.fieldnames)
self._fieldfile.flush()
self._logger.info("Updated log fields: %s", self.fieldnames)
if to_log["_tick"] == 0:
self._logfile.write("# %s\n" % ",".join(self.fieldnames))
if verbose:
self._logger.info(
"LOG | %s",
", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]),
)
self._logwriter.writerow(to_log)
self._logfile.flush()
def close(self, successful=True):
self.metadata["successful"] = successful
self.save_metadata()
for f in [self._logfile, self._fieldfile]:
f.close()
def save_metadata(self):
_save_metadata(self.paths["meta"], self.metadata)
# This file taken from
# https://github.com/deepmind/scalable_agent/blob/
# cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py
# and modified.
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to compute V-trace off-policy actor critic targets.
For details and theory see:
"IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures"
by Espeholt, Soyer, Munos et al.
See https://arxiv.org/abs/1802.01561 for the full paper.
"""
import collections
import torch
import torch.nn.functional as F
VTraceFromLogitsReturns = collections.namedtuple(
"VTraceFromLogitsReturns",
[
"vs",
"pg_advantages",
"log_rhos",
"behavior_action_log_probs",
"target_action_log_probs",
],
)
VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")
def action_log_probs(policy_logits, actions):
return -F.nll_loss(
F.log_softmax(torch.flatten(policy_logits, 0, 1), dim=-1),
torch.flatten(actions, 0, 1),
reduction="none",
).view_as(actions)
def from_logits(
behavior_policy_logits,
target_policy_logits,
actions,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
):
"""V-trace for softmax policies."""
target_action_log_probs = action_log_probs(target_policy_logits, actions)
behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions)
log_rhos = target_action_log_probs - behavior_action_log_probs
vtrace_returns = from_importance_weights(
log_rhos=log_rhos,
discounts=discounts,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
)
return VTraceFromLogitsReturns(
log_rhos=log_rhos,
behavior_action_log_probs=behavior_action_log_probs,
target_action_log_probs=target_action_log_probs,
**vtrace_returns._asdict(),
)
@torch.no_grad()
def from_importance_weights(
log_rhos,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
):
"""V-trace from log importance weights."""
with torch.no_grad():
rhos = torch.exp(log_rhos)
if clip_rho_threshold is not None:
clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold)
else:
clipped_rhos = rhos
cs = torch.clamp(rhos, max=1.0)
# Append bootstrapped value to get [v1, ..., v_t+1]
values_t_plus_1 = torch.cat(
[values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0
)
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
acc = torch.zeros_like(bootstrap_value)
result = []
for t in range(discounts.shape[0] - 1, -1, -1):
acc = deltas[t] + discounts[t] * cs[t] * acc
result.append(acc)
result.reverse()
vs_minus_v_xs = torch.stack(result)
# Add V(x_s) to get v_s.
vs = torch.add(vs_minus_v_xs, values)
# Advantage for policy gradient.
vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
if clip_pg_rho_threshold is not None:
clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold)
else:
clipped_pg_rhos = rhos
pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)
# Make sure no gradients backpropagated through the returned values.
return VTraceReturns(vs=vs, pg_advantages=pg_advantages)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nle.env import tasks
from nle.env.base import DUNGEON_SHAPE
from .baseline import BaselineNet
from omegaconf import OmegaConf
import torch
ENVS = dict(
staircase=tasks.NetHackStaircase,
score=tasks.NetHackScore,
pet=tasks.NetHackStaircasePet,
oracle=tasks.NetHackOracle,
gold=tasks.NetHackGold,
eat=tasks.NetHackEat,
scout=tasks.NetHackScout,
challenge=tasks.NetHackChallenge,
)
def create_model(flags, device):
model_string = flags.model
if model_string == "baseline":
model_cls = BaselineNet
else:
raise NotImplementedError("model=%s" % model_string)
action_space = ENVS[flags.env](savedir=None, archivefile=None)._actions
model = model_cls(DUNGEON_SHAPE, action_space, flags, device)
model.to(device=device)
return model
def load_model(load_dir, device):
flags = OmegaConf.load(load_dir + "/config.yaml")
flags.checkpoint = load_dir + "/checkpoint.tar"
model = create_model(flags, device)
checkpoint_states = torch.load(flags.checkpoint, map_location=device)
model.load_state_dict(checkpoint_states["model_state_dict"])
return model
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from nle import nethack
from .util import id_pairs_table
import numpy as np
NUM_GLYPHS = nethack.MAX_GLYPH
NUM_FEATURES = 25
PAD_CHAR = 0
NUM_CHARS = 256
def get_action_space_mask(action_space, reduced_action_space):
mask = np.array([int(a in reduced_action_space) for a in action_space])
return torch.Tensor(mask)
def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1):
"""Return the dimension after applying a convolution along one axis"""
return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride)
def select(embedding_layer, x, use_index_select):
"""Use index select instead of default forward to possible speed up embedding."""
if use_index_select:
out = embedding_layer.weight.index_select(0, x.view(-1))
# handle reshaping x to 1-d and output back to N-d
return out.view(x.shape + (-1,))
else:
return embedding_layer(x)
class NetHackNet(nn.Module):
"""This base class simply provides a skeleton for running with torchbeast."""
AgentOutput = collections.namedtuple("AgentOutput", "action policy_logits baseline")
def __init__(self):
super(NetHackNet, self).__init__()
self.register_buffer("reward_sum", torch.zeros(()))
self.register_buffer("reward_m2", torch.zeros(()))
self.register_buffer("reward_count", torch.zeros(()).fill_(1e-8))
def forward(self, inputs, core_state):
raise NotImplementedError
def initial_state(self, batch_size=1):
return ()
@torch.no_grad()
def update_running_moments(self, reward_batch):
"""Maintains a running mean of reward."""
new_count = len(reward_batch)
new_sum = torch.sum(reward_batch)
new_mean = new_sum / new_count
curr_mean = self.reward_sum / self.reward_count
new_m2 = torch.sum((reward_batch - new_mean) ** 2) + (
(self.reward_count * new_count)
/ (self.reward_count + new_count)
* (new_mean - curr_mean) ** 2
)
self.reward_count += new_count
self.reward_sum += new_sum
self.reward_m2 += new_m2
@torch.no_grad()
def get_running_std(self):
"""Returns standard deviation of the running mean of the reward."""
return torch.sqrt(self.reward_m2 / self.reward_count)
class BaselineNet(NetHackNet):
"""This model combines the encodings of the glyphs, top line message and
blstats into a single fixed-size representation, which is then passed to
an LSTM core before generating a policy and value head for use in an IMPALA
like architecture.
This model was based on 'neurips2020release' tag on the NLE repo, itself
based on Kuttler et al, 2020
The NetHack Learning Environment
https://arxiv.org/abs/2006.13760
"""
def __init__(self, observation_shape, action_space, flags, device):
super(BaselineNet, self).__init__()
self.flags = flags
self.observation_shape = observation_shape
self.num_actions = len(action_space)
self.H = observation_shape[0]
self.W = observation_shape[1]
self.use_lstm = flags.use_lstm
self.h_dim = flags.hidden_dim
# GLYPH + CROP MODEL
self.glyph_model = GlyphEncoder(flags, self.H, self.W, flags.crop_dim, device)
# MESSAGING MODEL
self.msg_model = MessageEncoder(
flags.msg.hidden_dim, flags.msg.embedding_dim, device
)
# BLSTATS MODEL
self.blstats_model = BLStatsEncoder(NUM_FEATURES, flags.embedding_dim)
out_dim = (
self.blstats_model.hidden_dim
+ self.glyph_model.hidden_dim
+ self.msg_model.hidden_dim
)
self.fc = nn.Sequential(
nn.Linear(out_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim),
nn.ReLU(),
)
if self.use_lstm:
self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)
self.policy = nn.Linear(self.h_dim, self.num_actions)
self.baseline = nn.Linear(self.h_dim, 1)
if flags.restrict_action_space:
reduced_space = nethack.USEFUL_ACTIONS
logits_mask = get_action_space_mask(action_space, reduced_space)
self.policy_logits_mask = nn.parameter.Parameter(
logits_mask, requires_grad=False
)
def initial_state(self, batch_size=1):
return tuple(
torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
for _ in range(2)
)
def forward(self, inputs, core_state, learning=False):
T, B, H, W = inputs["glyphs"].shape
reps = []
# -- [B' x K] ; B' == (T x B)
glyphs_rep = self.glyph_model(inputs)
reps.append(glyphs_rep)
# -- [B' x K]
char_rep = self.msg_model(inputs)
reps.append(char_rep)
# -- [B' x K]
features_emb = self.blstats_model(inputs)
reps.append(features_emb)
# -- [B' x K]
st = torch.cat(reps, dim=1)
# -- [B' x K]
st = self.fc(st)
if self.use_lstm:
core_input = st.view(T, B, -1)
core_output_list = []
notdone = (~inputs["done"]).float()
for input, nd in zip(core_input.unbind(), notdone.unbind()):
# Reset core state to zero whenever an episode ended.
# Make `done` broadcastable with (num_layers, B, hidden_size)
# states:
nd = nd.view(1, -1, 1)
core_state = tuple(nd * t for t in core_state)
output, core_state = self.core(input.unsqueeze(0), core_state)
core_output_list.append(output)
core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
else:
core_output = st
# -- [B' x A]
policy_logits = self.policy(core_output)
# -- [B' x 1]
baseline = self.baseline(core_output)
if self.flags.restrict_action_space:
policy_logits = policy_logits * self.policy_logits_mask + (
(1 - self.policy_logits_mask) * -1e10
)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, -1)
baseline = baseline.view(T, B)
action = action.view(T, B)
output = dict(policy_logits=policy_logits, baseline=baseline, action=action)
return (output, core_state)
class GlyphEncoder(nn.Module):
"""This glyph encoder first breaks the glyphs (integers up to 6000) to a
more structured representation based on the qualities of the glyph: chars,
colors, specials, groups and subgroup ids..
Eg: invisible hell-hound: char (d), color (red), specials (invisible),
group (monster) subgroup id (type of monster)
Eg: lit dungeon floor: char (.), color (white), specials (none),
group (dungeon) subgroup id (type of dungeon)
An embedding is provided for each of these, and the embeddings are
concatenated, before encoding with a number of CNN layers. This operation
is repeated with a crop of the structured reprentations taken around the
characters position, and the two representations are concatenated
before returning.
"""
def __init__(self, flags, rows, cols, crop_dim, device=None):
super(GlyphEncoder, self).__init__()
self.crop = Crop(rows, cols, crop_dim, crop_dim, device)
K = flags.embedding_dim # number of input filters
L = flags.layers # number of convnet layers
assert (
K % 8 == 0
), "This glyph embedding format needs embedding dim to be multiple of 8"
unit = K // 8
self.chars_embedding = nn.Embedding(256, 2 * unit)
self.colors_embedding = nn.Embedding(16, unit)
self.specials_embedding = nn.Embedding(256, unit)
self.id_pairs_table = nn.parameter.Parameter(
torch.from_numpy(id_pairs_table()), requires_grad=False
)
num_groups = self.id_pairs_table.select(1, 1).max().item() + 1
num_ids = self.id_pairs_table.select(1, 0).max().item() + 1
self.groups_embedding = nn.Embedding(num_groups, unit)
self.ids_embedding = nn.Embedding(num_ids, 3 * unit)
F = 3 # filter dimensions
S = 1 # stride
P = 1 # padding
M = 16 # number of intermediate filters
self.output_filters = 8
in_channels = [K] + [M] * (L - 1)
out_channels = [M] * (L - 1) + [self.output_filters]
h, w, c = rows, cols, crop_dim
conv_extract, conv_extract_crop = [], []
for i in range(L):
conv_extract.append(
nn.Conv2d(
in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=(F, F),
stride=S,
padding=P,
)
)
conv_extract.append(nn.ELU())
conv_extract_crop.append(
nn.Conv2d(
in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=(F, F),
stride=S,
padding=P,
)
)
conv_extract_crop.append(nn.ELU())
# Keep track of output shapes
h = conv_outdim(h, F, P, S)
w = conv_outdim(w, F, P, S)
c = conv_outdim(c, F, P, S)
self.hidden_dim = (h * w + c * c) * self.output_filters
self.extract_representation = nn.Sequential(*conv_extract)
self.extract_crop_representation = nn.Sequential(*conv_extract_crop)
self.select = lambda emb, x: select(emb, x, flags.use_index_select)
def glyphs_to_ids_groups(self, glyphs):
T, B, H, W = glyphs.shape
ids_groups = self.id_pairs_table.index_select(0, glyphs.view(-1).long())
ids = ids_groups.select(1, 0).view(T, B, H, W).long()
groups = ids_groups.select(1, 1).view(T, B, H, W).long()
return [ids, groups]
def forward(self, inputs):
T, B, H, W = inputs["glyphs"].shape
ids, groups = self.glyphs_to_ids_groups(inputs["glyphs"])
glyph_tensors = [
self.select(self.chars_embedding, inputs["chars"].long()),
self.select(self.colors_embedding, inputs["colors"].long()),
self.select(self.specials_embedding, inputs["specials"].long()),
self.select(self.groups_embedding, groups),
self.select(self.ids_embedding, ids),
]
glyphs_emb = torch.cat(glyph_tensors, dim=-1)
glyphs_emb = rearrange(glyphs_emb, "T B H W K -> (T B) K H W")
coordinates = inputs["blstats"].view(T * B, -1).float()[:, :2]
crop_emb = self.crop(glyphs_emb, coordinates)
glyphs_rep = self.extract_representation(glyphs_emb)
glyphs_rep = rearrange(glyphs_rep, "B C H W -> B (C H W)")
assert glyphs_rep.shape[0] == T * B
crop_rep = self.extract_crop_representation(crop_emb)
crop_rep = rearrange(crop_rep, "B C H W -> B (C H W)")
assert crop_rep.shape[0] == T * B
st = torch.cat([glyphs_rep, crop_rep], dim=1)
return st
class MessageEncoder(nn.Module):
"""This model encodes the the topline message into a fixed size representation.
It works by using a learnt embedding for each character before passing the
embeddings through 6 CNN layers.
Inspired by Zhang et al, 2016
Character-level Convolutional Networks for Text Classification
https://arxiv.org/abs/1509.01626
"""
def __init__(self, hidden_dim, embedding_dim, device=None):
super(MessageEncoder, self).__init__()
self.hidden_dim = hidden_dim
self.msg_edim = embedding_dim
self.char_lt = nn.Embedding(NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR)
self.conv1 = nn.Conv1d(self.msg_edim, self.hidden_dim, kernel_size=7)
self.conv2_6_fc = nn.Sequential(
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=3),
# conv2
nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=7),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=3),
# conv3
nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
nn.ReLU(),
# conv4
nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
nn.ReLU(),
# conv5
nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
nn.ReLU(),
# conv6
nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=3),
# fc receives -- [ B x h_dim x 5 ]
Flatten(),
nn.Linear(5 * self.hidden_dim, 2 * self.hidden_dim),
nn.ReLU(),
nn.Linear(2 * self.hidden_dim, self.hidden_dim),
) # final output -- [ B x h_dim x 5 ]
def forward(self, inputs):
T, B, *_ = inputs["message"].shape
messages = inputs["message"].long().view(T * B, -1)
# [ T * B x E x 256 ]
char_emb = self.char_lt(messages).transpose(1, 2)
char_rep = self.conv2_6_fc(self.conv1(char_emb))
return char_rep
class BLStatsEncoder(nn.Module):
"""This model encodes the bottom line stats into a fixed size representation.
It works by simply using two fully-connected layers with ReLU activations.
"""
def __init__(self, num_features, hidden_dim):
super(BLStatsEncoder, self).__init__()
self.num_features = num_features
self.hidden_dim = hidden_dim
self.embed_features = nn.Sequential(
nn.Linear(self.num_features, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.ReLU(),
)
def forward(self, inputs):
T, B, *_ = inputs["blstats"].shape
features = inputs["blstats"][:,:, :NUM_FEATURES]
# -- [B' x F]
features = features.view(T * B, -1).float()
# -- [B x K]
features_emb = self.embed_features(features)
assert features_emb.shape[0] == T * B
return features_emb
class Crop(nn.Module):
def __init__(self, height, width, height_target, width_target, device=None):
super(Crop, self).__init__()
self.width = width
self.height = height
self.width_target = width_target
self.height_target = height_target
width_grid = self._step_to_range(2 / (self.width - 1), self.width_target)
self.width_grid = width_grid[None, :].expand(self.height_target, -1)
height_grid = self._step_to_range(2 / (self.height - 1), height_target)
self.height_grid = height_grid[:, None].expand(-1, self.width_target)
if device is not None:
self.width_grid = self.width_grid.to(device)
self.height_grid = self.height_grid.to(device)
def _step_to_range(self, step, num_steps):
return torch.tensor([step * (i - num_steps // 2) for i in range(num_steps)])
def forward(self, inputs, coordinates):
"""Calculates centered crop around given x,y coordinates.
Args:
inputs [B x H x W] or [B x C x H x W]
coordinates [B x 2] x,y coordinates
Returns:
[B x C x H' x W'] inputs cropped and centered around x,y coordinates.
"""
if inputs.dim() == 3:
inputs = inputs.unsqueeze(1).float()
assert inputs.shape[2] == self.height, "expected %d but found %d" % (
self.height,
inputs.shape[2],
)
assert inputs.shape[3] == self.width, "expected %d but found %d" % (
self.width,
inputs.shape[3],
)
x = coordinates[:, 0]
y = coordinates[:, 1]
x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2)
y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2)
grid = torch.stack(
[
self.width_grid[None, :, :] + x_shift[:, None, None],
self.height_grid[None, :, :] + y_shift[:, None, None],
],
dim=3,
)
crop = torch.round(F.grid_sample(inputs, grid, align_corners=True)).squeeze(1)
return crop
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import numpy as np
from nle.nethack import * # noqa: F403
# flake8: noqa: F405
# TODO: import this from NLE again
NUM_OBJECTS = 453
MAXEXPCHARS = 9
class GlyphGroup(enum.IntEnum):
# See display.h in NetHack.
MON = 0
PET = 1
INVIS = 2
DETECT = 3
BODY = 4
RIDDEN = 5
OBJ = 6
CMAP = 7
EXPLODE = 8
ZAP = 9
SWALLOW = 10
WARNING = 11
STATUE = 12
def id_pairs_table():
"""Returns a lookup table for glyph -> NLE id pairs."""
table = np.zeros([MAX_GLYPH, 2], dtype=np.int16)
num_nle_ids = 0
for glyph in range(GLYPH_MON_OFF, GLYPH_PET_OFF):
table[glyph] = (glyph, GlyphGroup.MON)
num_nle_ids += 1
for glyph in range(GLYPH_PET_OFF, GLYPH_INVIS_OFF):
table[glyph] = (glyph - GLYPH_PET_OFF, GlyphGroup.PET)
for glyph in range(GLYPH_INVIS_OFF, GLYPH_DETECT_OFF):
table[glyph] = (num_nle_ids, GlyphGroup.INVIS)
num_nle_ids += 1
for glyph in range(GLYPH_DETECT_OFF, GLYPH_BODY_OFF):
table[glyph] = (glyph - GLYPH_DETECT_OFF, GlyphGroup.DETECT)
for glyph in range(GLYPH_BODY_OFF, GLYPH_RIDDEN_OFF):
table[glyph] = (glyph - GLYPH_BODY_OFF, GlyphGroup.BODY)
for glyph in range(GLYPH_RIDDEN_OFF, GLYPH_OBJ_OFF):
table[glyph] = (glyph - GLYPH_RIDDEN_OFF, GlyphGroup.RIDDEN)
for glyph in range(GLYPH_OBJ_OFF, GLYPH_CMAP_OFF):
table[glyph] = (num_nle_ids, GlyphGroup.OBJ)
num_nle_ids += 1
for glyph in range(GLYPH_CMAP_OFF, GLYPH_EXPLODE_OFF):
table[glyph] = (num_nle_ids, GlyphGroup.CMAP)
num_nle_ids += 1
for glyph in range(GLYPH_EXPLODE_OFF, GLYPH_ZAP_OFF):
id_ = num_nle_ids + (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS
table[glyph] = (id_, GlyphGroup.EXPLODE)
num_nle_ids += EXPL_MAX
for glyph in range(GLYPH_ZAP_OFF, GLYPH_SWALLOW_OFF):
id_ = num_nle_ids + (glyph - GLYPH_ZAP_OFF) // 4
table[glyph] = (id_, GlyphGroup.ZAP)
num_nle_ids += NUM_ZAP
for glyph in range(GLYPH_SWALLOW_OFF, GLYPH_WARNING_OFF):
table[glyph] = (num_nle_ids, GlyphGroup.SWALLOW)
num_nle_ids += 1
for glyph in range(GLYPH_WARNING_OFF, GLYPH_STATUE_OFF):
table[glyph] = (num_nle_ids, GlyphGroup.WARNING)
num_nle_ids += 1
for glyph in range(GLYPH_STATUE_OFF, MAX_GLYPH):
table[glyph] = (glyph - GLYPH_STATUE_OFF, GlyphGroup.STATUE)
return table
def id_pairs_func(glyph):
result = glyph_to_mon(glyph)
if result != NO_GLYPH:
return result
if glyph_is_invisible(glyph):
return NUMMONS
if glyph_is_body(glyph):
return glyph - GLYPH_BODY_OFF
offset = NUMMONS + 1
# CORPSE handled by glyph_is_body; STATUE handled by glyph_to_mon.
result = glyph_to_obj(glyph)
if result != NO_GLYPH:
return result + offset
offset += NUM_OBJECTS
# I don't understand glyph_to_cmap and/or the GLYPH_EXPLODE_OFF definition
# with MAXPCHARS - MAXEXPCHARS.
if GLYPH_CMAP_OFF <= glyph < GLYPH_EXPLODE_OFF:
return glyph - GLYPH_CMAP_OFF + offset
offset += MAXPCHARS - MAXEXPCHARS
if GLYPH_EXPLODE_OFF <= glyph < GLYPH_ZAP_OFF:
return (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS + offset
offset += EXPL_MAX
if GLYPH_ZAP_OFF <= glyph < GLYPH_SWALLOW_OFF:
return ((glyph - GLYPH_ZAP_OFF) >> 2) + offset
offset += NUM_ZAP
if GLYPH_SWALLOW_OFF <= glyph < GLYPH_WARNING_OFF:
return offset
offset += 1
result = glyph_to_warning(glyph)
if result != NO_GLYPH:
return result + offset
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing as mp
import logging
import os
import threading
import time
import torch
import libtorchbeast
from models import ENVS
logging.basicConfig(
format=(
"[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
),
level=0,
)
# Helper functions for NethackEnv.
def _format_observation(obs):
obs = torch.from_numpy(obs)
return obs.view((1, 1) + obs.shape) # (...) -> (T,B,...).
def create_folders(flags):
# Creates some of the folders that would be created by the filewriter.
logdir = os.path.join(flags.savedir, "archives")
if not os.path.exists(logdir):
logging.info("Creating archive directory: %s" % logdir)
os.makedirs(logdir, exist_ok=True)
else:
logging.info("Found archive directory: %s" % logdir)
def create_env(flags, env_id=0, lock=threading.Lock()):
# commenting out these options for now because they use too much disk space
# archivefile = "nethack.%i.%%(pid)i.%%(time)s.zip" % env_id
# if flags.single_ttyrec and env_id != 0:
# archivefile = None
# logdir = os.path.join(flags.savedir, "archives")
with lock:
env_class = ENVS[flags.env]
kwargs = dict(
savedir=None,
archivefile=None,
character=flags.character,
max_episode_steps=flags.max_num_steps,
observation_keys=(
"glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",
"inv_glyphs",
"inv_strs",
"inv_letters",
"inv_oclasses",
),
penalty_step=flags.penalty_step,
penalty_time=flags.penalty_time,
penalty_mode=flags.fn_penalty_step,
)
if flags.env in ("staircase", "pet", "oracle"):
kwargs.update(reward_win=flags.reward_win, reward_lose=flags.reward_lose)
elif env_id == 0: # print warning once
print("Ignoring flags.reward_win and flags.reward_lose")
if flags.state_counter != "none":
kwargs.update(state_counter=flags.state_counter)
env = env_class(**kwargs)
if flags.seedspath is not None and len(flags.seedspath) > 0:
raise NotImplementedError("seedspath > 0 not implemented yet.")
return env
def serve(flags, server_address, env_id):
env = lambda: create_env(flags, env_id)
server = libtorchbeast.Server(env, server_address=server_address)
server.run()
def main(flags):
if flags.num_seeds > 0:
raise NotImplementedError("num_seeds > 0 not currently implemented.")
create_folders(flags)
if not flags.pipes_basename.startswith("unix:"):
raise Exception("--pipes_basename has to be of the form unix:/some/path.")
processes = []
for i in range(flags.num_servers):
p = mp.Process(
target=serve, args=(flags, f"{flags.pipes_basename}.{i}", i), daemon=True
)
p.start()
processes.append(p)
try:
# We are only here to listen to the interrupt.
while True:
time.sleep(10)
except KeyboardInterrupt:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Run with OMP_NUM_THREADS=1.
#
import collections
import logging
import os
import threading
import time
import timeit
import traceback
import wandb
import omegaconf
import nest
import torch
import libtorchbeast
from core import file_writer
from core import vtrace
from models import create_model
from models.baseline import NetHackNet
from torch import nn
from torch.nn import functional as F
logging.basicConfig(
format=(
"[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
),
level=0,
)
def compute_baseline_loss(advantages):
return 0.5 * torch.sum(advantages ** 2)
def compute_entropy_loss(logits):
policy = F.softmax(logits, dim=-1)
log_policy = F.log_softmax(logits, dim=-1)
entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1)
return -torch.sum(entropy_per_timestep)
def compute_policy_gradient_loss(logits, actions, advantages):
cross_entropy = F.nll_loss(
F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
target=torch.flatten(actions, 0, 1),
reduction="none",
)
cross_entropy = cross_entropy.view_as(advantages)
policy_gradient_loss_per_timestep = cross_entropy * advantages.detach()
return torch.sum(policy_gradient_loss_per_timestep)
def inference(
inference_batcher, model, flags, actor_device, lock=threading.Lock()
): # noqa: B008
with torch.no_grad():
for batch in inference_batcher:
batched_env_outputs, agent_state = batch.get_inputs()
observation, reward, done, *_ = batched_env_outputs
# Observation is a dict with keys 'features' and 'glyphs'.
observation["done"] = done
observation, agent_state = nest.map(
lambda t: t.to(actor_device, non_blocking=True),
(observation, agent_state),
)
with lock:
outputs = model(observation, agent_state)
core_outputs, agent_state = nest.map(lambda t: t.cpu(), outputs)
# Restructuring the output in the way that is expected
# by the functions in actorpool.
outputs = (
tuple(
(
core_outputs["action"],
core_outputs["policy_logits"],
core_outputs["baseline"],
)
),
agent_state,
)
batch.set_outputs(outputs)
# TODO(heiner): Given that our nest implementation doesn't support
# namedtuples, using them here doesn't seem like a good fit. We
# probably want to nestify the environment server and deal with
# dictionaries?
EnvOutput = collections.namedtuple(
"EnvOutput", "frame rewards done episode_step episode_return"
)
AgentOutput = NetHackNet.AgentOutput
Batch = collections.namedtuple("Batch", "env agent")
def learn(
learner_queue,
model,
actor_model,
optimizer,
scheduler,
stats,
flags,
plogger,
learner_device,
lock=threading.Lock(), # noqa: B008
):
for tensors in learner_queue:
tensors = nest.map(lambda t: t.to(learner_device), tensors)
batch, initial_agent_state = tensors
env_outputs, actor_outputs = batch
observation, reward, done, *_ = env_outputs
observation["reward"] = reward
observation["done"] = done
lock.acquire() # Only one thread learning at a time.
output, _ = model(observation, initial_agent_state, learning=True)
# Use last baseline value (from the value function) to bootstrap.
learner_outputs = AgentOutput._make(
(output["action"], output["policy_logits"], output["baseline"])
)
# At this point, the environment outputs at time step `t` are the inputs
# that lead to the learner_outputs at time step `t`. After the following
# shifting, the actions in `batch` and `learner_outputs` at time
# step `t` is what leads to the environment outputs at time step `t`.
batch = nest.map(lambda t: t[1:], batch)
learner_outputs = nest.map(lambda t: t[:-1], learner_outputs)
# Turn into namedtuples again.
env_outputs, actor_outputs = batch
# Note that the env_outputs.frame is now a dict with 'features' and 'glyphs'
# instead of actually being the frame itself. This is currently not a problem
# because we never use actor_outputs.frame in the rest of this function.
env_outputs = EnvOutput._make(env_outputs)
actor_outputs = AgentOutput._make(actor_outputs)
learner_outputs = AgentOutput._make(learner_outputs)
rewards = env_outputs.rewards
if flags.normalize_reward:
model.update_running_moments(rewards)
rewards /= model.get_running_std()
total_loss = 0
# STANDARD EXTRINSIC LOSSES / REWARDS
if flags.entropy_cost > 0:
entropy_loss = flags.entropy_cost * compute_entropy_loss(
learner_outputs.policy_logits
)
total_loss += entropy_loss
discounts = (~env_outputs.done).float() * flags.discounting
# This could be in C++. In TF, this is actually slower on the GPU.
vtrace_returns = vtrace.from_logits(
behavior_policy_logits=actor_outputs.policy_logits,
target_policy_logits=learner_outputs.policy_logits,
actions=actor_outputs.action,
discounts=discounts,
rewards=rewards,
values=learner_outputs.baseline,
bootstrap_value=learner_outputs.baseline[-1],
)
# Compute loss as a weighted sum of the baseline loss, the policy
# gradient loss and an entropy regularization term.
pg_loss = compute_policy_gradient_loss(
learner_outputs.policy_logits,
actor_outputs.action,
vtrace_returns.pg_advantages,
)
baseline_loss = flags.baseline_cost * compute_baseline_loss(
vtrace_returns.vs - learner_outputs.baseline
)
total_loss += pg_loss + baseline_loss
# BACKWARD STEP
optimizer.zero_grad()
total_loss.backward()
if flags.grad_norm_clipping > 0:
nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
optimizer.step()
scheduler.step()
actor_model.load_state_dict(model.state_dict())
# LOGGING
episode_returns = env_outputs.episode_return[env_outputs.done]
stats["step"] = stats.get("step", 0) + flags.unroll_length * flags.batch_size
stats["mean_episode_return"] = torch.mean(episode_returns).item()
stats["mean_episode_step"] = torch.mean(env_outputs.episode_step.float()).item()
stats["total_loss"] = total_loss.item()
stats["pg_loss"] = pg_loss.item()
stats["baseline_loss"] = baseline_loss.item()
if flags.entropy_cost > 0:
stats["entropy_loss"] = entropy_loss.item()
stats["learner_queue_size"] = learner_queue.size()
if not len(episode_returns):
# Hide the mean-of-empty-tuple NaN as it scares people.
stats["mean_episode_return"] = None
# Only logging if at least one episode was finished
if len(episode_returns):
# TODO: log also SPS
plogger.log(stats)
if flags.wandb:
wandb.log(stats, step=stats["step"])
lock.release()
def train(flags):
logging.info("Logging results to %s", flags.savedir)
if isinstance(flags, omegaconf.DictConfig):
flag_dict = omegaconf.OmegaConf.to_container(flags)
else:
flag_dict = vars(flags)
plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir)
if not flags.disable_cuda and torch.cuda.is_available():
logging.info("Using CUDA.")
learner_device = torch.device(flags.learner_device)
actor_device = torch.device(flags.actor_device)
else:
logging.info("Not using CUDA.")
learner_device = torch.device("cpu")
actor_device = torch.device("cpu")
if flags.max_learner_queue_size is None:
flags.max_learner_queue_size = flags.batch_size
# The queue the learner threads will get their data from.
# Setting `minimum_batch_size == maximum_batch_size`
# makes the batch size static. We could make it dynamic, but that
# requires a loss (and learning rate schedule) that's batch size
# independent.
learner_queue = libtorchbeast.BatchingQueue(
batch_dim=1,
minimum_batch_size=flags.batch_size,
maximum_batch_size=flags.batch_size,
check_inputs=True,
maximum_queue_size=flags.max_learner_queue_size,
)
# The "batcher", a queue for the inference call. Will yield
# "batch" objects with `get_inputs` and `set_outputs` methods.
# The batch size of the tensors will be dynamic.
inference_batcher = libtorchbeast.DynamicBatcher(
batch_dim=1,
minimum_batch_size=1,
maximum_batch_size=512,
timeout_ms=100,
check_outputs=True,
)
addresses = []
connections_per_server = 1
pipe_id = 0
while len(addresses) < flags.num_actors:
for _ in range(connections_per_server):
addresses.append(f"{flags.pipes_basename}.{pipe_id}")
if len(addresses) == flags.num_actors:
break
pipe_id += 1
logging.info("Using model %s", flags.model)
model = create_model(flags, learner_device)
plogger.metadata["model_numel"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
logging.info("Number of model parameters: %i", plogger.metadata["model_numel"])
actor_model = create_model(flags, actor_device)
# The ActorPool that will run `flags.num_actors` many loops.
actors = libtorchbeast.ActorPool(
unroll_length=flags.unroll_length,
learner_queue=learner_queue,
inference_batcher=inference_batcher,
env_server_addresses=addresses,
initial_agent_state=model.initial_state(),
)
def run():
try:
actors.run()
except Exception as e:
logging.error("Exception in actorpool thread!")
traceback.print_exc()
print()
raise e
actorpool_thread = threading.Thread(target=run, name="actorpool-thread")
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=flags.learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha,
)
def lr_lambda(epoch):
return (
1
- min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps)
/ flags.total_steps
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
stats = {}
if flags.checkpoint and os.path.exists(flags.checkpoint):
logging.info("Loading checkpoint: %s" % flags.checkpoint)
checkpoint_states = torch.load(
flags.checkpoint, map_location=flags.learner_device
)
model.load_state_dict(checkpoint_states["model_state_dict"])
optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
stats = checkpoint_states["stats"]
logging.info(f"Resuming preempted job, current stats:\n{stats}")
# Initialize actor model like learner model.
actor_model.load_state_dict(model.state_dict())
learner_threads = [
threading.Thread(
target=learn,
name="learner-thread-%i" % i,
args=(
learner_queue,
model,
actor_model,
optimizer,
scheduler,
stats,
flags,
plogger,
learner_device,
),
)
for i in range(flags.num_learner_threads)
]
inference_threads = [
threading.Thread(
target=inference,
name="inference-thread-%i" % i,
args=(inference_batcher, actor_model, flags, actor_device),
)
for i in range(flags.num_inference_threads)
]
actorpool_thread.start()
for t in learner_threads + inference_threads:
t.start()
def checkpoint(checkpoint_path=None):
if flags.checkpoint:
if checkpoint_path is None:
checkpoint_path = flags.checkpoint
logging.info("Saving checkpoint to %s", checkpoint_path)
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"stats": stats,
"flags": vars(flags),
},
checkpoint_path,
)
def format_value(x):
return f"{x:1.5}" if isinstance(x, float) else str(x)
try:
train_start_time = timeit.default_timer()
train_time_offset = stats.get("train_seconds", 0) # used for resuming training
last_checkpoint_time = timeit.default_timer()
dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75]
loop_start_time = timeit.default_timer()
loop_start_step = stats.get("step", 0)
while True:
if loop_start_step >= flags.total_steps:
break
time.sleep(5)
loop_end_time = timeit.default_timer()
loop_end_step = stats.get("step", 0)
stats["train_seconds"] = round(
loop_end_time - train_start_time + train_time_offset, 1
)
if loop_end_time - last_checkpoint_time > 10 * 60:
# Save every 10 min.
checkpoint()
last_checkpoint_time = loop_end_time
if len(dev_checkpoint_intervals) > 0:
step_percentage = loop_end_step / flags.total_steps
i = dev_checkpoint_intervals[0]
if step_percentage > i:
checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar")
dev_checkpoint_intervals = dev_checkpoint_intervals[1:]
logging.info(
"Step %i @ %.1f SPS. Inference batcher size: %i."
" Learner queue size: %i."
" Other stats: (%s)",
loop_end_step,
(loop_end_step - loop_start_step) / (loop_end_time - loop_start_time),
inference_batcher.size(),
learner_queue.size(),
", ".join(
f"{key} = {format_value(value)}" for key, value in stats.items()
),
)
loop_start_time = loop_end_time
loop_start_step = loop_end_step
except KeyboardInterrupt:
pass # Close properly.
else:
logging.info("Learning finished after %i steps.", stats["step"])
checkpoint()
# Done with learning. Let's stop all the ongoing work.
inference_batcher.close()
learner_queue.close()
actorpool_thread.join()
for t in learner_threads + inference_threads:
t.join()
def test(flags):
test_checkpoint = os.path.join(flags.savedir, "test_checkpoint.tar")
checkpoint = os.path.join(flags.load_dir, "checkpoint.tar")
if not os.path.exists(os.path.dirname(test_checkpoint)):
os.makedirs(os.path.dirname(test_checkpoint))
logging.info("Creating test copy of checkpoint '%s'", checkpoint)
checkpoint = torch.load(checkpoint)
for d in checkpoint["optimizer_state_dict"]["param_groups"]:
d["lr"] = 0.0
d["initial_lr"] = 0.0
checkpoint["scheduler_state_dict"]["last_epoch"] = 0
checkpoint["scheduler_state_dict"]["_step_count"] = 0
checkpoint["scheduler_state_dict"]["base_lrs"] = [0.0]
checkpoint["stats"]["step"] = 0
checkpoint["stats"]["_tick"] = 0
flags.checkpoint = test_checkpoint
flags.learning_rate = 0.0
logging.info("Saving test checkpoint to %s", test_checkpoint)
torch.save(checkpoint, test_checkpoint)
train(flags)
def main(flags):
if flags.wandb:
wandb.init(
project=flags.project,
config=vars(flags),
group=flags.group,
entity=flags.entity,
)
if flags.mode == "train":
if flags.write_profiler_trace:
logging.info("Running with profiler.")
with torch.autograd.profiler.profile() as prof:
train(flags)
filename = "chrome-%s.trace" % time.strftime("%Y%m%d-%H%M%S")
logging.info("Writing profiler trace to '%s.gz'", filename)
prof.export_chrome_trace(filename)
os.system("gzip %s" % filename)
else:
train(flags)
elif flags.mode.startswith("test"):
test(flags)
if flags.wandb:
wandb.finish()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Installation for hydra:
pip install hydra-core hydra_colorlog --upgrade
Runs like polybeast but use = to set flags:
python -m polyhydra.py learning_rate=0.001 rnd.twoheaded=true
Run sweep with another -m after the module:
python -m polyhydra.py -m learning_rate=0.01,0.001,0.0001,0.00001 momentum=0,0.5
Baseline should run with:
python polyhydra.py
"""
from pathlib import Path
import logging
import os
import multiprocessing as mp
import hydra
import numpy as np
from omegaconf import OmegaConf, DictConfig
import torch
import polybeast_env
import polybeast_learner
if torch.__version__.startswith("1.5") or torch.__version__.startswith("1.6"):
# pytorch 1.5.* needs this for some reason on the cluster
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
logging.basicConfig(
format=(
"[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
),
level=0,
)
def pipes_basename():
logdir = Path(os.getcwd())
name = ".".join([logdir.parents[1].name, logdir.parents[0].name, logdir.name])
return "unix:/tmp/poly.%s" % name
def get_common_flags(flags):
flags = OmegaConf.to_container(flags)
flags["pipes_basename"] = pipes_basename()
flags["savedir"] = os.getcwd()
return OmegaConf.create(flags)
def get_learner_flags(flags):
lrn_flags = OmegaConf.to_container(flags)
lrn_flags["checkpoint"] = os.path.join(flags["savedir"], "checkpoint.tar")
lrn_flags["entropy_cost"] = float(lrn_flags["entropy_cost"])
return OmegaConf.create(lrn_flags)
def run_learner(flags: DictConfig):
polybeast_learner.main(flags)
def get_environment_flags(flags):
env_flags = OmegaConf.to_container(flags)
env_flags["num_servers"] = flags.num_actors
max_num_steps = 1e6
if flags.env in ("staircase", "pet"):
max_num_steps = 1000
env_flags["max_num_steps"] = int(max_num_steps)
env_flags["seedspath"] = ""
return OmegaConf.create(env_flags)
def run_env(flags):
np.random.seed() # Get new random seed in forked process.
polybeast_env.main(flags)
def symlink_latest(savedir, symlink):
try:
if os.path.islink(symlink):
os.remove(symlink)
if not os.path.exists(symlink):
os.symlink(savedir, symlink)
logging.info("Symlinked log directory: %s" % symlink)
except OSError:
# os.remove() or os.symlink() raced. Don't do anything.
pass
@hydra.main(config_name="config")
def main(flags: DictConfig):
if os.path.exists("config.yaml"):
# this ignores the local config.yaml and replaces it completely with saved one
logging.info("loading existing configuration, we're continuing a previous run")
new_flags = OmegaConf.load("config.yaml")
cli_conf = OmegaConf.from_cli()
# however, you can override parameters from the cli still
# this is useful e.g. if you did total_steps=N before and want to increase it
flags = OmegaConf.merge(new_flags, cli_conf)
if flags.load_dir and os.path.exists(os.path.join(flags.load_dir, "config.yaml")):
new_flags = OmegaConf.load(os.path.join(flags.load_dir, "config.yaml"))
cli_conf = OmegaConf.from_cli()
flags = OmegaConf.merge(new_flags, cli_conf)
logging.info(flags.pretty(resolve=True))
OmegaConf.save(flags, "config.yaml")
flags = get_common_flags(flags)
# set flags for polybeast_env
env_flags = get_environment_flags(flags)
env_processes = []
for _ in range(1):
p = mp.Process(target=run_env, args=(env_flags,))
p.start()
env_processes.append(p)
symlink_latest(
flags.savedir, os.path.join(hydra.utils.get_original_cwd(), "latest")
)
lrn_flags = get_learner_flags(flags)
run_learner(lrn_flags)
for p in env_processes:
p.kill()
p.join()
print('Training Done!')
if __name__ == "__main__":
main()
%% Cell type:markdown id:71dbce49 tags:
# A Brief Intro to NetHack & the NLE
%% Cell type:markdown id:f036704c tags:
> *Welcome, adventurer! You have been heralded from birth as the instrument of the gods. You are destined to recover the Amulet of Yendor for your deity or die in the attempt. Your hour of destiny has come. For the sake of us all: Go bravely!*
This notebook provides a brief overview of the game of NetHack, a glance at the NetHack Learning Environment (NLE) and finally lays down the gauntlet for the NetHack Chellenge!
%% Cell type:markdown id:fdc6fcd4 tags:
# What is NetHack?
NetHack is a [roguelike](https://en.wikipedia.org/wiki/Roguelike) computer game, which was first introduced in the late 1980s. At the beginning of the game your hero is placed into a dungeon, with the goal to descend to the bottom of over 50 procedurally generated levels to retrieve the Amulet of Yendor. Once obtained, your hero must subsequently escape the dungeon, unlocking five extremely challenging final levels, before offering the Amulet to your in-game deity.
A key component of NetHack is that it is *visually* simple, with observations solely making use of ascii characters, yet it is complex in almost every other way!
There are several reasons why it is particularly challenging:
1) The game is randomized, with everything from the map layouts to the impact of actions based on the roll of a dice.
2) Unlike modern games, it is impossible to save. Instead when you die, you begin from scatch. Given the game's randomness (see above) this makes it especially "unforgiving" (as described on the wiki). Indeed, deaths are so common there is even an acronym - YASD, which stands for Yet Another Stupid Death.
3) It is incredibly complex, with hundreds of different characters to observe and many more potential sequences of actions.
Thus, unlike other games played by AI agents, NetHack is not solvable by the average human in just a few hours of gameplay. Instead - expert players often take many years to solve it - assuming they are even able to!
NetHack has been actively developed for decades, and NLE makes use of version 3.6.6, originally released in March 2020.
%% Cell type:markdown id:a652ab9f tags:
## Playing the Game
%% Cell type:markdown id:45fd15f0 tags:
### Choosing your hero
At the start of the game, players are usually asked to choose their *character*'s starting role, race, gender and religious alignment. From the [NetHack Wiki](https://nethackwiki.com/wiki/Player):
> The player character can be any one of the following roles: archeologist, barbarian, cave[wo]man, healer, knight, monk, priest[ess], ranger, rogue, samurai, tourist, valkyrie, or wizard. They each have varying difficulties, strengths, weaknesses, quests and starting items.
>
> The player can also choose from the five races: human, elf, dwarf, gnome, or orc, and the three alignments: lawful, neutral or chaotic. The available races and alignments are dependent on the role one picks.
Each starting combination will alter the game experience, and thus impact the difficulty of the game and the most suitable strategy. For example, wizards start with magic and magical items, while rangers begin with a bow and arrow; elves are generally intelligent whereas dwarves are strong!
It's worth noting these different starting characters can really affect the performance of agents learning to play the game. In the original NLE paper, agents on the Score task (most similar to the NetHack Challenge) averaged 738 for monk, 538 for valkyrie, 314 for wizard - but only 11 for tourist! For the purposes of the NetHack Challenge, the character is randomized during evaluation for the competition, so it is likely wise to consider agents that can perform well across a variety of hero configurations.
%% Cell type:markdown id:1e8bc401 tags:
### Complex Observations
One of the many challenges of NetHack is the richness of the observation space, with fully-formed dungeon, message line and stats bar all rendered as ascii text! Every character (and color) in the dungeon has a symbolic meaning - whether its a [Monster](https://nethackwiki.com/wiki/Monster#Canonical_list_of_monsters), [Item](https://nethackwiki.com/wiki/Item), or just a part of the [Dungeon](https://nethackwiki.com/wiki/Dungeon_feature) itself.
![Dungeon](./example_annotated.png)
#### Dungeon
The dungeon is the main part of the screen the character navigates. The most frequently seen symbols are:
* `@` : You
* `.` : Dungeon Floor
* `<` and `>` : Stairs up and down
* `|` and `-` : Walls
* `+` : Doors
while it is also common to see Fountains: `{`, Traps: `^`, Altars: `_` and Hallways: `#`.
#### Items
NetHack has a [vast number of items](https://nethackwiki.com/wiki/Item) for in-game use, and many objects can be picked up and included in the inventory. Once included, the agent can choose to use them in a number of different ways - often with some imaginative consequences: you can `apply` a towel to a weapon to clean off grease, but you can `wear` it too (it will wrap around your head)!
Heros will need to use items as best as possible to navigate the dungeons, not least in finding fresh food to eat (unless they can find a [different way](https://nethackwiki.com/wiki/Prayer) to stave off hunger).
%% Cell type:markdown id:143f1ca8 tags:
#### Monsters!
A key component of the difficulty of NetHack (and the cause of many heroic deaths) is the presence of [monsters](https://nethackwiki.com/wiki/Monster#Canonical_list_of_monsters). Throughout the game the hero will encounter many of the hundreds of different types of monsters, ranging from simple jackals which can be trivially defeated to other, more challenging obstacles that typically require significant thought to overcome.
For instance, if you walk into a Floating Eye (blue `e`) you will become paralyzed and probably die - this is common for even experienced players who lose concentration! To kill one, the hero can: make use of ranged weapons; blind themselves to avoid looking it at it; become invisible so as not to be seen by it; wear a ring of free action (preventing paralysis); or possess a source reflection (thus reflecting the gaze). Got all that?
What makes this a little tricker is that many of the most challenging monsters may be seen infrequently, potentially only being encountered once across multiple games. Thus, while it is possible to memorize a strategy for a handful or even dozens of monsters, it only takes one to slip through the cracks of memory before it is back to the beginning of the game.
%% Cell type:markdown id:b449cc83 tags:
#### Taking Actions
In order to make the vast array of complex skills possible to achieve, NetHack has a large action space (referred to as `commands`). The game of NetHack takes inputs directly corresponding to keys on the keyboard, including modifiers such as ctrl, shift and meta. The [full list of commands](https://nethackwiki.com/wiki/Commands) is extensive, including both actions and meta-commands such as help or viewing the inventory.
For the NetHack Challenge we provide an action space that is as close to the full set of commands as possible - blocking only a few commands like modifying option settings. This should provide a significant challenge to all AI agents, while also offering them the potential to fully master the game. We note that it may be worthwhile to constrain this with some inductive bias, possibly even considering a curriculum of [increasing action spaces](http://proceedings.mlr.press/v119/farquhar20a.html).
%% Cell type:markdown id:85bfc579 tags:
#### Structure of the NetHack world
The collective name for all levels of the game is the "Mazes of Menace". Your hero starts on the inital Dungeons of Doom, which is above the underworld Gehennom and below the five Planes which form the final stages of the game.
The Dungeons also contain various branches, the locations of which are often randomized. For example, the Gnomish Mines will always be generated between dungeon levels 2 and 4. There is also a Sokoban branch, located between levels 2 and 9. In order to reach the Amulet (and win the game), adventurers must complete the Quest, another branch, the location of which varies depending on the role.
This is just a brief foray into the details of the game. For more detail on the Mazes of Menace see the [nethackwiki page](https://nethackwiki.com/wiki/Mazes_of_Menace).
%% Cell type:markdown id:bc29bf70 tags:
# What is the NetHack Learning Environment (NLE)
The NLE is the OpenAI Gym environment which provides researchers with the ability to train agents on the game of NetHack, presented at NeurIPS 2020.
### `NetHackChallenge-v0`
The NLE contains different NetHack based tasks for agent training, but a new environment has been created especially for the competition: 'NetHackChallenge-v0'. The new environment is based on the 'NetHackScore-v0' task used in the NeurIPS paper, but contains some key modifications to bring out the full experience of NetHack. These are:
* The action space of the environment is greatly expanded to allow all keys on the keyboard
* Menus, yes/no questions, cursor-movement, and text-input modalities are enabled.
* A random character (represented as '@' ) instead of a single default (e.g. 'mon-hum-neu-mal')
This makes the game particularly challenging, while also providing additional opportunity for savvy agents!
NLE is loaded as a gym environment, with all the typical functions that reinforcement learning (RL) researchers will be familiar with. For those using a symbolic approach, this means we typically follow the following few steps:
```python
obs = env.reset() # produces the first observation
done = False # initialize this so we know when episode ends
total_reward = 0 # total reward
while not done:
action = agent.act(obs) # agent processes observation and computes an action
obs, reward, done, info = env.step(action) # updates the new observation and provides the reward/done
total_reward += reward # keep track of cumulative reward
```
When the episode is over (very likely YASD) the total_reward will be the score of the agent. This is used to train RL agents and to get an idea of the current performance of symbolic agents.
## Code Examples
%% Cell type:code id:5a513468 tags:
``` python
import nle
import gym
```
%% Cell type:code id:3d422c13 tags:
``` python
env = gym.make("NetHackChallenge-v0", savedir=None) # (Don't save a recording of the episode)
env.reset() # each reset generates a new dungeon
env.step(1) # move agent '@' north
## WARNING: WILL NOT RENDER ON GITLAB (run locally)
env.render()
```
%% Output
You swap places with your kitten.                                              
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
 --------------.-                                                              
 |........F.....|#                                                             
 |.....$.......f@#                                                             
 |..............|                                                              
 |..............+                                                              
 ----------------                                                              
                                                                               
                                                                               
Agent the Plunderess           St:18 Dx:18 Co:18 In:8 Wi:7 Ch:6 Chaotic S:0    
Dlvl:1 $:0 HP:15(15) Pw:2(2) AC:8 Xp:1/0                                        
%% Cell type:markdown id:a7fbc173 tags:
The NLE observation contains multiple objects, many of which we receive as keys in the observation dictionary. Let's take a look.
%% Cell type:code id:d8cbc4fd tags:
``` python
obs = env.reset()
obs.keys()
```
%% Output
dict_keys(['glyphs', 'chars', 'colors', 'specials', 'blstats', 'message', 'inv_glyphs', 'inv_strs', 'inv_letters', 'inv_oclasses', 'tty_chars', 'tty_colors', 'tty_cursor', 'misc'])
%% Cell type:markdown id:68223810 tags:
#### Observing the Dungeon
The elements **`glyphs`**, **`chars`**, **`colors`**, and **`specials`** are tensors representing the (batched) 2D symbolic observation of the dungeon. The key item is `glyphs` - which are integers uniquely specifying what should be displaye on the screen. These `glyphs` are then mapped to `chars`, `colors` and `specials` which can be rendered by terminals. Our agents primarily use the first three.
* **`glyphs`** - they single integers representing the specific object at a square in the dungeon (e.g. a pet hell-hound)
* **`chars`** - the characters used to render the glyphs on the screen (e.g. `d`)
* **`colors`** - the colors used to render the glyphs on the screen (e.g. red)
* **`specials`** - any special modifications to render the glyphs on the screen (e.g. its a pet!)
%% Cell type:code id:7c8649d5 tags:
``` python
for key in ['glyphs', 'chars', 'colors']:
print("\n{}:\n".format(key))
print("Shape: {}\n".format(obs[key].shape))
print(obs[key])
```
%% Output
glyphs:
Shape: (21, 79)
[[2359 2359 2359 ... 2359 2359 2359]
[2359 2359 2359 ... 2359 2359 2359]
[2359 2359 2359 ... 2359 2359 2359]
...
[2359 2359 2359 ... 2359 2359 2359]
[2359 2359 2359 ... 2359 2359 2359]
[2359 2359 2359 ... 2359 2359 2359]]
chars:
Shape: (21, 79)
[[32 32 32 ... 32 32 32]
[32 32 32 ... 32 32 32]
[32 32 32 ... 32 32 32]
...
[32 32 32 ... 32 32 32]
[32 32 32 ... 32 32 32]
[32 32 32 ... 32 32 32]]
colors:
Shape: (21, 79)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
%% Cell type:markdown id:ab199a8b tags:
#### BLStats and Message
Along the top of the screen is a topline message that the game uses to communicate with you. Paying close attention to what the game tells you can often result in the difference between life and death! The encoding of this message is presented in the observation **`message`**
Also of interest are the stats along the bottom line of the screen. These are extract in **`blstats`** and contain a lot of useful infomation visible below.
%% Cell type:code id:36b8f0b0 tags:
``` python
bl_meaning = [
'hero col', 'hero_row', 'strength_pct', 'strength', 'dexterity', 'constitution',
'intelligence', 'wisdom', 'charisma', 'score', 'hitpoints', 'max_hitpoints', 'depth',
'gold', 'energy', 'max_energy', 'armor_class', 'monster_level', 'experience_level',
'experience_points', 'time', 'hunger_state', 'carrying_capacity', 'dungeon_number', 'level_number'
]
env.render()
obs['blstats']
print()
print('MESSAGE')
print(bytes(obs['message']).decode('ascii'))
print()
print('BL STATS')
print(' '.join(["%s: %d" % (m,s) for m, s in zip(bl_meaning, obs['blstats'])]))
```
%% Output
Hello Agent, welcome to NetHack!  You are a lawful female dwarven Archeologist.
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                            --.--                              
                                            |...|                              
                                            ..d.|                              
                                            |+.@.                              
                                            ---.-                              
                                                                               
                                                                               
                                                                               
Agent the Digger               St:18/01 Dx:8 Co:8 In:16 Wi:16 Ch:8 Lawful S:0  
Dlvl:1 $:0 HP:15(15) Pw:1(1) AC:9 Xp:1/0                                        
MESSAGE
Hello Agent, welcome to NetHack! You are a lawful female dwarven Archeologist.
BL STATS
hero col: 47 hero_row: 16 strength_pct: 19 strength: 19 dexterity: 8 constitution: 8 intelligence: 16 wisdom: 16 charisma: 8 score: 0 hitpoints: 15 max_hitpoints: 15 depth: 1 gold: 0 energy: 1 max_energy: 1 armor_class: 9 monster_level: 0 experience_level: 1 experience_points: 0 time: 1 hunger_state: 1 carrying_capacity: 0 dungeon_number: 0 level_number: 1
%% Cell type:markdown id:c7464df6 tags:
#### Inventory
After this we have a series of entries to signify what's in the inventory.
* **inv_glyphs** - The glyphs corresponding to the items in each slot in the inventory
* **inv_letters** - The letter assigned to the slot in the inventory
* **inv_strs** - The textual description of each item in the inventory
* **inv_oclasses** - The object class of the item in the inventory (potion, scroll etc...)
%% Cell type:code id:f3116c76 tags:
``` python
for let, glyph, strs, oclass in zip(
obs['inv_letters'], obs['inv_glyphs'], obs['inv_strs'], obs['inv_oclasses']):
l = chr(let)
desc = bytes(strs).decode('utf-8')
if let:
print('In slot (%s) - glyph: %d, (class %d) - "%s"' % (l, glyph, oclass, desc))
```
%% Output
In slot (a) - glyph: 1970, (class 2) - "a blessed +2 bullwhip (weapon in hand)"
In slot (b) - glyph: 2020, (class 3) - "a blessed +0 leather jacket (being worn)"
In slot (c) - glyph: 1980, (class 3) - "an uncursed +0 fedora (being worn)"
In slot (d) - glyph: 2174, (class 7) - "4 uncursed food rations"
In slot (e) - glyph: 2140, (class 6) - "a +0 pick-axe (alternate weapon; not wielded)"
In slot (f) - glyph: 2119, (class 6) - "a tinning kit (0:59)"
In slot (g) - glyph: 2350, (class 13) - "an uncursed touchstone"
In slot (h) - glyph: 2098, (class 6) - "an empty uncursed sack"
%% Cell type:markdown id:ea1952a5 tags:
#### Miscellaneous Internal Game State
As you progress through the game, you may encounter some different 'modes' of input. For instance, the game might be asking you a yes or no questions (`Do you really want to pray? yn(n)`) or you may enter a menu that requires a space to quit. Some flags are provided to you to help you work out the current mode in the **`misc`** observation, which is an array of integers
```
misc[0] - boolean (0 or 1) - yn_question # Am I in a yes or no question? (Like after "pray")
misc[1] - boolean (0 or 1) - getline # Am I writing the input to a line? (Like making a wish)
misc[2] - boolean (0 or 1) - xwaitforspace # Am I waiting for a space? (Like when -More- is shown, or after "inventory")
```
%% Cell type:code id:cda04913 tags:
``` python
print(obs["misc"])
```
%% Output
[0 0 0]
%% Cell type:markdown id:ccc71153 tags:
#### Terminal Rendering
Finally NLE provides you with the raw outputs of the terminal screen, should you decide you want to learn from these. This allows you to render menus and popups that might not otherwise be shown on the dungeon.
The observations are simple:
* **`tty_chars`** the characters at each point on the screen
* **`tty_colors`** the colors at each point on the screen
* **`tty_cursor`** the location of the cursor on the screen (NOTE: it's not always on the hero!)
These first two are what's rendered when you call `env.render()` in human mode, and the cursor is pretty self explanatory.
%% Cell type:code id:fdd30978 tags:
``` python
print(obs['tty_cursor'])
env.render()
```
%% Output
[17 47]
Hello Agent, welcome to NetHack!  You are a lawful female dwarven Archeologist.
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                                                               
                                            --.--                              
                                            |...|                              
                                            ..d.|                              
                                            |+.@.                              
                                            ---.-                              
                                                                               
                                                                               
                                                                               
Agent the Digger               St:18/01 Dx:8 Co:8 In:16 Wi:16 Ch:8 Lawful S:0  
Dlvl:1 $:0 HP:15(15) Pw:1(1) AC:9 Xp:1/0                                        
%% Cell type:markdown id:250d9f6a tags:
# Next Steps?
Included in the starter kit is a [Torchbeast](https://arxiv.org/abs/1910.03552) implementation of [IMPALA](https://arxiv.org/abs/1802.01561), a large scale distributed RL algorithm adapted for NLE. A similar model was used in the original NLE paper to produce non-trivial learning curves for environments such as NetHackScore-v0.
In the original NLE paper, the agent architecture was as follows:
![Model](./model.png)
As can be seen, the model utilized both an agent centric view and a global view, which are both processed with convolutional neural network (CNN) layers. In addition, the blstats are processed with an MLP. Finally, the embeddings are passed into an LSTM to deal with partial observability.
The baseline is almost identical except one key difference - we haven't added a CNN encoder for the `message` observation. This architecture may provide a promising starting point for development, but the sky is the limit for new ideas! Check out the [README.md](./nethack_baselines/torchbeast/README.md) to get started!
%% Cell type:markdown id:af86ddfe tags:
And if you want to learn more about NetHack, checkout:
* [NetHackWiki](https://nethackwiki.com/wiki/Main_Page)
* [the Beginner's Guide](https://nethackwiki.com/wiki/Why_do_I_keep_dying%3F)
* [MIT course on NetHack](https://rec.games.roguelike.nethack.narkive.com/n31QHcTe/nethack-class-at-mit-follow-along-virtually) (yes seriously!)
%% Cell type:code id:1a9e9fd8 tags:
``` python
```
notebooks/example_annotated.png

292 KiB

notebooks/example_standalone.png

262 KiB

notebooks/model.png

238 KiB

#!/usr/bin/env python
# This file is the entrypoint for your submission
# You can modify this file to include your code or directly call your functions/modules from here.
import aicrowd_gym
import nle
################################################################
## Ideally you shouldn't need to change this file at all ##
## ##
## This file generates the rollouts, with the specific agent, ##
## batch_size and wrappers specified in subminssion_config.py ##
################################################################
from tqdm import tqdm
import numpy as np
def main():
from envs.batched_env import BatchedEnv
from envs.wrappers import create_env
from submission_config import SubmissionConfig
def run_batched_rollout(num_episodes, batched_env, agent):
"""
This function will be called for training phase.
This function will generate a series of rollouts in a batched manner.
"""
# This allows us to limit the features of the environment
# that we don't want participants to use during the submission
env = aicrowd_gym.make("NetHackScore-v0")
num_envs = batched_env.num_envs
env = aicrowd_gym.make("NetHackScore-v0")
env.reset()
done = False
# This part can be left as is
observations = batched_env.batch_reset()
rewards = [0.0 for _ in range(num_envs)]
dones = [False for _ in range(num_envs)]
infos = [{} for _ in range(num_envs)]
# We mark at the start of each episode if we are 'counting it'
active_envs = [i < num_episodes for i in range(num_envs)]
num_remaining = num_episodes - sum(active_envs)
episode_count = 0
while episode_count < 20:
_, _, done, _ = env.step(env.action_space.sample())
if done:
episode_count += 1
print(episode_count)
env.reset()
pbar = tqdm(total=num_episodes)
ascension_count = 0
all_returns = []
returns = [0.0 for _ in range(num_envs)]
# The evaluator will automatically stop after the episodes based on the development/test phase
while episode_count < num_episodes:
actions = agent.batched_step(observations, rewards, dones, infos)
observations, rewards, dones, infos = batched_env.batch_step(actions)
for i, r in enumerate(rewards):
returns[i] += r
for done_idx in np.where(dones)[0]:
if active_envs[done_idx]:
# We were 'counting' this episode
all_returns.append(returns[done_idx])
episode_count += 1
active_envs[done_idx] = (num_remaining > 0)
num_remaining -= 1
ascension_count += int(infos[done_idx]["is_ascended"])
pbar.update(1)
returns[done_idx] = 0.0
pbar.close()
return ascension_count, all_returns
if __name__ == "__main__":
main()
# AIcrowd will cut the assessment early duing the dev phase
NUM_ASSESSMENTS = 4096
env_make_fn = SubmissionConfig.MAKE_ENV_FN
num_envs = SubmissionConfig.NUM_ENVIRONMENTS
Agent = SubmissionConfig.AGENT
batched_env = BatchedEnv(env_make_fn=env_make_fn, num_envs=num_envs)
agent = Agent(num_envs, batched_env.num_actions)
run_batched_rollout(NUM_ASSESSMENTS, batched_env, agent)
#!/bin/bash
python rollout.py
File added
name: 5
wandb: true
project: nethack_challenge
entity: nethack
group: baseline
mock: false
single_ttyrec: true
num_seeds: 0
write_profiler_trace: false
fn_penalty_step: constant
penalty_time: 0.0
penalty_step: -0.01
reward_lose: 0
reward_win: 100
state_counter: none
character: '@'
mode: train
env: challenge
num_actors: 256
total_steps: 1000000000.0
batch_size: 32
unroll_length: 80
num_learner_threads: 1
num_inference_threads: 1
disable_cuda: false
learner_device: cuda:1
actor_device: cuda:0
learning_rate: 0.0002
grad_norm_clipping: 40
alpha: 0.99
momentum: 0
epsilon: 1.0e-06
entropy_cost: 0.001
baseline_cost: 0.5
discounting: 0.999
normalize_reward: true
model: baseline
use_lstm: true
hidden_dim: 256
embedding_dim: 64
layers: 5
crop_dim: 9
use_index_select: true
restrict_action_space: true
msg:
hidden_dim: 64
embedding_dim: 32
load_dir: null
File added