Commit 685dec5f authored by Ubuntu's avatar Ubuntu

Fix load model for neural agents

parent 40598dc2
......@@ -5,7 +5,7 @@ import projekt
class BaselineForageAgent(NeuralMMOAgent):
def __init__(self):
self.agent_type = 'scripted'
self.agent = getattr(baselines, 'Forage')(projekt.config.SmallMaps())
self.agent = getattr(baselines, 'Forage')(projekt.config.SmallMaps(),0)
def register_reset(self, observations):
action = self.agent(observations)
......@@ -18,7 +18,7 @@ class BaselineForageAgent(NeuralMMOAgent):
class BaselineCombatAgent(NeuralMMOAgent):
def __init__(self):
self.agent_type = 'scripted'
self.agent = getattr(baselines, 'Combat')(projekt.config.SmallMaps())
self.agent = getattr(baselines, 'Combat')(projekt.config.SmallMaps(),0)
def register_reset(self, observations):
action = self.agent(observations)
......@@ -31,7 +31,7 @@ class BaselineCombatAgent(NeuralMMOAgent):
class BaselineRandomAgent(NeuralMMOAgent):
def __init__(self):
self.agent_type = 'scripted'
self.agent = getattr(baselines, 'Random')(projekt.config.SmallMaps())
self.agent = getattr(baselines, 'Random')(projekt.config.SmallMaps(),0)
def register_reset(self, observations):
action = self.agent(observations)
......
......@@ -53,7 +53,7 @@ def run_episode(player_index, agents, N_TIME_STEPS):
if Action.Attack in actions[entid]:
targID = actions[entid][Action.Attack][Action.Target]
actions[entid][Action.Attack][Action.Target] = realm.entity(targID)
print(actions)
obs, dones, rewards, _ = env.step(
actions, omitDead=True, preprocess=neural_agents
)
......@@ -110,7 +110,7 @@ if __name__ == "__main__":
player_agent, opponent_agents = load_agents(LocalEvaluationConfig)
player_agent = player_agent[0]
N_EPISODES = 2
N_TIME_STEPS = 1024
N_TIME_STEPS = 10
for episode in range(N_EPISODES):
agents, player_index = assign_agents(player_agent, opponent_agents)
statistics = run_episode(player_index, agents, N_TIME_STEPS)
......
......@@ -16,7 +16,7 @@ import ray
from ray import rllib, tune
from ray.tune import CLIReporter
from ray.tune.integration.wandb import WandbLoggerCallback
from projekt import rllib_wrapper as wrapper
from neuralmmo.projekt import rllib_wrapper as wrapper
import projekt
from projekt import config as base_config
......@@ -37,7 +37,7 @@ def run_tune_experiment(config):
ray.init(local_mode=config.LOCAL_MODE)
#Obs and actions
#Obs and actions
obs = wrapper.observationSpace(config)
atns = wrapper.actionSpace(config)
......@@ -176,6 +176,79 @@ class Anvil():
from neural_mmo.forge.blade.core import terrain
terrain.MapGenerator(self.config).generate()
def loadTrainer(config):
'''Create monolithic RLlib trainer object'''
torch.set_num_threads(1)
ray.init(local_mode=config.LOCAL_MODE,
_memory=2000 * 1024 * 1024,
object_store_memory=200 * 1024 * 1024,
)
#Register custom env
ray.tune.registry.register_env("Neural_MMO",
lambda config: wrapper.RLlibEnv(config))
#Create policies
rllib.models.ModelCatalog.register_custom_model('godsword', wrapper.RLlibPolicy)
mapPolicy = lambda agentID: 'policy_{}'.format(agentID % config.NPOLICIES)
policies = createPolicies(config, mapPolicy)
#Instantiate monolithic RLlib Trainer object.
return wrapper.SanePPOTrainer(config={
'num_workers': config.NUM_WORKERS,
'num_gpus_per_worker': config.NUM_GPUS_PER_WORKER,
'num_gpus': config.NUM_GPUS,
'num_envs_per_worker': 1,
'train_batch_size': config.TRAIN_BATCH_SIZE // 2,
'rollout_fragment_length': config.ROLLOUT_FRAGMENT_LENGTH,
'sgd_minibatch_size': config.SGD_MINIBATCH_SIZE,
'num_sgd_iter': config.NUM_SGD_ITER,
'framework': 'torch',
'horizon': np.inf,
'soft_horizon': False,
'no_done_at_end': False,
'callbacks': wrapper.RLlibLogCallbacks,
'env_config': {
'config': config
},
'multiagent': {
'policies': policies,
'policy_mapping_fn': mapPolicy,
'count_steps_by': 'env_steps'
},
'model': {
'custom_model': 'godsword',
'custom_model_config': {'config': config},
'max_seq_len': config.LSTM_BPTT_HORIZON
},
})
def loadModel(config):
'''Load NN weights and optimizer state'''
trainer = loadTrainer(config)
utils.modelSize(trainer.defaultModel())
trainer.restore()
return trainer
def createPolicies(config, mapPolicy):
'''Generate RLlib policies'''
obs = wrapper.observationSpace(config)
atns = wrapper.actionSpace(config)
policies = {}
for i in range(config.NPOLICIES):
params = {
"agent_id": i,
"obs_space_dict": obs,
"act_space_dict": atns}
key = mapPolicy(i)
policies[key] = (None, obs, atns, params)
return policies
if __name__ == '__main__':
def Display(lines, out):
text = "\n".join(lines) + "\n"
......@@ -183,4 +256,4 @@ if __name__ == '__main__':
from fire import core
core.Display = Display
Fire(Anvil)
\ No newline at end of file
Fire(Anvil)
......@@ -5,6 +5,7 @@ import numpy as np
import gym
import wandb
import trueskill
import os
import torch
from torch import nn
......@@ -606,3 +607,55 @@ class RLlibLogCallbacks(DefaultCallbacks):
for rank, idx in enumerate(idxs):
key = 'Rank_{}'.format(policies[idx].__name__)
episode.custom_metrics[key] = rank
class SanePPOTrainer(ppo.PPOTrainer):
'''Small utility class on top of RLlib's base trainer'''
def __init__(self, config):
self.envConfig = config['env_config']['config']
super().__init__(env=self.envConfig.ENV_NAME, config=config)
self.training_logs = {}
def save(self):
'''Save model to file. Note: RLlib does not let us chose save paths'''
config = self.envConfig
saveFile = super().save(config.PATH_CHECKPOINTS)
saveDir = os.path.dirname(saveFile)
#Clear current save dir
shutil.rmtree(config.PATH_MODEL, ignore_errors=True)
os.mkdir(config.PATH_MODEL)
#Copy checkpoints
for f in os.listdir(saveDir):
stripped = re.sub('-\d+', '', f)
src = os.path.join(saveDir, f)
dst = os.path.join(config.PATH_MODEL, stripped)
shutil.copy(src, dst)
print('Saved to: {}'.format(saveDir))
def restore(self):
'''Restore model from path'''
self.training_logs = np.load(
self.envConfig.PATH_TRAINING_DATA,
allow_pickle=True).item()['logs']
path = os.path.join(
self.envConfig.PATH_MODEL,
'checkpoint')
print('Loading model from: {}'.format(path))
super().restore(path)
def policyID(self, idx):
return 'policy_{}'.format(idx)
def model(self, policyID):
return self.get_policy(policyID).model
def defaultModel(self):
return self.model(self.policyID(0))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment