neural_baseline_agent.py 737 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from os import stat_result
from utils.base_agent import NeuralMMOAgent
from utils.env_spaces import get_action_spaces

import projekt

from Forge import loadModel

class NeuralBaselineAgent(NeuralMMOAgent):
    def __init__(self):
        self.config = projekt.config.CompetitionRound1()
        self.trainer = loadModel(self.config)

    def register_reset(self, observations):
        obs = {0:observations}
16
        actions,self.state,_ = self.trainer.compute_actions(obs,state={},policy_id='policy_0')
17 18
        return actions[0]

19
    def compute_action(self, observations):
20
        obs = {0:observations}
21
        actions,self.state,_ = self.trainer.compute_actions(obs,state={},policy_id='policy_0')
22
        return actions[0]
23