Skip to content
Snippets Groups Projects
Commit 847dc1ff authored by Eric Hambro's avatar Eric Hambro
Browse files

Fixup agents and remove reference to RLlib.

parent d6ba2a88
No related branches found
No related tags found
1 merge request!7Fixup agents and remove reference to RLlib.
from abc import ABC, abstractmethod
class BatchedAgent(ABC):
"""
This is an abstract base clase for you to load your models and perform rollouts on a
batched set of environments.
"""
def __init__(self, num_envs: int , num_actions: int):
self.num_envs = num_envs
self.num_actions = num_actions
@abstractmethod
def batched_step(self, observations, rewards, dones, infos):
"""
Perform a batched step on lists of environment outputs.
:param observations: a list of observations
:param rewards: a list of rewards
:param dones: a list of dones
:param observations: a list of infos
returns: an iterable of actions
"""
pass
class BatchedAgent:
"""
Simple Batched agent interface
Main motivation is to speedup runs by increasing gpu utilization
"""
def __init__(self, num_envs, num_actions):
"""
Setup your model
Load your weights etc
"""
self.num_envs = num_envs
self.num_actions = num_actions
def batched_step(self, observations, rewards, dones, infos):
"""
Take list of outputs of each environments and return a list of actions
"""
raise NotImplementedError
import numpy as np import numpy as np
from agents.batched_agent import BatchedAgent from agents.base import BatchedAgent
class CustomAgent(BatchedAgent):
"""A example agent... that simple acts randomly. Adapt to your needs!"""
class RandomAgent(BatchedAgent):
def __init__(self, num_envs, num_actions): def __init__(self, num_envs, num_actions):
"""Set up and load you model here"""
super().__init__(num_envs, num_actions) super().__init__(num_envs, num_actions)
self.seeded_state = np.random.RandomState(42) self.seeded_state = np.random.RandomState(42)
def preprocess_observations(self, observations, rewards, dones, infos):
return observations, rewards, dones, infos
def postprocess_actions(self, actions):
return actions
def batched_step(self, observations, rewards, dones, infos): def batched_step(self, observations, rewards, dones, infos):
rets = self.preprocess_observations(observations, rewards, dones, infos) """
observations, rewards, dones, infos = rets Perform a batched step on lists of environment outputs.
Each argument is a list of the respective gym output.
Returns an iterable of actions.
"""
actions = self.seeded_state.randint(self.num_actions, size=self.num_envs) actions = self.seeded_state.randint(self.num_actions, size=self.num_envs)
actions = self.postprocess_actions(actions) return actions
return actions
\ No newline at end of file
placeholder
import torch import torch
import numpy as np import numpy as np
from agents.batched_agent import BatchedAgent from agents.base import BatchedAgent
from nethack_baselines.torchbeast.models import load_model from nethack_baselines.torchbeast.models import load_model
......
from agents.random_batched_agent import RandomAgent from agents.custom_agent import CustomAgent
from agents.torchbeast_agent import TorchBeastAgent from agents.torchbeast_agent import TorchBeastAgent
# from agents.rllib_batched_agent import RLlibAgent
from envs.wrappers import addtimelimitwrapper_fn from envs.wrappers import addtimelimitwrapper_fn
...@@ -15,9 +14,8 @@ from envs.wrappers import addtimelimitwrapper_fn ...@@ -15,9 +14,8 @@ from envs.wrappers import addtimelimitwrapper_fn
class SubmissionConfig: class SubmissionConfig:
## Add your own agent class ## Add your own agent class
AGENT = TorchBeastAgent AGENT = CustomAgent
# AGENT = RLlibAgent # AGENT = TorchBeastAgent
# AGENT = RandomAgent
## Change the NUM_ENVIRONMENTS as you need ## Change the NUM_ENVIRONMENTS as you need
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment