rollout_update.py 2.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import sys
sys.path.append('neural-mmo/')

from forge.ethyr.torch import utils
from forge.trinity.env import Env
import projekt

from utils.helpers import load_agents
import random
import copy

def assign_agents(player_agent,opponent_agents):
    player_index = 0 
    if len(opponent_agents) != 127:
        raise Exception("Number of opponent agents should add up to exactly 127")
    random.shuffle(opponent_agents)
    player_index = random.randint(0,127)
    agents = copy.deepcopy(opponent_agents)
    agents.insert(player_index,player_agent)
    return agents,player_index


def run_episode(player_index, agents, N_TIME_STEPS):
    config = projekt.config.CompetitionRound1()
    env = Env(config)
    n_steps = 0
    neural_agents = set()
    dead_agents = []

    obs = env.reset()
    entids = list(obs.keys())
    agent_entid_map = dict(zip(range(len(agents)), entids))
    entid_agent_map = {x[1]:x[0] for x in agent_entid_map.items()}
    for idx,agent in enumerate(agents):
        if agent.type == 'neural':
            neural_agents.add(agent_entid_map[idx])
        
    actions = {}
    for entid in entids:
        actions[entid] = agents[entid_agent_map[entid]].register_reset(obs[entid])

    while len(obs.keys()) > 0 and n_steps < N_TIME_STEPS:
        obs,dones,rewards,_ = env.step(actions,omitDead=False)
        alive_agents = list(obs.keys())
        for entid in dones:
            if dones[entid]:
                dead_agents.append(entid)
        actions = {}
        for entid in alive_agents:
            if entid not in entid_agent_map:
                continue
            actions[entid] = agents[entid_agent_map[entid]].compute_action(obs[entid])
        n_steps += 1
    
    for entid in obs:
        if entid not in dead_agents:
            dead_agents.append(entid)

    logs = env.terminal()    
    player_entid = entid_agent_map[player_index]
    player_log = {}
    if player_entid in dead_agents:
        player_log["Achievement"] = logs['Stats']['Achievement'][player_entid]
        player_log["Equipment"] = logs['Stats']['Equipment'][player_entid]
        player_log["Exploration"] = logs['Stats']['Exploration'][player_entid]
        player_log["PlayerKills"] = logs['Stats']['PlayerKills'][player_entid]
        player_log["Foraging"] = logs['Stats']['Foraging'][player_entid]
    return player_log


def print_statistics(player_statistics,episode):
    print(player_statistics, episode)

if __name__== "__main__":
    player_agent, opponent_agents = load_agents("players.yaml")
    N_EPISODES = 10
    N_TIME_STEPS = 10
    for episode in range(N_EPISODES):
        agents,player_index = assign_agents(player_agent,opponent_agents)
        statistics = run_episode(player_index,agents,N_TIME_STEPS)
        print_statistics(statistics,episode)