diff --git a/agents/random_agent.py b/agents/random_batched_agent.py similarity index 52% rename from agents/random_agent.py rename to agents/random_batched_agent.py index 1ed471cd5e7b8e59b96268dfe43029a7eef3083b..ae426a5215f157910fad147abdc6945bbde1bbfb 100644 --- a/agents/random_agent.py +++ b/agents/random_batched_agent.py @@ -3,11 +3,19 @@ import numpy as np from agents.batched_agent import BatchedAgent class RandomAgent(BatchedAgent): - """This random agent just selects an action from the action space.""" def __init__(self, num_envs, num_actions): super().__init__(num_envs, num_actions) self.seeded_state = np.random.RandomState(42) + def preprocess_observations(self, observations, rewards, dones, infos): + return observations, rewards, dones, infos + + def postprocess_actions(self, actions): + return actions + def batched_step(self, observations, rewards, dones, infos): + rets = self.preprocess_observations(observations, rewards, dones, infos) + observations, rewards, dones, infos = rets actions = self.seeded_state.randint(self.num_actions, size=self.num_envs) + actions = self.postprocess_actions(actions) return actions \ No newline at end of file diff --git a/agents/rllib_agent.py b/agents/rllib_batched_agent.py similarity index 100% rename from agents/rllib_agent.py rename to agents/rllib_batched_agent.py diff --git a/submission_config.py b/submission_config.py index de7d56ddbdb23ee5be908a3138db8d4950c628da..f7d548260474eb7a780ed98f8b57c7a3e28c6f23 100644 --- a/submission_config.py +++ b/submission_config.py @@ -1,6 +1,6 @@ -from agents.random_agent import RandomAgent +from agents.random_batched_agent import RandomAgent # from agents.torchbeast_batched_agent import TorchBeastAgent -# from agents.rllib_agent import RLlibAgent +# from agents.rllib_batched_agent import RLlibAgent from submission_wrappers import addtimelimitwrapper_fn