diff --git a/agents/batched_agent.py b/agents/batched_agent.py index 71089572bec7f5a4fa4548264a932e39581d503c..d5e59cec3a1d0ec85fcca6ca3462f46786ab587c 100644 --- a/agents/batched_agent.py +++ b/agents/batched_agent.py @@ -11,20 +11,9 @@ class BatchedAgent: self.num_envs = num_envs self.num_actions = num_actions - def preprocess_observations(self, observations, rewards, dones, infos): + def batched_step(self, observations, rewards, dones, infos): """ - Add any preprocessing steps, for example rerodering/stacking for torch/tf in your model + Take list of outputs of each environments and return a list of actions """ - pass + raise NotImplementedError - def preprocess_actions(self, actions): - """ - Add any postprocessing steps, for example converting to lists - """ - pass - - def batched_step(self): - """ - Return a list of actions - """ - pass diff --git a/local_evaluation.py b/local_evaluation.py index 1ae846db360bd8784da333f143d13b8b903e0636..b8b625dc8e967446c6086d8aad1edf7c35063eb8 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -32,6 +32,7 @@ def evaluate(): agent = Agent(num_envs, num_actions) run_batched_rollout(batched_env, agent) + if __name__ == '__main__': evaluate()