Commit 55fb08f9 authored by Eric Hambro's avatar Eric Hambro
Browse files

TorchBeast Agent added.

parent bc8d80f4
placeholders
\ No newline at end of file
import torch
import numpy as np
from agents.batched_agent import BatchedAgent
from nethack_baselines.torchbeast.models import load_model
MODEL_DIR = "./models/torchbeast/example_run"
class TorchBeastAgent(BatchedAgent):
"""
A BatchedAgent using the TorchBeast Model
"""
def __init__(self, num_envs, num_actions):
super().__init__(num_envs, num_actions)
self.model_dir = MODEL_DIR
self.device = "cuda:0"
self.model = load_model(MODEL_DIR, self.device)
self.core_state = [
m.to(self.device) for m in self.model.initial_state(batch_size=num_envs)
]
def batch_inputs(self, observations, dones):
"""
Convert lists of observations, rewards, dones, infos to tensors for TorchBeast.
TorchBeast models:
* take tensors in the form: [T, B, ...]: B:= batch, T:= unroll (=1)
* take "done" as a BOOLEAN observation
"""
states = list(observations[0].keys())
obs = {k: [] for k in states}
# Unpack List[Dicts] -> Dict[Lists]
for o in observations:
for k, t in o.items():
obs[k].append(t)
# Convert to Tensor, Add Unroll Dim (=1), Move to GPU
for k in states:
obs[k] = torch.Tensor(np.stack(obs[k])[None, ...]).to(self.device)
obs["done"] = torch.Tensor(np.array(dones)[None, ...]).bool().to(self.device)
return obs, dones
def batched_step(self, observations, rewards, dones, infos):
"""
Perform a batched step on lists of environment outputs.
Torchbeast models:
* take the core (LSTM) state as input, and return as output
* return outputs as a dict of "action", "policy_logits", "baseline"
"""
observations, dones = self.batch_inputs(observations, dones)
with torch.no_grad():
outputs, self.core_state = self.model(observations, self.core_state)
return outputs["action"].cpu().numpy()[0]
......@@ -21,7 +21,7 @@ from einops import rearrange
from nle import nethack
from util.id_pairs import id_pairs_table
from .util import id_pairs_table
import numpy as np
NUM_GLYPHS = nethack.MAX_GLYPH
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from agents.random_batched_agent import RandomAgent
# from agents.torchbeast_batched_agent import TorchBeastAgent
from agents.torchbeast_batched_agent import TorchBeastAgent
# from agents.rllib_batched_agent import RLlibAgent
from submission_wrappers import addtimelimitwrapper_fn
......@@ -15,9 +15,9 @@ from submission_wrappers import addtimelimitwrapper_fn
class SubmissionConfig:
## Add your own agent class
# Submision_Agent = TorchBeastAgent
Submision_Agent = TorchBeastAgent
# Submision_Agent = RLlibAgent
Submision_Agent = RandomAgent
# Submision_Agent = RandomAgent
## Change the NUM_PARALLEL_ENVIRONMENTS as you need
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment