diff --git a/.gitattributes b/.gitattributes
index ef7168a7beba0bc807f8bdacc036c79273fa3897..6813df25fdcd2b8b7bd5f55534ca72d3aa80eebf 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,2 +1,5 @@
 *.wav filter=lfs diff=lfs merge=lfs -text
 *.pth filter=lfs diff=lfs merge=lfs -text
+submission filter=lfs diff=lfs merge=lfs -text
+submission/* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
diff --git a/agents/batched_agent.py b/agents/batched_agent.py
index 71089572bec7f5a4fa4548264a932e39581d503c..d5e59cec3a1d0ec85fcca6ca3462f46786ab587c 100644
--- a/agents/batched_agent.py
+++ b/agents/batched_agent.py
@@ -11,20 +11,9 @@ class BatchedAgent:
         self.num_envs = num_envs
         self.num_actions = num_actions
 
-    def preprocess_observations(self, observations, rewards, dones, infos):
+    def batched_step(self, observations, rewards, dones, infos):
         """
-        Add any preprocessing steps, for example rerodering/stacking for torch/tf in your model
+        Take list of outputs of each environments and return a list of actions
         """
-        pass
+        raise NotImplementedError
 
-    def preprocess_actions(self, actions):
-        """
-        Add any postprocessing steps, for example converting to lists
-        """
-        pass
-
-    def batched_step(self):
-        """
-        Return a list of actions
-        """
-        pass
diff --git a/agents/torchbeast_agent.py b/agents/torchbeast_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbfb9b8146786d7b919773958717809fc4caf038
--- /dev/null
+++ b/agents/torchbeast_agent.py
@@ -0,0 +1,61 @@
+import torch
+import numpy as np
+
+from agents.batched_agent import BatchedAgent
+
+from nethack_baselines.torchbeast.models import load_model
+
+MODEL_DIR = "./models/torchbeast/example_run"
+
+
+class TorchBeastAgent(BatchedAgent):
+    """
+    A BatchedAgent using the TorchBeast Model
+    """
+
+    def __init__(self, num_envs, num_actions):
+        super().__init__(num_envs, num_actions)
+        self.model_dir = MODEL_DIR
+        self.device = "cuda:0"
+        self.model = load_model(MODEL_DIR, self.device)
+
+        self.core_state = [
+            m.to(self.device) for m in self.model.initial_state(batch_size=num_envs)
+        ]
+
+    def batch_inputs(self, observations, dones):
+        """
+        Convert lists of observations, rewards, dones, infos to tensors for TorchBeast.
+
+        TorchBeast models:
+            * take tensors in the form: [T, B, ...]: B:= batch, T:= unroll (=1)
+            * take "done" as a BOOLEAN observation
+        """
+        states = list(observations[0].keys())
+        obs = {k: [] for k in states}
+        
+        # Unpack List[Dicts] -> Dict[Lists]
+        for o in observations:
+            for k, t in o.items():
+                obs[k].append(t)
+        
+        # Convert to Tensor, Add Unroll Dim (=1), Move to GPU
+        for k in states:
+            obs[k] = torch.Tensor(np.stack(obs[k])[None, ...]).to(self.device)
+        obs["done"] = torch.Tensor(np.array(dones)[None, ...]).bool().to(self.device)
+        return obs, dones
+
+    def batched_step(self, observations, rewards, dones, infos):
+        """
+        Perform a batched step on lists of environment outputs.
+
+        Torchbeast models:
+            * take the core (LSTM) state as input, and return as output
+            * return outputs as a dict of "action", "policy_logits", "baseline"
+        """
+        observations, dones = self.batch_inputs(observations, dones)
+    
+        with torch.no_grad():
+            outputs, self.core_state = self.model(observations, self.core_state)
+        
+        return outputs["action"].cpu().numpy()[0]
diff --git a/agents/torchbeast_batched_agent.py b/agents/torchbeast_batched_agent.py
deleted file mode 100644
index ae9bcf359ab23bd77d159bacb986bd3cafc15889..0000000000000000000000000000000000000000
--- a/agents/torchbeast_batched_agent.py
+++ /dev/null
@@ -1 +0,0 @@
-placeholders
\ No newline at end of file
diff --git a/aicrowd.json b/aicrowd.json
index b89b1dc55ed1952015ab73ba7de4555a604eb556..af5d9212c40f9581f486f3751d6490c8c037ab92 100644
--- a/aicrowd.json
+++ b/aicrowd.json
@@ -4,5 +4,7 @@
   "authors": [
     "aicrowd-bot"
   ],
-  "external_dataset_used": false
+  "external_dataset_used": false,
+  "gpu": true
 }
+
diff --git a/local_evaluation.py b/local_evaluation.py
index 1ae846db360bd8784da333f143d13b8b903e0636..b8b625dc8e967446c6086d8aad1edf7c35063eb8 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -32,6 +32,7 @@ def evaluate():
     agent = Agent(num_envs, num_actions)
 
     run_batched_rollout(batched_env, agent)
+    
 
 if __name__ == '__main__':
     evaluate()
diff --git a/nethack_baselines/torchbeast/README.md b/nethack_baselines/torchbeast/README.md
index 3b94f915737aba1f12a0f067fdba3726bfe02df5..bb4bd0d3b690c760ce345b7ccc1e0d51d216ce8d 100644
--- a/nethack_baselines/torchbeast/README.md
+++ b/nethack_baselines/torchbeast/README.md
@@ -1 +1,90 @@
-Placeholder
+# TorchBeast NetHackChallenge Benchmark
+
+This is a baseline model for the NetHack Challenge based on [TorchBeast](https://github.com/facebookresearch/torchbeast) - FAIR's implementation of IMPALA for PyTorch.
+
+It comes with all the code you need to train, run and submit a model that is based on the results published in the original NLE paper.
+
+This implementation runs with 2 GPUS (one for acting and one for learning), and runs many simultaneous environments with dynamic batching.
+
+## Installation
+
+To get this running all you need to do is follow the TorchBeast installation instructions, on the repo page, and then install the requirements.txt
+
+A Dockerfile is also provided with installation of Torchbeast.
+
+## Running The Baseline
+
+Once installed, in this directory run:
+
+`python polyhydra.py` 
+
+To change parameters, edit `config.yaml`, or to override parameters from the command-line run:
+
+`python polyhydra.py embedding_dim=16`
+
+The training will save checkpoints to a new directory (`outputs`) and should the environments create any outputs, they will be saved to `nle_data` - (by default recordings of episodes are switched off to save space).
+
+The default polybeast runs on 2 GPUs, one for the learner and one for the actors. However, with only one GPU you can run still run polybeast - just override the `actor_device` argument:
+
+`python polyhydra.py actor_device=cpu`
+
+## Making a submission
+
+Take the output directory of your trained model, add the `checkpoint.tar` and `config.yaml` to the git repo. Then change the `SUBMISSION` variable in `rollout.py` in the base of this repository to point to that directory.
+
+After that tag the submission, and push the branch and tag to AIcrowd's gitlab!
+
+
+## Repo Structure
+
+```
+baselines/torchbeast
+├── core/           
+├── models/                #  <- Models HERE
+├── util/           
+├── config.yaml            #  <- Flags HERE
+├── polybeast_env.py       #  <- Training Env HERE
+├── polybeast_learner.py   #  <- Training Loop HERE
+└── polyhydra.py           #  <- main() HERE
+
+```
+
+The structure is simple, compartmentalising the environment setup, training loop and models in to different files. You can tweak any of these separately, and add parameters to the flags (which are passed around).
+
+
+
+## About the Model
+
+This model (`BaselineNet`) we provide is simple and all in `models/baseline.py`.
+
+* It encodes the dungeon into a fixed-size representation (`GlyphEncoder`)
+* It encodes the topline message into a fixed-size representation (`MessageEncoder`)
+* It encodes the bottom line statistics (eg armour class, health) into a fixed-size representation (`BLStatsEncoder`)
+* It concatenates all these outputs into a fixed size, runs this through a fully connected layer, and into an LSTM.
+* The outputs of the LSTM go through policy and baseline heads (since this is an actor-critic alorithm)
+
+
+As you can see there is a lot of data to play with in this game, and plenty to try, both in modelling and in the learning algorithms used.
+
+## Improvement Ideas
+
+*Here are some ideas we haven't tried yet, but might be easy places to start. Happy tinkering!*
+
+
+### Model Improvements (`baseline.py`)
+
+* The model is currently not using the terminal observations (`tty_chars`, `tty_colors`, `tty_cursor`), so it has no idea about menus - could this we make use of this somehow?
+* The bottom-line stats are very informative, but very simply encoded in `BLStatsEncoder` - is there a better way to do this?
+* The `GlyphEncoder` builds a embedding for the glyphs, and then takes a crop of these centered around the player icon coordinates (`@`). Should the crop be reusing these the same embedding matrix? 
+* The current model constrains the vast action space to a smaller subset of actions. Is it too constrained? Or not constrained enough?
+
+###  Environment Improvements (`polybeast_env.py`)
+
+* Opening menus (such as when spellcasting) do not advance the in game timer. However, models can also get stuck in menus as you have to learn what buttons to press to close the menu. Can changing the penalty for not advancing the in-game timer improve the result? 
+* The NetHackChallenge assesses the score on random character assignments. Might it be easier to learn on just a few of these at the beginning of training? 
+
+### Algorithm/Optimisation Improvements (`polybeast_learner.py`)
+
+* Can we add some intrinsic rewards to help our agents learn? 
+* Should we add penalties for disincentivise pathological behaviour we observe? 
+* Can we improve the model by using a different optimizer?
diff --git a/nethack_baselines/torchbeast/config.yaml b/nethack_baselines/torchbeast/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4947756f3de6f029fa4da058ceea3739d0e8b72
--- /dev/null
+++ b/nethack_baselines/torchbeast/config.yaml
@@ -0,0 +1,107 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+defaults:
+- hydra/job_logging: colorlog
+- hydra/hydra_logging: colorlog
+# - hydra/launcher: submitit_slurm
+
+# # To Be Used With hydra submitit_slurm if you have SLURM cluster
+# # pip install hydra-core hydra_colorlog
+# # can set these on the commandline too, e.g. `hydra.launcher.partition=dev`
+# hydra:
+#   launcher:
+#     timeout_min: 4300
+#     cpus_per_task: 20  
+#     gpus_per_node: 2
+#     tasks_per_node: 1
+#     mem_gb: 20
+#     nodes: 1
+#     partition: dev
+#     comment: null  
+#     max_num_timeout: 5  # will requeue on timeout or preemption
+
+
+name: null  # can use this to have multiple runs with same params, eg name=1,2,3,4,5
+
+## WANDB settings
+wandb: false                 # Enable wandb logging.
+project: nethack_challenge   # The wandb project name.
+entity: user1                # The wandb user to log to.
+group: group1                # The wandb group for the run.
+
+# POLYBEAST ENV settings
+mock: false                  # Use mock environment instead of NetHack.
+single_ttyrec: true          # Record ttyrec only for actor 0.
+num_seeds: 0                 # If larger than 0, samples fixed number of environment seeds to be used.'
+write_profiler_trace: false  # Collect and write a profiler trace for chrome://tracing/.
+fn_penalty_step: constant    # Function to accumulate penalty.
+penalty_time: 0.0            # Penalty per time step in the episode.
+penalty_step: -0.01          # Penalty per step in the episode.
+reward_lose: 0               # Reward for losing (dying before finding the staircase).
+reward_win: 100              # Reward for winning (finding the staircase).
+state_counter: none          # Method for counting state visits. Default none.
+character: 'mon-hum-neu-mal' # Specification of the NetHack character.
+                              ## typical characters we use
+                                # 'mon-hum-neu-mal'
+                                # 'val-dwa-law-fem'
+                                # 'wiz-elf-cha-mal'
+                                # 'tou-hum-neu-fem'
+                                # '@'   # random (used in Challenge assessment)
+
+# RUN settings.
+mode: train                  # Training or test mode.
+env: challenge               # Name of Gym environment to create.
+                             # # env (task) names: challenge, staircase, pet, 
+                             #     eat, gold, score, scout, oracle
+
+# TRAINING settings.
+num_actors: 256              # Number of actors.
+total_steps: 1e9             # Total environment steps to train for. Will be cast to int.
+batch_size: 32               # Learner batch size.
+unroll_length: 80            # The unroll length (time dimension).
+num_learner_threads: 1       # Number learner threads.
+num_inference_threads: 1     # Number inference threads.
+disable_cuda: false          # Disable CUDA.
+learner_device: cuda:1       # Set learner device.
+actor_device: cuda:0         # Set actor device.
+
+# OPTIMIZER settings. (RMS Prop)
+learning_rate: 0.0002        # Learning rate.
+grad_norm_clipping: 40       # Global gradient norm clip.
+alpha: 0.99                  # RMSProp smoothing constant.
+momentum: 0                  # RMSProp momentum.
+epsilon: 0.000001            # RMSProp epsilon.
+
+# LOSS settings.
+entropy_cost: 0.001          # Entropy cost/multiplier.
+baseline_cost: 0.5           # Baseline cost/multiplier.
+discounting: 0.999           # Discounting factor.
+normalize_reward: true       # Normalizes reward by dividing by running stdev from mean.
+
+# MODEL settings.
+model: baseline              # Name of model to build (see models/__init__.py).
+use_lstm: true               # Use LSTM in agent model.
+hidden_dim: 256              # Size of hidden representations.
+embedding_dim: 64            # Size of glyph embeddings.
+layers: 5                    # Number of ConvNet Layers for Glyph Model
+crop_dim: 9                  # Size of crop (c x c)
+use_index_select: true       # Whether to use index_select instead of embedding lookup (for speed reasons).
+restrict_action_space: True  # Use a restricted ACTION SPACE (only nethack.USEFUL_ACTIONS)
+
+msg:                      
+  hidden_dim: 64             # Hidden dimension for message encoder.
+  embedding_dim: 32          # Embedding dimension for characters in message encoder.
+
+# TEST settings.    
+load_dir: null               # Path to load a model from for testing
diff --git a/nethack_baselines/torchbeast/core/file_writer.py b/nethack_baselines/torchbeast/core/file_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..99b06fe4fd5ce39194853d5f6b4f248b3b9554ed
--- /dev/null
+++ b/nethack_baselines/torchbeast/core/file_writer.py
@@ -0,0 +1,203 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import csv
+import datetime
+import json
+import logging
+import os
+import time
+import weakref
+
+
+def _save_metadata(path, metadata):
+    metadata["date_save"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
+    with open(path, "w") as f:
+        json.dump(metadata, f, indent=4, sort_keys=True)
+
+
+def gather_metadata():
+    metadata = dict(
+        date_start=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+        env=os.environ.copy(),
+        successful=False,
+    )
+
+    # Git metadata.
+    try:
+        import git
+    except ImportError:
+        logging.warning(
+            "Couldn't import gitpython module; install it with `pip install gitpython`."
+        )
+    else:
+        try:
+            repo = git.Repo(search_parent_directories=True)
+            metadata["git"] = {
+                "commit": repo.commit().hexsha,
+                "is_dirty": repo.is_dirty(),
+                "path": repo.git_dir,
+            }
+            if not repo.head.is_detached:
+                metadata["git"]["branch"] = repo.active_branch.name
+        except git.InvalidGitRepositoryError:
+            pass
+
+    if "git" not in metadata:
+        logging.warning("Couldn't determine git data.")
+
+    # Slurm metadata.
+    if "SLURM_JOB_ID" in os.environ:
+        slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")]
+        metadata["slurm"] = {}
+        for k in slurm_env_keys:
+            d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower()
+            metadata["slurm"][d_key] = os.environ[k]
+
+    return metadata
+
+
+class FileWriter:
+    def __init__(self, xp_args=None, rootdir="~/palaas"):
+        if rootdir == "~/palaas":
+            # make unique id in case someone uses the default rootdir
+            xpid = "{proc}_{unixtime}".format(
+                proc=os.getpid(), unixtime=int(time.time())
+            )
+            rootdir = os.path.join(rootdir, xpid)
+        self.basepath = os.path.expandvars(os.path.expanduser(rootdir))
+
+        self._tick = 0
+
+        # metadata gathering
+        if xp_args is None:
+            xp_args = {}
+        self.metadata = gather_metadata()
+        # we need to copy the args, otherwise when we close the file writer
+        # (and rewrite the args) we might have non-serializable objects (or
+        # other nasty stuff).
+        self.metadata["args"] = copy.deepcopy(xp_args)
+
+        formatter = logging.Formatter("%(message)s")
+        self._logger = logging.getLogger("palaas/out")
+
+        # to stdout handler
+        shandle = logging.StreamHandler()
+        shandle.setFormatter(formatter)
+        self._logger.addHandler(shandle)
+        self._logger.setLevel(logging.INFO)
+
+        # to file handler
+        if not os.path.exists(self.basepath):
+            self._logger.info("Creating log directory: %s", self.basepath)
+            os.makedirs(self.basepath, exist_ok=True)
+        else:
+            self._logger.info("Found log directory: %s", self.basepath)
+
+        self.paths = dict(
+            msg="{base}/out.log".format(base=self.basepath),
+            logs="{base}/logs.csv".format(base=self.basepath),
+            fields="{base}/fields.csv".format(base=self.basepath),
+            meta="{base}/meta.json".format(base=self.basepath),
+        )
+
+        self._logger.info("Saving arguments to %s", self.paths["meta"])
+        if os.path.exists(self.paths["meta"]):
+            self._logger.warning(
+                "Path to meta file already exists. " "Not overriding meta."
+            )
+        else:
+            self.save_metadata()
+
+        self._logger.info("Saving messages to %s", self.paths["msg"])
+        if os.path.exists(self.paths["msg"]):
+            self._logger.warning(
+                "Path to message file already exists. " "New data will be appended."
+            )
+
+        fhandle = logging.FileHandler(self.paths["msg"])
+        fhandle.setFormatter(formatter)
+        self._logger.addHandler(fhandle)
+
+        self._logger.info("Saving logs data to %s", self.paths["logs"])
+        self._logger.info("Saving logs' fields to %s", self.paths["fields"])
+        self.fieldnames = ["_tick", "_time"]
+        if os.path.exists(self.paths["logs"]):
+            self._logger.warning(
+                "Path to log file already exists. " "New data will be appended."
+            )
+            # Override default fieldnames.
+            with open(self.paths["fields"], "r") as csvfile:
+                reader = csv.reader(csvfile)
+                lines = list(reader)
+                if len(lines) > 0:
+                    self.fieldnames = lines[-1]
+            # Override default tick: use the last tick from the logs file plus 1.
+            with open(self.paths["logs"], "r") as csvfile:
+                reader = csv.reader(csvfile)
+                lines = list(reader)
+                # Need at least two lines in order to read the last tick:
+                # the first is the csv header and the second is the first line
+                # of data.
+                if len(lines) > 1:
+                    self._tick = int(lines[-1][0]) + 1
+
+        self._fieldfile = open(self.paths["fields"], "a")
+        self._fieldwriter = csv.writer(self._fieldfile)
+        self._fieldfile.flush()
+        self._logfile = open(self.paths["logs"], "a")
+        self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames)
+
+        # Auto-close (and save) on destruction.
+        weakref.finalize(self, _save_metadata, self.paths["meta"], self.metadata)
+
+    def log(self, to_log, tick=None, verbose=False):
+        if tick is not None:
+            raise NotImplementedError
+        else:
+            to_log["_tick"] = self._tick
+            self._tick += 1
+        to_log["_time"] = time.time()
+
+        old_len = len(self.fieldnames)
+        for k in to_log:
+            if k not in self.fieldnames:
+                self.fieldnames.append(k)
+        if old_len != len(self.fieldnames):
+            self._fieldwriter.writerow(self.fieldnames)
+            self._fieldfile.flush()
+            self._logger.info("Updated log fields: %s", self.fieldnames)
+
+        if to_log["_tick"] == 0:
+            self._logfile.write("# %s\n" % ",".join(self.fieldnames))
+
+        if verbose:
+            self._logger.info(
+                "LOG | %s",
+                ", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]),
+            )
+
+        self._logwriter.writerow(to_log)
+        self._logfile.flush()
+
+    def close(self, successful=True):
+        self.metadata["successful"] = successful
+        self.save_metadata()
+
+        for f in [self._logfile, self._fieldfile]:
+            f.close()
+
+    def save_metadata(self):
+        _save_metadata(self.paths["meta"], self.metadata)
diff --git a/nethack_baselines/torchbeast/core/vtrace.py b/nethack_baselines/torchbeast/core/vtrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d851a04467b9679573c1a6300230905396b1c03
--- /dev/null
+++ b/nethack_baselines/torchbeast/core/vtrace.py
@@ -0,0 +1,136 @@
+# This file taken from
+#     https://github.com/deepmind/scalable_agent/blob/
+#         cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py
+# and modified.
+
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Functions to compute V-trace off-policy actor critic targets.
+
+For details and theory see:
+
+"IMPALA: Scalable Distributed Deep-RL with
+Importance Weighted Actor-Learner Architectures"
+by Espeholt, Soyer, Munos et al.
+
+See https://arxiv.org/abs/1802.01561 for the full paper.
+"""
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+
+VTraceFromLogitsReturns = collections.namedtuple(
+    "VTraceFromLogitsReturns",
+    [
+        "vs",
+        "pg_advantages",
+        "log_rhos",
+        "behavior_action_log_probs",
+        "target_action_log_probs",
+    ],
+)
+
+VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")
+
+
+def action_log_probs(policy_logits, actions):
+    return -F.nll_loss(
+        F.log_softmax(torch.flatten(policy_logits, 0, 1), dim=-1),
+        torch.flatten(actions, 0, 1),
+        reduction="none",
+    ).view_as(actions)
+
+
+def from_logits(
+    behavior_policy_logits,
+    target_policy_logits,
+    actions,
+    discounts,
+    rewards,
+    values,
+    bootstrap_value,
+    clip_rho_threshold=1.0,
+    clip_pg_rho_threshold=1.0,
+):
+    """V-trace for softmax policies."""
+
+    target_action_log_probs = action_log_probs(target_policy_logits, actions)
+    behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions)
+    log_rhos = target_action_log_probs - behavior_action_log_probs
+    vtrace_returns = from_importance_weights(
+        log_rhos=log_rhos,
+        discounts=discounts,
+        rewards=rewards,
+        values=values,
+        bootstrap_value=bootstrap_value,
+        clip_rho_threshold=clip_rho_threshold,
+        clip_pg_rho_threshold=clip_pg_rho_threshold,
+    )
+    return VTraceFromLogitsReturns(
+        log_rhos=log_rhos,
+        behavior_action_log_probs=behavior_action_log_probs,
+        target_action_log_probs=target_action_log_probs,
+        **vtrace_returns._asdict(),
+    )
+
+
+@torch.no_grad()
+def from_importance_weights(
+    log_rhos,
+    discounts,
+    rewards,
+    values,
+    bootstrap_value,
+    clip_rho_threshold=1.0,
+    clip_pg_rho_threshold=1.0,
+):
+    """V-trace from log importance weights."""
+    with torch.no_grad():
+        rhos = torch.exp(log_rhos)
+        if clip_rho_threshold is not None:
+            clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold)
+        else:
+            clipped_rhos = rhos
+
+        cs = torch.clamp(rhos, max=1.0)
+        # Append bootstrapped value to get [v1, ..., v_t+1]
+        values_t_plus_1 = torch.cat(
+            [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0
+        )
+        deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
+
+        acc = torch.zeros_like(bootstrap_value)
+        result = []
+        for t in range(discounts.shape[0] - 1, -1, -1):
+            acc = deltas[t] + discounts[t] * cs[t] * acc
+            result.append(acc)
+        result.reverse()
+        vs_minus_v_xs = torch.stack(result)
+
+        # Add V(x_s) to get v_s.
+        vs = torch.add(vs_minus_v_xs, values)
+
+        # Advantage for policy gradient.
+        vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
+        if clip_pg_rho_threshold is not None:
+            clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold)
+        else:
+            clipped_pg_rhos = rhos
+        pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)
+
+        # Make sure no gradients backpropagated through the returned values.
+        return VTraceReturns(vs=vs, pg_advantages=pg_advantages)
diff --git a/nethack_baselines/torchbeast/models/__init__.py b/nethack_baselines/torchbeast/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2da83299a8a07b9538886c6068c6ad09374602
--- /dev/null
+++ b/nethack_baselines/torchbeast/models/__init__.py
@@ -0,0 +1,56 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nle.env import tasks
+from nle.env.base import DUNGEON_SHAPE
+
+from .baseline import BaselineNet
+
+from omegaconf import OmegaConf
+import torch
+
+
+ENVS = dict(
+    staircase=tasks.NetHackStaircase,
+    score=tasks.NetHackScore,
+    pet=tasks.NetHackStaircasePet,
+    oracle=tasks.NetHackOracle,
+    gold=tasks.NetHackGold,
+    eat=tasks.NetHackEat,
+    scout=tasks.NetHackScout,
+    challenge=tasks.NetHackChallenge,
+)
+
+
+def create_model(flags, device):
+    model_string = flags.model
+    if model_string == "baseline":
+        model_cls = BaselineNet
+    else:
+        raise NotImplementedError("model=%s" % model_string)
+
+    action_space = ENVS[flags.env](savedir=None, archivefile=None)._actions
+
+    model = model_cls(DUNGEON_SHAPE, action_space, flags, device)
+    model.to(device=device)
+    return model
+
+
+def load_model(load_dir, device):
+    flags = OmegaConf.load(load_dir + "/config.yaml")
+    flags.checkpoint = load_dir + "/checkpoint.tar"
+    model = create_model(flags, device)
+    checkpoint_states = torch.load(flags.checkpoint, map_location=device)
+    model.load_state_dict(checkpoint_states["model_state_dict"])
+    return model
diff --git a/nethack_baselines/torchbeast/models/baseline.py b/nethack_baselines/torchbeast/models/baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..0599315c2d82c5f9951ddc478c90a7e211e7e18c
--- /dev/null
+++ b/nethack_baselines/torchbeast/models/baseline.py
@@ -0,0 +1,496 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import torch
+from torch import nn
+from torch.nn import functional as F
+from einops import rearrange
+
+
+from nle import nethack
+
+from .util import id_pairs_table
+import numpy as np
+
+NUM_GLYPHS = nethack.MAX_GLYPH
+NUM_FEATURES = nethack.BLSTATS_SHAPE[0]
+PAD_CHAR = 0
+NUM_CHARS = 256
+
+
+def get_action_space_mask(action_space, reduced_action_space):
+    mask = np.array([int(a in reduced_action_space) for a in action_space])
+    return torch.Tensor(mask)
+
+
+def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1):
+    """Return the dimension after applying a convolution along one axis"""
+    return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride)
+
+
+def select(embedding_layer, x, use_index_select):
+    """Use index select instead of default forward to possible speed up embedding."""
+    if use_index_select:
+        out = embedding_layer.weight.index_select(0, x.view(-1))
+        # handle reshaping x to 1-d and output back to N-d
+        return out.view(x.shape + (-1,))
+    else:
+        return embedding_layer(x)
+
+
+class NetHackNet(nn.Module):
+    """This base class simply provides a skeleton for running with torchbeast."""
+
+    AgentOutput = collections.namedtuple("AgentOutput", "action policy_logits baseline")
+
+    def __init__(self):
+        super(NetHackNet, self).__init__()
+        self.register_buffer("reward_sum", torch.zeros(()))
+        self.register_buffer("reward_m2", torch.zeros(()))
+        self.register_buffer("reward_count", torch.zeros(()).fill_(1e-8))
+
+    def forward(self, inputs, core_state):
+        raise NotImplementedError
+
+    def initial_state(self, batch_size=1):
+        return ()
+
+    @torch.no_grad()
+    def update_running_moments(self, reward_batch):
+        """Maintains a running mean of reward."""
+        new_count = len(reward_batch)
+        new_sum = torch.sum(reward_batch)
+        new_mean = new_sum / new_count
+
+        curr_mean = self.reward_sum / self.reward_count
+        new_m2 = torch.sum((reward_batch - new_mean) ** 2) + (
+            (self.reward_count * new_count)
+            / (self.reward_count + new_count)
+            * (new_mean - curr_mean) ** 2
+        )
+
+        self.reward_count += new_count
+        self.reward_sum += new_sum
+        self.reward_m2 += new_m2
+
+    @torch.no_grad()
+    def get_running_std(self):
+        """Returns standard deviation of the running mean of the reward."""
+        return torch.sqrt(self.reward_m2 / self.reward_count)
+
+
+class BaselineNet(NetHackNet):
+    """This model combines the encodings of the glyphs, top line message and
+    blstats into a single fixed-size representation, which is then passed to
+    an LSTM core before generating a policy and value head for use in an IMPALA
+    like architecture.
+
+    This model was based on 'neurips2020release' tag on the NLE repo, itself
+    based on Kuttler et al, 2020
+    The NetHack Learning Environment
+    https://arxiv.org/abs/2006.13760
+    """
+
+    def __init__(self, observation_shape, action_space, flags, device):
+        super(BaselineNet, self).__init__()
+
+        self.flags = flags
+
+        self.observation_shape = observation_shape
+        self.num_actions = len(action_space)
+
+        self.H = observation_shape[0]
+        self.W = observation_shape[1]
+
+        self.use_lstm = flags.use_lstm
+        self.h_dim = flags.hidden_dim
+
+        # GLYPH + CROP MODEL
+        self.glyph_model = GlyphEncoder(flags, self.H, self.W, flags.crop_dim, device)
+
+        # MESSAGING MODEL
+        self.msg_model = MessageEncoder(
+            flags.msg.hidden_dim, flags.msg.embedding_dim, device
+        )
+
+        # BLSTATS MODEL
+        self.blstats_model = BLStatsEncoder(NUM_FEATURES, flags.embedding_dim)
+
+        out_dim = (
+            self.blstats_model.hidden_dim
+            + self.glyph_model.hidden_dim
+            + self.msg_model.hidden_dim
+        )
+
+        self.fc = nn.Sequential(
+            nn.Linear(out_dim, self.h_dim),
+            nn.ReLU(),
+            nn.Linear(self.h_dim, self.h_dim),
+            nn.ReLU(),
+        )
+
+        if self.use_lstm:
+            self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)
+
+        self.policy = nn.Linear(self.h_dim, self.num_actions)
+        self.baseline = nn.Linear(self.h_dim, 1)
+
+        if flags.restrict_action_space:
+            reduced_space = nethack.USEFUL_ACTIONS
+            logits_mask = get_action_space_mask(action_space, reduced_space)
+            self.policy_logits_mask = nn.parameter.Parameter(
+                logits_mask, requires_grad=False
+            )
+
+    def initial_state(self, batch_size=1):
+        return tuple(
+            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
+            for _ in range(2)
+        )
+
+    def forward(self, inputs, core_state, learning=False):
+        T, B, H, W = inputs["glyphs"].shape
+
+        reps = []
+
+        # -- [B' x K] ; B' == (T x B)
+        glyphs_rep = self.glyph_model(inputs)
+        reps.append(glyphs_rep)
+
+        # -- [B' x K]
+        char_rep = self.msg_model(inputs)
+        reps.append(char_rep)
+
+        # -- [B' x K]
+        features_emb = self.blstats_model(inputs)
+        reps.append(features_emb)
+
+        # -- [B' x K]
+        st = torch.cat(reps, dim=1)
+
+        # -- [B' x K]
+        st = self.fc(st)
+
+        if self.use_lstm:
+            core_input = st.view(T, B, -1)
+            core_output_list = []
+            notdone = (~inputs["done"]).float()
+            for input, nd in zip(core_input.unbind(), notdone.unbind()):
+                # Reset core state to zero whenever an episode ended.
+                # Make `done` broadcastable with (num_layers, B, hidden_size)
+                # states:
+                nd = nd.view(1, -1, 1)
+                core_state = tuple(nd * t for t in core_state)
+                output, core_state = self.core(input.unsqueeze(0), core_state)
+                core_output_list.append(output)
+            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
+        else:
+            core_output = st
+
+        # -- [B' x A]
+        policy_logits = self.policy(core_output)
+
+        # -- [B' x 1]
+        baseline = self.baseline(core_output)
+
+        if self.flags.restrict_action_space:
+            policy_logits = policy_logits * self.policy_logits_mask + (
+                (1 - self.policy_logits_mask) * -1e10
+            )
+
+        if self.training:
+            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
+        else:
+            # Don't sample when testing.
+            action = torch.argmax(policy_logits, dim=1)
+
+        policy_logits = policy_logits.view(T, B, -1)
+        baseline = baseline.view(T, B)
+        action = action.view(T, B)
+
+        output = dict(policy_logits=policy_logits, baseline=baseline, action=action)
+        return (output, core_state)
+
+
+class GlyphEncoder(nn.Module):
+    """This glyph encoder first breaks the glyphs (integers up to 6000) to a
+    more structured representation based on the qualities of the glyph: chars,
+    colors, specials, groups and subgroup ids..
+       Eg: invisible hell-hound: char (d), color (red), specials (invisible),
+                                 group (monster) subgroup id (type of monster)
+       Eg: lit dungeon floor: char (.), color (white), specials (none),
+                              group (dungeon) subgroup id (type of dungeon)
+
+    An embedding is provided for each of these, and the embeddings are
+    concatenated, before encoding with a number of CNN layers.  This operation
+    is repeated with a crop of the structured reprentations taken around the
+    characters position, and the two representations are concatenated
+    before returning.
+    """
+
+    def __init__(self, flags, rows, cols, crop_dim, device=None):
+        super(GlyphEncoder, self).__init__()
+
+        self.crop = Crop(rows, cols, crop_dim, crop_dim, device)
+        K = flags.embedding_dim  # number of input filters
+        L = flags.layers  # number of convnet layers
+
+        assert (
+            K % 8 == 0
+        ), "This glyph embedding format needs embedding dim to be multiple of 8"
+        unit = K // 8
+        self.chars_embedding = nn.Embedding(256, 2 * unit)
+        self.colors_embedding = nn.Embedding(16, unit)
+        self.specials_embedding = nn.Embedding(256, unit)
+
+        self.id_pairs_table = nn.parameter.Parameter(
+            torch.from_numpy(id_pairs_table()), requires_grad=False
+        )
+        num_groups = self.id_pairs_table.select(1, 1).max().item() + 1
+        num_ids = self.id_pairs_table.select(1, 0).max().item() + 1
+
+        self.groups_embedding = nn.Embedding(num_groups, unit)
+        self.ids_embedding = nn.Embedding(num_ids, 3 * unit)
+
+        F = 3  # filter dimensions
+        S = 1  # stride
+        P = 1  # padding
+        M = 16  # number of intermediate filters
+        self.output_filters = 8
+
+        in_channels = [K] + [M] * (L - 1)
+        out_channels = [M] * (L - 1) + [self.output_filters]
+
+        h, w, c = rows, cols, crop_dim
+        conv_extract, conv_extract_crop = [], []
+        for i in range(L):
+            conv_extract.append(
+                nn.Conv2d(
+                    in_channels=in_channels[i],
+                    out_channels=out_channels[i],
+                    kernel_size=(F, F),
+                    stride=S,
+                    padding=P,
+                )
+            )
+            conv_extract.append(nn.ELU())
+
+            conv_extract_crop.append(
+                nn.Conv2d(
+                    in_channels=in_channels[i],
+                    out_channels=out_channels[i],
+                    kernel_size=(F, F),
+                    stride=S,
+                    padding=P,
+                )
+            )
+            conv_extract_crop.append(nn.ELU())
+
+            # Keep track of output shapes
+            h = conv_outdim(h, F, P, S)
+            w = conv_outdim(w, F, P, S)
+            c = conv_outdim(c, F, P, S)
+
+        self.hidden_dim = (h * w + c * c) * self.output_filters
+        self.extract_representation = nn.Sequential(*conv_extract)
+        self.extract_crop_representation = nn.Sequential(*conv_extract_crop)
+        self.select = lambda emb, x: select(emb, x, flags.use_index_select)
+
+    def glyphs_to_ids_groups(self, glyphs):
+        T, B, H, W = glyphs.shape
+        ids_groups = self.id_pairs_table.index_select(0, glyphs.view(-1).long())
+        ids = ids_groups.select(1, 0).view(T, B, H, W).long()
+        groups = ids_groups.select(1, 1).view(T, B, H, W).long()
+        return [ids, groups]
+
+    def forward(self, inputs):
+        T, B, H, W = inputs["glyphs"].shape
+        ids, groups = self.glyphs_to_ids_groups(inputs["glyphs"])
+
+        glyph_tensors = [
+            self.select(self.chars_embedding, inputs["chars"].long()),
+            self.select(self.colors_embedding, inputs["colors"].long()),
+            self.select(self.specials_embedding, inputs["specials"].long()),
+            self.select(self.groups_embedding, groups),
+            self.select(self.ids_embedding, ids),
+        ]
+
+        glyphs_emb = torch.cat(glyph_tensors, dim=-1)
+        glyphs_emb = rearrange(glyphs_emb, "T B H W K -> (T B) K H W")
+
+        coordinates = inputs["blstats"].view(T * B, -1).float()[:, :2]
+        crop_emb = self.crop(glyphs_emb, coordinates)
+
+        glyphs_rep = self.extract_representation(glyphs_emb)
+        glyphs_rep = rearrange(glyphs_rep, "B C H W -> B (C H W)")
+        assert glyphs_rep.shape[0] == T * B
+
+        crop_rep = self.extract_crop_representation(crop_emb)
+        crop_rep = rearrange(crop_rep, "B C H W -> B (C H W)")
+        assert crop_rep.shape[0] == T * B
+
+        st = torch.cat([glyphs_rep, crop_rep], dim=1)
+        return st
+
+
+class MessageEncoder(nn.Module):
+    """This model encodes the the topline message into a fixed size representation.
+
+    It works by using a learnt embedding for each character before passing the
+    embeddings through 6 CNN layers.
+
+    Inspired by Zhang et al, 2016
+    Character-level Convolutional Networks for Text Classification
+    https://arxiv.org/abs/1509.01626
+    """
+
+    def __init__(self, hidden_dim, embedding_dim, device=None):
+        super(MessageEncoder, self).__init__()
+
+        self.hidden_dim = hidden_dim
+        self.msg_edim = embedding_dim
+
+        self.char_lt = nn.Embedding(NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR)
+        self.conv1 = nn.Conv1d(self.msg_edim, self.hidden_dim, kernel_size=7)
+        self.conv2_6_fc = nn.Sequential(
+            nn.ReLU(),
+            nn.MaxPool1d(kernel_size=3, stride=3),
+            # conv2
+            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=7),
+            nn.ReLU(),
+            nn.MaxPool1d(kernel_size=3, stride=3),
+            # conv3
+            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
+            nn.ReLU(),
+            # conv4
+            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
+            nn.ReLU(),
+            # conv5
+            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
+            nn.ReLU(),
+            # conv6
+            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
+            nn.ReLU(),
+            nn.MaxPool1d(kernel_size=3, stride=3),
+            # fc receives -- [ B x h_dim x 5 ]
+            Flatten(),
+            nn.Linear(5 * self.hidden_dim, 2 * self.hidden_dim),
+            nn.ReLU(),
+            nn.Linear(2 * self.hidden_dim, self.hidden_dim),
+        )  # final output -- [ B x h_dim x 5 ]
+
+    def forward(self, inputs):
+        T, B, *_ = inputs["message"].shape
+        messages = inputs["message"].long().view(T * B, -1)
+        # [ T * B x E x 256 ]
+        char_emb = self.char_lt(messages).transpose(1, 2)
+        char_rep = self.conv2_6_fc(self.conv1(char_emb))
+        return char_rep
+
+
+class BLStatsEncoder(nn.Module):
+    """This model encodes the bottom line stats into a fixed size representation.
+
+    It works by simply using two fully-connected layers with ReLU activations.
+    """
+
+    def __init__(self, num_features, hidden_dim):
+        super(BLStatsEncoder, self).__init__()
+        self.num_features = num_features
+        self.hidden_dim = hidden_dim
+        self.embed_features = nn.Sequential(
+            nn.Linear(self.num_features, self.hidden_dim),
+            nn.ReLU(),
+            nn.Linear(self.hidden_dim, self.hidden_dim),
+            nn.ReLU(),
+        )
+
+    def forward(self, inputs):
+        T, B, *_ = inputs["blstats"].shape
+
+        features = inputs["blstats"]
+        # -- [B' x F]
+        features = features.view(T * B, -1).float()
+        # -- [B x K]
+        features_emb = self.embed_features(features)
+
+        assert features_emb.shape[0] == T * B
+        return features_emb
+
+
+class Crop(nn.Module):
+    def __init__(self, height, width, height_target, width_target, device=None):
+        super(Crop, self).__init__()
+        self.width = width
+        self.height = height
+        self.width_target = width_target
+        self.height_target = height_target
+
+        width_grid = self._step_to_range(2 / (self.width - 1), self.width_target)
+        self.width_grid = width_grid[None, :].expand(self.height_target, -1)
+
+        height_grid = self._step_to_range(2 / (self.height - 1), height_target)
+        self.height_grid = height_grid[:, None].expand(-1, self.width_target)
+
+        if device is not None:
+            self.width_grid = self.width_grid.to(device)
+            self.height_grid = self.height_grid.to(device)
+
+    def _step_to_range(self, step, num_steps):
+        return torch.tensor([step * (i - num_steps // 2) for i in range(num_steps)])
+
+    def forward(self, inputs, coordinates):
+        """Calculates centered crop around given x,y coordinates.
+
+        Args:
+           inputs [B x H x W] or [B x C x H x W]
+           coordinates [B x 2] x,y coordinates
+
+        Returns:
+           [B x C x H' x W'] inputs cropped and centered around x,y coordinates.
+        """
+        if inputs.dim() == 3:
+            inputs = inputs.unsqueeze(1).float()
+
+        assert inputs.shape[2] == self.height, "expected %d but found %d" % (
+            self.height,
+            inputs.shape[2],
+        )
+        assert inputs.shape[3] == self.width, "expected %d but found %d" % (
+            self.width,
+            inputs.shape[3],
+        )
+
+        x = coordinates[:, 0]
+        y = coordinates[:, 1]
+
+        x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2)
+        y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2)
+
+        grid = torch.stack(
+            [
+                self.width_grid[None, :, :] + x_shift[:, None, None],
+                self.height_grid[None, :, :] + y_shift[:, None, None],
+            ],
+            dim=3,
+        )
+
+        crop = torch.round(F.grid_sample(inputs, grid, align_corners=True)).squeeze(1)
+        return crop
+
+
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
diff --git a/nethack_baselines/torchbeast/models/util.py b/nethack_baselines/torchbeast/models/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..30352401fce77335de9d9cf4322b57b57f79b53f
--- /dev/null
+++ b/nethack_baselines/torchbeast/models/util.py
@@ -0,0 +1,142 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+
+import numpy as np
+
+from nle.nethack import *  # noqa: F403
+
+# flake8: noqa: F405
+
+# TODO: import this from NLE again
+NUM_OBJECTS = 453
+MAXEXPCHARS = 9
+
+
+class GlyphGroup(enum.IntEnum):
+    # See display.h in NetHack.
+    MON = 0
+    PET = 1
+    INVIS = 2
+    DETECT = 3
+    BODY = 4
+    RIDDEN = 5
+    OBJ = 6
+    CMAP = 7
+    EXPLODE = 8
+    ZAP = 9
+    SWALLOW = 10
+    WARNING = 11
+    STATUE = 12
+
+
+def id_pairs_table():
+    """Returns a lookup table for glyph -> NLE id pairs."""
+    table = np.zeros([MAX_GLYPH, 2], dtype=np.int16)
+
+    num_nle_ids = 0
+
+    for glyph in range(GLYPH_MON_OFF, GLYPH_PET_OFF):
+        table[glyph] = (glyph, GlyphGroup.MON)
+        num_nle_ids += 1
+
+    for glyph in range(GLYPH_PET_OFF, GLYPH_INVIS_OFF):
+        table[glyph] = (glyph - GLYPH_PET_OFF, GlyphGroup.PET)
+
+    for glyph in range(GLYPH_INVIS_OFF, GLYPH_DETECT_OFF):
+        table[glyph] = (num_nle_ids, GlyphGroup.INVIS)
+        num_nle_ids += 1
+
+    for glyph in range(GLYPH_DETECT_OFF, GLYPH_BODY_OFF):
+        table[glyph] = (glyph - GLYPH_DETECT_OFF, GlyphGroup.DETECT)
+
+    for glyph in range(GLYPH_BODY_OFF, GLYPH_RIDDEN_OFF):
+        table[glyph] = (glyph - GLYPH_BODY_OFF, GlyphGroup.BODY)
+
+    for glyph in range(GLYPH_RIDDEN_OFF, GLYPH_OBJ_OFF):
+        table[glyph] = (glyph - GLYPH_RIDDEN_OFF, GlyphGroup.RIDDEN)
+
+    for glyph in range(GLYPH_OBJ_OFF, GLYPH_CMAP_OFF):
+        table[glyph] = (num_nle_ids, GlyphGroup.OBJ)
+        num_nle_ids += 1
+
+    for glyph in range(GLYPH_CMAP_OFF, GLYPH_EXPLODE_OFF):
+        table[glyph] = (num_nle_ids, GlyphGroup.CMAP)
+        num_nle_ids += 1
+
+    for glyph in range(GLYPH_EXPLODE_OFF, GLYPH_ZAP_OFF):
+        id_ = num_nle_ids + (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS
+        table[glyph] = (id_, GlyphGroup.EXPLODE)
+
+    num_nle_ids += EXPL_MAX
+
+    for glyph in range(GLYPH_ZAP_OFF, GLYPH_SWALLOW_OFF):
+        id_ = num_nle_ids + (glyph - GLYPH_ZAP_OFF) // 4
+        table[glyph] = (id_, GlyphGroup.ZAP)
+
+    num_nle_ids += NUM_ZAP
+
+    for glyph in range(GLYPH_SWALLOW_OFF, GLYPH_WARNING_OFF):
+        table[glyph] = (num_nle_ids, GlyphGroup.SWALLOW)
+    num_nle_ids += 1
+
+    for glyph in range(GLYPH_WARNING_OFF, GLYPH_STATUE_OFF):
+        table[glyph] = (num_nle_ids, GlyphGroup.WARNING)
+        num_nle_ids += 1
+
+    for glyph in range(GLYPH_STATUE_OFF, MAX_GLYPH):
+        table[glyph] = (glyph - GLYPH_STATUE_OFF, GlyphGroup.STATUE)
+
+    return table
+
+
+def id_pairs_func(glyph):
+    result = glyph_to_mon(glyph)
+    if result != NO_GLYPH:
+        return result
+    if glyph_is_invisible(glyph):
+        return NUMMONS
+    if glyph_is_body(glyph):
+        return glyph - GLYPH_BODY_OFF
+
+    offset = NUMMONS + 1
+
+    # CORPSE handled by glyph_is_body; STATUE handled by glyph_to_mon.
+    result = glyph_to_obj(glyph)
+    if result != NO_GLYPH:
+        return result + offset
+    offset += NUM_OBJECTS
+
+    # I don't understand glyph_to_cmap and/or the GLYPH_EXPLODE_OFF definition
+    # with MAXPCHARS - MAXEXPCHARS.
+    if GLYPH_CMAP_OFF <= glyph < GLYPH_EXPLODE_OFF:
+        return glyph - GLYPH_CMAP_OFF + offset
+    offset += MAXPCHARS - MAXEXPCHARS
+
+    if GLYPH_EXPLODE_OFF <= glyph < GLYPH_ZAP_OFF:
+        return (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS + offset
+    offset += EXPL_MAX
+
+    if GLYPH_ZAP_OFF <= glyph < GLYPH_SWALLOW_OFF:
+        return ((glyph - GLYPH_ZAP_OFF) >> 2) + offset
+    offset += NUM_ZAP
+
+    if GLYPH_SWALLOW_OFF <= glyph < GLYPH_WARNING_OFF:
+        return offset
+    offset += 1
+
+    result = glyph_to_warning(glyph)
+    if result != NO_GLYPH:
+        return result + offset
diff --git a/nethack_baselines/torchbeast/polybeast_env.py b/nethack_baselines/torchbeast/polybeast_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..7073b20b6632915f825e93ae093c5823d9ed51bf
--- /dev/null
+++ b/nethack_baselines/torchbeast/polybeast_env.py
@@ -0,0 +1,127 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import multiprocessing as mp
+import logging
+import os
+import threading
+import time
+
+import torch
+
+import libtorchbeast
+
+from models import ENVS
+
+
+logging.basicConfig(
+    format=(
+        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
+    ),
+    level=0,
+)
+
+
+# Helper functions for NethackEnv.
+def _format_observation(obs):
+    obs = torch.from_numpy(obs)
+    return obs.view((1, 1) + obs.shape)  # (...) -> (T,B,...).
+
+
+def create_folders(flags):
+    # Creates some of the folders that would be created by the filewriter.
+    logdir = os.path.join(flags.savedir, "archives")
+    if not os.path.exists(logdir):
+        logging.info("Creating archive directory: %s" % logdir)
+        os.makedirs(logdir, exist_ok=True)
+    else:
+        logging.info("Found archive directory: %s" % logdir)
+
+
+def create_env(flags, env_id=0, lock=threading.Lock()):
+    # commenting out these options for now because they use too much disk space
+    # archivefile = "nethack.%i.%%(pid)i.%%(time)s.zip" % env_id
+    # if flags.single_ttyrec and env_id != 0:
+    #     archivefile = None
+
+    # logdir = os.path.join(flags.savedir, "archives")
+
+    with lock:
+        env_class = ENVS[flags.env]
+        kwargs = dict(
+            savedir=None,
+            archivefile=None,
+            character=flags.character,
+            max_episode_steps=flags.max_num_steps,
+            observation_keys=(
+                "glyphs",
+                "chars",
+                "colors",
+                "specials",
+                "blstats",
+                "message",
+                "tty_chars",
+                "tty_colors",
+                "tty_cursor",
+                "inv_glyphs",
+                "inv_strs",
+                "inv_letters",
+                "inv_oclasses",
+            ),
+            penalty_step=flags.penalty_step,
+            penalty_time=flags.penalty_time,
+            penalty_mode=flags.fn_penalty_step,
+        )
+        if flags.env in ("staircase", "pet", "oracle"):
+            kwargs.update(reward_win=flags.reward_win, reward_lose=flags.reward_lose)
+        elif env_id == 0:  # print warning once
+            print("Ignoring flags.reward_win and flags.reward_lose")
+        if flags.state_counter != "none":
+            kwargs.update(state_counter=flags.state_counter)
+        env = env_class(**kwargs)
+        if flags.seedspath is not None and len(flags.seedspath) > 0:
+            raise NotImplementedError("seedspath > 0 not implemented yet.")
+
+        return env
+
+
+def serve(flags, server_address, env_id):
+    env = lambda: create_env(flags, env_id)
+    server = libtorchbeast.Server(env, server_address=server_address)
+    server.run()
+
+
+def main(flags):
+    if flags.num_seeds > 0:
+        raise NotImplementedError("num_seeds > 0 not currently implemented.")
+
+    create_folders(flags)
+
+    if not flags.pipes_basename.startswith("unix:"):
+        raise Exception("--pipes_basename has to be of the form unix:/some/path.")
+
+    processes = []
+    for i in range(flags.num_servers):
+        p = mp.Process(
+            target=serve, args=(flags, f"{flags.pipes_basename}.{i}", i), daemon=True
+        )
+        p.start()
+        processes.append(p)
+
+    try:
+        # We are only here to listen to the interrupt.
+        while True:
+            time.sleep(10)
+    except KeyboardInterrupt:
+        pass
diff --git a/nethack_baselines/torchbeast/polybeast_learner.py b/nethack_baselines/torchbeast/polybeast_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca79d451c37fdb5e64218a594c67f50707081213
--- /dev/null
+++ b/nethack_baselines/torchbeast/polybeast_learner.py
@@ -0,0 +1,517 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Run with OMP_NUM_THREADS=1.
+#
+
+import collections
+import logging
+import os
+import threading
+import time
+import timeit
+import traceback
+
+import wandb
+import omegaconf
+import nest
+import torch
+
+import libtorchbeast
+
+from core import file_writer
+from core import vtrace
+
+from models import create_model
+from models.baseline import NetHackNet
+
+from torch import nn
+from torch.nn import functional as F
+
+
+logging.basicConfig(
+    format=(
+        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
+    ),
+    level=0,
+)
+
+
+def compute_baseline_loss(advantages):
+    return 0.5 * torch.sum(advantages ** 2)
+
+
+def compute_entropy_loss(logits):
+    policy = F.softmax(logits, dim=-1)
+    log_policy = F.log_softmax(logits, dim=-1)
+    entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1)
+    return -torch.sum(entropy_per_timestep)
+
+
+def compute_policy_gradient_loss(logits, actions, advantages):
+    cross_entropy = F.nll_loss(
+        F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
+        target=torch.flatten(actions, 0, 1),
+        reduction="none",
+    )
+    cross_entropy = cross_entropy.view_as(advantages)
+    policy_gradient_loss_per_timestep = cross_entropy * advantages.detach()
+    return torch.sum(policy_gradient_loss_per_timestep)
+
+
+def inference(
+    inference_batcher, model, flags, actor_device, lock=threading.Lock()
+):  # noqa: B008
+    with torch.no_grad():
+        for batch in inference_batcher:
+            batched_env_outputs, agent_state = batch.get_inputs()
+            observation, reward, done, *_ = batched_env_outputs
+            # Observation is a dict with keys 'features' and 'glyphs'.
+            observation["done"] = done
+            observation, agent_state = nest.map(
+                lambda t: t.to(actor_device, non_blocking=True),
+                (observation, agent_state),
+            )
+            with lock:
+                outputs = model(observation, agent_state)
+            core_outputs, agent_state = nest.map(lambda t: t.cpu(), outputs)
+            # Restructuring the output in the way that is expected
+            # by the functions in actorpool.
+            outputs = (
+                tuple(
+                    (
+                        core_outputs["action"],
+                        core_outputs["policy_logits"],
+                        core_outputs["baseline"],
+                    )
+                ),
+                agent_state,
+            )
+            batch.set_outputs(outputs)
+
+
+# TODO(heiner): Given that our nest implementation doesn't support
+# namedtuples, using them here doesn't seem like a good fit. We
+# probably want to nestify the environment server and deal with
+# dictionaries?
+EnvOutput = collections.namedtuple(
+    "EnvOutput", "frame rewards done episode_step episode_return"
+)
+AgentOutput = NetHackNet.AgentOutput
+Batch = collections.namedtuple("Batch", "env agent")
+
+
+def learn(
+    learner_queue,
+    model,
+    actor_model,
+    optimizer,
+    scheduler,
+    stats,
+    flags,
+    plogger,
+    learner_device,
+    lock=threading.Lock(),  # noqa: B008
+):
+    for tensors in learner_queue:
+        tensors = nest.map(lambda t: t.to(learner_device), tensors)
+
+        batch, initial_agent_state = tensors
+        env_outputs, actor_outputs = batch
+        observation, reward, done, *_ = env_outputs
+        observation["reward"] = reward
+        observation["done"] = done
+
+        lock.acquire()  # Only one thread learning at a time.
+
+        output, _ = model(observation, initial_agent_state, learning=True)
+
+        # Use last baseline value (from the value function) to bootstrap.
+        learner_outputs = AgentOutput._make(
+            (output["action"], output["policy_logits"], output["baseline"])
+        )
+
+        # At this point, the environment outputs at time step `t` are the inputs
+        # that lead to the learner_outputs at time step `t`. After the following
+        # shifting, the actions in `batch` and `learner_outputs` at time
+        # step `t` is what leads to the environment outputs at time step `t`.
+        batch = nest.map(lambda t: t[1:], batch)
+        learner_outputs = nest.map(lambda t: t[:-1], learner_outputs)
+
+        # Turn into namedtuples again.
+        env_outputs, actor_outputs = batch
+        # Note that the env_outputs.frame is now a dict with 'features' and 'glyphs'
+        # instead of actually being the frame itself. This is currently not a problem
+        # because we never use actor_outputs.frame in the rest of this function.
+        env_outputs = EnvOutput._make(env_outputs)
+        actor_outputs = AgentOutput._make(actor_outputs)
+        learner_outputs = AgentOutput._make(learner_outputs)
+
+        rewards = env_outputs.rewards
+        if flags.normalize_reward:
+            model.update_running_moments(rewards)
+            rewards /= model.get_running_std()
+
+        total_loss = 0
+
+        # STANDARD EXTRINSIC LOSSES / REWARDS
+        if flags.entropy_cost > 0:
+            entropy_loss = flags.entropy_cost * compute_entropy_loss(
+                learner_outputs.policy_logits
+            )
+            total_loss += entropy_loss
+
+        discounts = (~env_outputs.done).float() * flags.discounting
+
+        # This could be in C++. In TF, this is actually slower on the GPU.
+        vtrace_returns = vtrace.from_logits(
+            behavior_policy_logits=actor_outputs.policy_logits,
+            target_policy_logits=learner_outputs.policy_logits,
+            actions=actor_outputs.action,
+            discounts=discounts,
+            rewards=rewards,
+            values=learner_outputs.baseline,
+            bootstrap_value=learner_outputs.baseline[-1],
+        )
+
+        # Compute loss as a weighted sum of the baseline loss, the policy
+        # gradient loss and an entropy regularization term.
+        pg_loss = compute_policy_gradient_loss(
+            learner_outputs.policy_logits,
+            actor_outputs.action,
+            vtrace_returns.pg_advantages,
+        )
+        baseline_loss = flags.baseline_cost * compute_baseline_loss(
+            vtrace_returns.vs - learner_outputs.baseline
+        )
+        total_loss += pg_loss + baseline_loss
+
+        # BACKWARD STEP
+        optimizer.zero_grad()
+        total_loss.backward()
+        if flags.grad_norm_clipping > 0:
+            nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
+        optimizer.step()
+        scheduler.step()
+
+        actor_model.load_state_dict(model.state_dict())
+
+        # LOGGING
+        episode_returns = env_outputs.episode_return[env_outputs.done]
+        stats["step"] = stats.get("step", 0) + flags.unroll_length * flags.batch_size
+        stats["mean_episode_return"] = torch.mean(episode_returns).item()
+        stats["mean_episode_step"] = torch.mean(env_outputs.episode_step.float()).item()
+        stats["total_loss"] = total_loss.item()
+        stats["pg_loss"] = pg_loss.item()
+        stats["baseline_loss"] = baseline_loss.item()
+        if flags.entropy_cost > 0:
+            stats["entropy_loss"] = entropy_loss.item()
+
+        stats["learner_queue_size"] = learner_queue.size()
+
+        if not len(episode_returns):
+            # Hide the mean-of-empty-tuple NaN as it scares people.
+            stats["mean_episode_return"] = None
+
+        # Only logging if at least one episode was finished
+        if len(episode_returns):
+            # TODO: log also SPS
+            plogger.log(stats)
+            if flags.wandb:
+                wandb.log(stats, step=stats["step"])
+
+        lock.release()
+
+
+def train(flags):
+    logging.info("Logging results to %s", flags.savedir)
+    if isinstance(flags, omegaconf.DictConfig):
+        flag_dict = omegaconf.OmegaConf.to_container(flags)
+    else:
+        flag_dict = vars(flags)
+    plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir)
+
+    if not flags.disable_cuda and torch.cuda.is_available():
+        logging.info("Using CUDA.")
+        learner_device = torch.device(flags.learner_device)
+        actor_device = torch.device(flags.actor_device)
+    else:
+        logging.info("Not using CUDA.")
+        learner_device = torch.device("cpu")
+        actor_device = torch.device("cpu")
+
+    if flags.max_learner_queue_size is None:
+        flags.max_learner_queue_size = flags.batch_size
+
+    # The queue the learner threads will get their data from.
+    # Setting `minimum_batch_size == maximum_batch_size`
+    # makes the batch size static. We could make it dynamic, but that
+    # requires a loss (and learning rate schedule) that's batch size
+    # independent.
+    learner_queue = libtorchbeast.BatchingQueue(
+        batch_dim=1,
+        minimum_batch_size=flags.batch_size,
+        maximum_batch_size=flags.batch_size,
+        check_inputs=True,
+        maximum_queue_size=flags.max_learner_queue_size,
+    )
+
+    # The "batcher", a queue for the inference call. Will yield
+    # "batch" objects with `get_inputs` and `set_outputs` methods.
+    # The batch size of the tensors will be dynamic.
+    inference_batcher = libtorchbeast.DynamicBatcher(
+        batch_dim=1,
+        minimum_batch_size=1,
+        maximum_batch_size=512,
+        timeout_ms=100,
+        check_outputs=True,
+    )
+
+    addresses = []
+    connections_per_server = 1
+    pipe_id = 0
+    while len(addresses) < flags.num_actors:
+        for _ in range(connections_per_server):
+            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
+            if len(addresses) == flags.num_actors:
+                break
+        pipe_id += 1
+
+    logging.info("Using model %s", flags.model)
+
+    model = create_model(flags, learner_device)
+
+    plogger.metadata["model_numel"] = sum(
+        p.numel() for p in model.parameters() if p.requires_grad
+    )
+
+    logging.info("Number of model parameters: %i", plogger.metadata["model_numel"])
+
+    actor_model = create_model(flags, actor_device)
+
+    # The ActorPool that will run `flags.num_actors` many loops.
+    actors = libtorchbeast.ActorPool(
+        unroll_length=flags.unroll_length,
+        learner_queue=learner_queue,
+        inference_batcher=inference_batcher,
+        env_server_addresses=addresses,
+        initial_agent_state=model.initial_state(),
+    )
+
+    def run():
+        try:
+            actors.run()
+        except Exception as e:
+            logging.error("Exception in actorpool thread!")
+            traceback.print_exc()
+            print()
+            raise e
+
+    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")
+
+    optimizer = torch.optim.RMSprop(
+        model.parameters(),
+        lr=flags.learning_rate,
+        momentum=flags.momentum,
+        eps=flags.epsilon,
+        alpha=flags.alpha,
+    )
+
+    def lr_lambda(epoch):
+        return (
+            1
+            - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps)
+            / flags.total_steps
+        )
+
+    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+    stats = {}
+
+    if flags.checkpoint and os.path.exists(flags.checkpoint):
+        logging.info("Loading checkpoint: %s" % flags.checkpoint)
+        checkpoint_states = torch.load(
+            flags.checkpoint, map_location=flags.learner_device
+        )
+        model.load_state_dict(checkpoint_states["model_state_dict"])
+        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
+        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
+        stats = checkpoint_states["stats"]
+        logging.info(f"Resuming preempted job, current stats:\n{stats}")
+
+    # Initialize actor model like learner model.
+    actor_model.load_state_dict(model.state_dict())
+
+    learner_threads = [
+        threading.Thread(
+            target=learn,
+            name="learner-thread-%i" % i,
+            args=(
+                learner_queue,
+                model,
+                actor_model,
+                optimizer,
+                scheduler,
+                stats,
+                flags,
+                plogger,
+                learner_device,
+            ),
+        )
+        for i in range(flags.num_learner_threads)
+    ]
+    inference_threads = [
+        threading.Thread(
+            target=inference,
+            name="inference-thread-%i" % i,
+            args=(inference_batcher, actor_model, flags, actor_device),
+        )
+        for i in range(flags.num_inference_threads)
+    ]
+
+    actorpool_thread.start()
+    for t in learner_threads + inference_threads:
+        t.start()
+
+    def checkpoint(checkpoint_path=None):
+        if flags.checkpoint:
+            if checkpoint_path is None:
+                checkpoint_path = flags.checkpoint
+            logging.info("Saving checkpoint to %s", checkpoint_path)
+            torch.save(
+                {
+                    "model_state_dict": model.state_dict(),
+                    "optimizer_state_dict": optimizer.state_dict(),
+                    "scheduler_state_dict": scheduler.state_dict(),
+                    "stats": stats,
+                    "flags": vars(flags),
+                },
+                checkpoint_path,
+            )
+
+    def format_value(x):
+        return f"{x:1.5}" if isinstance(x, float) else str(x)
+
+    try:
+        train_start_time = timeit.default_timer()
+        train_time_offset = stats.get("train_seconds", 0)  # used for resuming training
+        last_checkpoint_time = timeit.default_timer()
+
+        dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75]
+
+        loop_start_time = timeit.default_timer()
+        loop_start_step = stats.get("step", 0)
+        while True:
+            if loop_start_step >= flags.total_steps:
+                break
+            time.sleep(5)
+            loop_end_time = timeit.default_timer()
+            loop_end_step = stats.get("step", 0)
+
+            stats["train_seconds"] = round(
+                loop_end_time - train_start_time + train_time_offset, 1
+            )
+
+            if loop_end_time - last_checkpoint_time > 10 * 60:
+                # Save every 10 min.
+                checkpoint()
+                last_checkpoint_time = loop_end_time
+
+            if len(dev_checkpoint_intervals) > 0:
+                step_percentage = loop_end_step / flags.total_steps
+                i = dev_checkpoint_intervals[0]
+                if step_percentage > i:
+                    checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar")
+                    dev_checkpoint_intervals = dev_checkpoint_intervals[1:]
+
+            logging.info(
+                "Step %i @ %.1f SPS. Inference batcher size: %i."
+                " Learner queue size: %i."
+                " Other stats: (%s)",
+                loop_end_step,
+                (loop_end_step - loop_start_step) / (loop_end_time - loop_start_time),
+                inference_batcher.size(),
+                learner_queue.size(),
+                ", ".join(
+                    f"{key} = {format_value(value)}" for key, value in stats.items()
+                ),
+            )
+            loop_start_time = loop_end_time
+            loop_start_step = loop_end_step
+    except KeyboardInterrupt:
+        pass  # Close properly.
+    else:
+        logging.info("Learning finished after %i steps.", stats["step"])
+
+    checkpoint()
+
+    # Done with learning. Let's stop all the ongoing work.
+    inference_batcher.close()
+    learner_queue.close()
+
+    actorpool_thread.join()
+
+    for t in learner_threads + inference_threads:
+        t.join()
+
+
+def test(flags):
+    test_checkpoint = os.path.join(flags.savedir, "test_checkpoint.tar")
+    checkpoint = os.path.join(flags.load_dir, "checkpoint.tar")
+    if not os.path.exists(os.path.dirname(test_checkpoint)):
+        os.makedirs(os.path.dirname(test_checkpoint))
+
+    logging.info("Creating test copy of checkpoint '%s'", checkpoint)
+
+    checkpoint = torch.load(checkpoint)
+    for d in checkpoint["optimizer_state_dict"]["param_groups"]:
+        d["lr"] = 0.0
+        d["initial_lr"] = 0.0
+
+    checkpoint["scheduler_state_dict"]["last_epoch"] = 0
+    checkpoint["scheduler_state_dict"]["_step_count"] = 0
+    checkpoint["scheduler_state_dict"]["base_lrs"] = [0.0]
+    checkpoint["stats"]["step"] = 0
+    checkpoint["stats"]["_tick"] = 0
+
+    flags.checkpoint = test_checkpoint
+    flags.learning_rate = 0.0
+
+    logging.info("Saving test checkpoint to %s", test_checkpoint)
+    torch.save(checkpoint, test_checkpoint)
+
+    train(flags)
+
+
+def main(flags):
+    if flags.wandb:
+        wandb.init(
+            project=flags.project,
+            config=vars(flags),
+            group=flags.group,
+            entity=flags.entity,
+        )
+    if flags.mode == "train":
+        if flags.write_profiler_trace:
+            logging.info("Running with profiler.")
+            with torch.autograd.profiler.profile() as prof:
+                train(flags)
+            filename = "chrome-%s.trace" % time.strftime("%Y%m%d-%H%M%S")
+            logging.info("Writing profiler trace to '%s.gz'", filename)
+            prof.export_chrome_trace(filename)
+            os.system("gzip %s" % filename)
+        else:
+            train(flags)
+    elif flags.mode.startswith("test"):
+        test(flags)
diff --git a/nethack_baselines/torchbeast/polyhydra.py b/nethack_baselines/torchbeast/polyhydra.py
new file mode 100644
index 0000000000000000000000000000000000000000..1554574a401afbdc1a0bc48dfcf4aa4192cc379b
--- /dev/null
+++ b/nethack_baselines/torchbeast/polyhydra.py
@@ -0,0 +1,149 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Installation for hydra:
+pip install hydra-core hydra_colorlog --upgrade
+
+Runs like polybeast but use = to set flags:
+python -m polyhydra.py learning_rate=0.001 rnd.twoheaded=true
+
+Run sweep with another -m after the module:
+python -m polyhydra.py -m learning_rate=0.01,0.001,0.0001,0.00001 momentum=0,0.5
+
+Baseline should run with:
+python polyhydra.py
+"""
+
+from pathlib import Path
+import logging
+import os
+import multiprocessing as mp
+
+import hydra
+import numpy as np
+from omegaconf import OmegaConf, DictConfig
+
+import torch
+
+import polybeast_env
+import polybeast_learner
+
+if torch.__version__.startswith("1.5") or torch.__version__.startswith("1.6"):
+    # pytorch 1.5.* needs this for some reason on the cluster
+    os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+logging.basicConfig(
+    format=(
+        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
+    ),
+    level=0,
+)
+
+
+def pipes_basename():
+    logdir = Path(os.getcwd())
+    name = ".".join([logdir.parents[1].name, logdir.parents[0].name, logdir.name])
+    return "unix:/tmp/poly.%s" % name
+
+
+def get_common_flags(flags):
+    flags = OmegaConf.to_container(flags)
+    flags["pipes_basename"] = pipes_basename()
+    flags["savedir"] = os.getcwd()
+    return OmegaConf.create(flags)
+
+
+def get_learner_flags(flags):
+    lrn_flags = OmegaConf.to_container(flags)
+    lrn_flags["checkpoint"] = os.path.join(flags["savedir"], "checkpoint.tar")
+    lrn_flags["entropy_cost"] = float(lrn_flags["entropy_cost"])
+    return OmegaConf.create(lrn_flags)
+
+
+def run_learner(flags: DictConfig):
+    polybeast_learner.main(flags)
+
+
+def get_environment_flags(flags):
+    env_flags = OmegaConf.to_container(flags)
+    env_flags["num_servers"] = flags.num_actors
+    max_num_steps = 1e6
+    if flags.env in ("staircase", "pet"):
+        max_num_steps = 1000
+    env_flags["max_num_steps"] = int(max_num_steps)
+    env_flags["seedspath"] = ""
+    return OmegaConf.create(env_flags)
+
+
+def run_env(flags):
+    np.random.seed()  # Get new random seed in forked process.
+    polybeast_env.main(flags)
+
+
+def symlink_latest(savedir, symlink):
+    try:
+        if os.path.islink(symlink):
+            os.remove(symlink)
+        if not os.path.exists(symlink):
+            os.symlink(savedir, symlink)
+            logging.info("Symlinked log directory: %s" % symlink)
+    except OSError:
+        # os.remove() or os.symlink() raced. Don't do anything.
+        pass
+
+
+@hydra.main(config_name="config")
+def main(flags: DictConfig):
+    if os.path.exists("config.yaml"):
+        # this ignores the local config.yaml and replaces it completely with saved one
+        logging.info("loading existing configuration, we're continuing a previous run")
+        new_flags = OmegaConf.load("config.yaml")
+        cli_conf = OmegaConf.from_cli()
+        # however, you can override parameters from the cli still
+        # this is useful e.g. if you did total_steps=N before and want to increase it
+        flags = OmegaConf.merge(new_flags, cli_conf)
+    if flags.load_dir and os.path.exists(os.path.join(flags.load_dir, "config.yaml")):
+        new_flags = OmegaConf.load(os.path.join(flags.load_dir, "config.yaml"))
+        cli_conf = OmegaConf.from_cli()
+        flags = OmegaConf.merge(new_flags, cli_conf)
+
+    logging.info(flags.pretty(resolve=True))
+    OmegaConf.save(flags, "config.yaml")
+
+    flags = get_common_flags(flags)
+
+    # set flags for polybeast_env
+    env_flags = get_environment_flags(flags)
+    env_processes = []
+    for _ in range(1):
+        p = mp.Process(target=run_env, args=(env_flags,))
+        p.start()
+        env_processes.append(p)
+
+    symlink_latest(
+        flags.savedir, os.path.join(hydra.utils.get_original_cwd(), "latest")
+    )
+
+    lrn_flags = get_learner_flags(flags)
+    run_learner(lrn_flags)
+
+    for p in env_processes:
+        p.kill()
+        p.join()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/requirements.txt b/requirements.txt
index 7b7604086ad50cdfeef9923c0ba707c126ed24b3..887774a23a4ba8be9a6614cbf004d4ed400a603e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,7 @@
+torch
+einops
+hydra-core
+hydra_colorlog
 aicrowd-api
 aicrowd-gym
 numpy
diff --git a/rollout.py b/rollout.py
index 00fd7421b7dd9297618f1ce49c6beb430f44c550..89899aecdfb24a4fad309114b2dcd007bbb1a43d 100644
--- a/rollout.py
+++ b/rollout.py
@@ -1,17 +1,22 @@
 #!/usr/bin/env python
 
-############################################################
-## Ideally you shouldn't need to change this file at all  ##
-############################################################
-
+################################################################
+## Ideally you shouldn't need to change this file at all      ##
+##                                                            ##
+## This file generates the rollouts, with the specific agent, ##
+## batch_size and wrappers specified in subminssion_config.py ##
+################################################################
+from tqdm import tqdm
 import numpy as np
 
 from envs.batched_env import BactchedEnv
 from submission_config import SubmissionConfig
 
+NUM_ASSESSMENTS = 512
+
 def run_batched_rollout(batched_env, agent):
     """
-    This function will be called the rollout
+    This function will generate a series of rollouts in a batched manner.
     """
 
     num_envs = batched_env.num_envs
@@ -22,30 +27,52 @@ def run_batched_rollout(batched_env, agent):
     dones = [False for _ in range(num_envs)]
     infos = [{} for _ in range(num_envs)]
 
+    # We mark at the start of each episode if we are 'counting it'
+    active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)]
+    num_remaining = NUM_ASSESSMENTS - sum(active_envs)
+    
     episode_count = 0
+    pbar = tqdm(total=NUM_ASSESSMENTS)
 
+    all_returns = []
+    returns = [0.0 for _ in range(num_envs)]
     # The evaluator will automatically stop after the episodes based on the development/test phase
-    while episode_count < 10000:
-        actions = agent.batched_step(observations, rewards, dones, infos) 
+    while episode_count < NUM_ASSESSMENTS:
+        actions = agent.batched_step(observations, rewards, dones, infos)
 
         observations, rewards, dones, infos = batched_env.batch_step(actions)
+        
+        for i, r in enumerate(rewards):
+            returns[i] += r
+        
         for done_idx in np.where(dones)[0]:
             observations[done_idx] = batched_env.single_env_reset(done_idx)
-            episode_count += 1
-            print("Episodes Completed :", episode_count)
+ 
+            if active_envs[done_idx]:
+                # We were 'counting' this episode
+                all_returns.append(returns[done_idx])
+                episode_count += 1
+                
+                active_envs[done_idx] = (num_remaining > 0)
+                num_remaining -= 1
+                
+                pbar.update(1)
+            
+            returns[done_idx] = 0.0
+    return all_returns
 
 if __name__ == "__main__":
-
     submission_env_make_fn = SubmissionConfig.submission_env_make_fn
-    NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS 
+    NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS
     Agent = SubmissionConfig.Submision_Agent
 
-    batched_env = BactchedEnv(env_make_fn=submission_env_make_fn, 
-                              num_envs=NUM_PARALLEL_ENVIRONMENTS)
+    batched_env = BactchedEnv(
+        env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS
+    )
 
     num_envs = batched_env.num_envs
     num_actions = batched_env.num_actions
-    
+
     agent = Agent(num_envs, num_actions)
 
     run_batched_rollout(batched_env, agent)
diff --git a/submission_config.py b/submission_config.py
index f7d548260474eb7a780ed98f8b57c7a3e28c6f23..df2916b6fd9a4319d34f5d7e3296f0b8df13ae08 100644
--- a/submission_config.py
+++ b/submission_config.py
@@ -1,5 +1,5 @@
 from agents.random_batched_agent import RandomAgent
-# from agents.torchbeast_batched_agent import TorchBeastAgent
+from agents.torchbeast_agent import TorchBeastAgent
 # from agents.rllib_batched_agent import RLlibAgent
 
 from submission_wrappers import addtimelimitwrapper_fn
@@ -15,15 +15,15 @@ from submission_wrappers import addtimelimitwrapper_fn
 
 class SubmissionConfig:
     ## Add your own agent class
-    # Submision_Agent = TorchBeastAgent
+    Submision_Agent = TorchBeastAgent
     # Submision_Agent = RLlibAgent
-    Submision_Agent = RandomAgent
+    # Submision_Agent = RandomAgent
 
 
     ## Change the NUM_PARALLEL_ENVIRONMENTS as you need
     ## for example reduce it if your GPU doesn't fit
     ## Increasing above 32 is not advisable for the Nethack Challenge 2021
-    NUM_PARALLEL_ENVIRONMENTS = 16
+    NUM_PARALLEL_ENVIRONMENTS = 32
 
 
     ## Add a function that creates your nethack env
diff --git a/submission_wrappers.py b/submission_wrappers.py
index c35fa17e5eb82a90df80f5f01aaff2c6112c9a28..9f75d9de6baad97a888c7d462f6a7e6af6575625 100644
--- a/submission_wrappers.py
+++ b/submission_wrappers.py
@@ -8,5 +8,5 @@ def addtimelimitwrapper_fn():
     Should return a gym env which wraps the nethack gym env
     """
     env = nethack_make_fn()
-    env = TimeLimit(env, max_episode_steps=10_000_0000)
+    env = TimeLimit(env, max_episode_steps=10_000_000)
     return env
\ No newline at end of file