diff --git a/agents/torchbeast_batched_agent.py b/agents/torchbeast_batched_agent.py index ae9bcf359ab23bd77d159bacb986bd3cafc15889..bbfb9b8146786d7b919773958717809fc4caf038 100644 --- a/agents/torchbeast_batched_agent.py +++ b/agents/torchbeast_batched_agent.py @@ -1 +1,61 @@ -placeholders \ No newline at end of file +import torch +import numpy as np + +from agents.batched_agent import BatchedAgent + +from nethack_baselines.torchbeast.models import load_model + +MODEL_DIR = "./models/torchbeast/example_run" + + +class TorchBeastAgent(BatchedAgent): + """ + A BatchedAgent using the TorchBeast Model + """ + + def __init__(self, num_envs, num_actions): + super().__init__(num_envs, num_actions) + self.model_dir = MODEL_DIR + self.device = "cuda:0" + self.model = load_model(MODEL_DIR, self.device) + + self.core_state = [ + m.to(self.device) for m in self.model.initial_state(batch_size=num_envs) + ] + + def batch_inputs(self, observations, dones): + """ + Convert lists of observations, rewards, dones, infos to tensors for TorchBeast. + + TorchBeast models: + * take tensors in the form: [T, B, ...]: B:= batch, T:= unroll (=1) + * take "done" as a BOOLEAN observation + """ + states = list(observations[0].keys()) + obs = {k: [] for k in states} + + # Unpack List[Dicts] -> Dict[Lists] + for o in observations: + for k, t in o.items(): + obs[k].append(t) + + # Convert to Tensor, Add Unroll Dim (=1), Move to GPU + for k in states: + obs[k] = torch.Tensor(np.stack(obs[k])[None, ...]).to(self.device) + obs["done"] = torch.Tensor(np.array(dones)[None, ...]).bool().to(self.device) + return obs, dones + + def batched_step(self, observations, rewards, dones, infos): + """ + Perform a batched step on lists of environment outputs. + + Torchbeast models: + * take the core (LSTM) state as input, and return as output + * return outputs as a dict of "action", "policy_logits", "baseline" + """ + observations, dones = self.batch_inputs(observations, dones) + + with torch.no_grad(): + outputs, self.core_state = self.model(observations, self.core_state) + + return outputs["action"].cpu().numpy()[0] diff --git a/nethack_baselines/torchbeast/models/baseline.py b/nethack_baselines/torchbeast/models/baseline.py index 99fdde1c667097610be86445855b33a9ed637f79..0599315c2d82c5f9951ddc478c90a7e211e7e18c 100644 --- a/nethack_baselines/torchbeast/models/baseline.py +++ b/nethack_baselines/torchbeast/models/baseline.py @@ -21,7 +21,7 @@ from einops import rearrange from nle import nethack -from util.id_pairs import id_pairs_table +from .util import id_pairs_table import numpy as np NUM_GLYPHS = nethack.MAX_GLYPH diff --git a/nethack_baselines/torchbeast/util/id_pairs.py b/nethack_baselines/torchbeast/models/util.py similarity index 100% rename from nethack_baselines/torchbeast/util/id_pairs.py rename to nethack_baselines/torchbeast/models/util.py diff --git a/nethack_baselines/torchbeast/util/__init__.py b/nethack_baselines/torchbeast/util/__init__.py deleted file mode 100644 index 8daf2005df702571a183c0f0fac274d9429c4325..0000000000000000000000000000000000000000 --- a/nethack_baselines/torchbeast/util/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/submission_config.py b/submission_config.py index 5038d32ce0c8ee1498a5ff30ad373e9f13f1ab3c..e5e738b66462e8becb096053f52c18b57cffee76 100644 --- a/submission_config.py +++ b/submission_config.py @@ -1,5 +1,5 @@ from agents.random_batched_agent import RandomAgent -# from agents.torchbeast_batched_agent import TorchBeastAgent +from agents.torchbeast_batched_agent import TorchBeastAgent # from agents.rllib_batched_agent import RLlibAgent from submission_wrappers import addtimelimitwrapper_fn @@ -15,9 +15,9 @@ from submission_wrappers import addtimelimitwrapper_fn class SubmissionConfig: ## Add your own agent class - # Submision_Agent = TorchBeastAgent + Submision_Agent = TorchBeastAgent # Submision_Agent = RLlibAgent - Submision_Agent = RandomAgent + # Submision_Agent = RandomAgent ## Change the NUM_PARALLEL_ENVIRONMENTS as you need