neural_baseline_agent.py 785 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from os import stat_result
from utils.base_agent import NeuralMMOAgent
from utils.env_spaces import get_action_spaces

import projekt

from Forge import loadModel

class NeuralBaselineAgent(NeuralMMOAgent):
    def __init__(self):
        self.config = projekt.config.CompetitionRound1()
        self.trainer = loadModel(self.config)

    def register_reset(self, observations):
        obs = {0:observations}
        actions,self.state,_ = self.trainer.compute_actions(obs)
        # action = self.get_action(observations)
        return actions[0]

    def compute_action(self, observations, info=None):
        obs = {0:observations}
        actions,self.state,_ = self.trainer.compute_actions(obs)
        # action = self.get_action(observations)
        return actions[0]