Commit 847dc1ff authored by Eric Hambro's avatar Eric Hambro
Browse files

Fixup agents and remove reference to RLlib.

parent d6ba2a88
from abc import ABC, abstractmethod
class BatchedAgent(ABC):
"""
This is an abstract base clase for you to load your models and perform rollouts on a
batched set of environments.
"""
def __init__(self, num_envs: int , num_actions: int):
self.num_envs = num_envs
self.num_actions = num_actions
@abstractmethod
def batched_step(self, observations, rewards, dones, infos):
"""
Perform a batched step on lists of environment outputs.
:param observations: a list of observations
:param rewards: a list of rewards
:param dones: a list of dones
:param observations: a list of infos
returns: an iterable of actions
"""
pass
class BatchedAgent:
"""
Simple Batched agent interface
Main motivation is to speedup runs by increasing gpu utilization
"""
def __init__(self, num_envs, num_actions):
"""
Setup your model
Load your weights etc
"""
self.num_envs = num_envs
self.num_actions = num_actions
def batched_step(self, observations, rewards, dones, infos):
"""
Take list of outputs of each environments and return a list of actions
"""
raise NotImplementedError
import numpy as np
from agents.batched_agent import BatchedAgent
from agents.base import BatchedAgent
class CustomAgent(BatchedAgent):
"""A example agent... that simple acts randomly. Adapt to your needs!"""
class RandomAgent(BatchedAgent):
def __init__(self, num_envs, num_actions):
"""Set up and load you model here"""
super().__init__(num_envs, num_actions)
self.seeded_state = np.random.RandomState(42)
def preprocess_observations(self, observations, rewards, dones, infos):
return observations, rewards, dones, infos
def postprocess_actions(self, actions):
return actions
def batched_step(self, observations, rewards, dones, infos):
rets = self.preprocess_observations(observations, rewards, dones, infos)
observations, rewards, dones, infos = rets
"""
Perform a batched step on lists of environment outputs.
Each argument is a list of the respective gym output.
Returns an iterable of actions.
"""
actions = self.seeded_state.randint(self.num_actions, size=self.num_envs)
actions = self.postprocess_actions(actions)
return actions
\ No newline at end of file
return actions
import torch
import numpy as np
from agents.batched_agent import BatchedAgent
from agents.base import BatchedAgent
from nethack_baselines.torchbeast.models import load_model
......
from agents.random_batched_agent import RandomAgent
from agents.custom_agent import CustomAgent
from agents.torchbeast_agent import TorchBeastAgent
# from agents.rllib_batched_agent import RLlibAgent
from envs.wrappers import addtimelimitwrapper_fn
......@@ -15,9 +14,8 @@ from envs.wrappers import addtimelimitwrapper_fn
class SubmissionConfig:
## Add your own agent class
AGENT = TorchBeastAgent
# AGENT = RLlibAgent
# AGENT = RandomAgent
AGENT = CustomAgent
# AGENT = TorchBeastAgent
## Change the NUM_ENVIRONMENTS as you need
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment