Skip to content
Snippets Groups Projects
Commit 02e8b2c2 authored by zverkov's avatar zverkov
Browse files

updater user agent

parent 95f13e50
No related branches found
No related tags found
1 merge request!1Merge
# Adapted from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_parlai.py
from typing import List, Dict
import torch
import re
from transformers import (
AutoModelForCausalLM, # AutoModel for language modeling tasks
AutoTokenizer, # AutoTokenizer for tokenization
BitsAndBytesConfig, # Configuration for BitsAndBytes
pipeline, # Creating pipelines for model inference
)
from tqdm.auto import tqdm
from peft import PeftModel
from transformers import StoppingCriteria, StoppingCriteriaList
import torch
# from vllm import LLM, SamplingParams
# from vllm.lora.request import LoRARequest
class MistralAgent(object):
def __init__(self):
""" Load your model(s) here """
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
# model_name = "teknium/OpenHermes-2.5-Mistral-7B"
adapter_model_id = "./agents/mistral_adapter"
# compute_dtype = getattr(torch, "float16")
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=compute_dtype,
# bnb_4bit_use_double_quant=False,
# )
device_map = {"": 0}
self.tokenizer = AutoTokenizer.from_pretrained("./agents/mistral_model/mistral_gptq/tokenizer", trust_remote_code=True)
# self.tokenizer.bos_token_id = 1
# self.tokenizer.add_special_tokens({'pad_token': '</s>'})
# self.padding_side = "right"
self.model = AutoModelForCausalLM.from_pretrained("./agents/mistral_model/mistral_gptq/model",
# quantization_config=bnb_config,
device_map="auto")
# self.sampling_params = SamplingParams(
# top_p=1,
# use_beam_search=False,
# repetition_penalty=1.2,
# max_tokens=150,
# temperature=0.9
# )
# self.model = LLM(model="./agents/mistral_model/model",
# tokenizer="./agents/mistral_model/tokenizer",
# enable_lora=True, max_model_len=5000, max_lora_rank=64)
self.model.config.use_cache = True # For faster generation
self.model.load_adapter(adapter_model_id)
# self.model = PeftModel.from_pretrained(self.model, adapter_model_id)
# peft_config.init_lora_weights = False
# self.model.add_adapter(peft_config)
# self.model.enable_adapters()
# self.lora = LoRARequest("mistral_adapter", 1, adapter_model_id)
generation_kwargs = {
# "top_k": 50,
"top_p": 1,
"bos_token_id": 1,
"eos_token_id": 32000,
"pad_token_id": 2,
"do_sample": True,
"num_beams": 1,
"use_cache": True,
"temperature": 0.9,
"max_new_tokens": 150,
"repetition_penalty": 1.2
}
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(self.device)
# self.model = self.model.to(self.device)
self.model.eval()
self.turn_id = 1
self.prompt = """<|im_start|>system You are a helpful chatbot whose task is to continue chat conversation based on the character traits you have. \
Make an answer related to the chat history to continue the conversation. An answer should be related to your own characteristics and follows the last message in chat history. \
Give a short answer in 1-3 sentences.
Charateristics: ```{character}```
Chat history: ```{chat_history}``` <|im_end|>
<|im_start|>bot Answer: ```"""
print("Model loaded!!")
self.pipeline = pipeline(task="text-generation", return_full_text=False, model=self.model, tokenizer=self.tokenizer,**generation_kwargs)
@staticmethod
def text_preprocess(txt):
if isinstance(txt, str):
txt = [txt]
result = []
for text in txt:
text = re.sub(r' ([!.?]+)', r"\g<1>", text.strip()).capitalize()
capitalizing = re.compile(r'(?<=[\.\?!]\s)(\w+)')
def cap(match):
return(match.group().capitalize())
text = capitalizing.sub(cap, text)
result.append(text)
return " ".join(result)
def make_dialogue(self, chat_history, last_messages=5,
personas={"A": "User", "B": "Bot"}):
chat_history = chat_history[-last_messages:]
replics = []
for message in chat_history:
replics.append(f"{personas[message['persona_id']]}: {self.text_preprocess(message['text'])}")
return '\n'.join(replics)
def prepare_prompt(self, conversation):
"""
Prompt template has the following input variables: character, chat_history.
It's possible to add answer field for finetuning.
"""
character = self.text_preprocess(conversation["persona B"])
dialogue = self.make_dialogue(conversation['dialogue'])
prompt = self.prompt.format(character=character,
chat_history=dialogue)
return prompt
def generate_responses(self, test_data: List[Dict], api_responses: List[str]) -> List[str]:
"""
You will be provided with a batch of upto 50 independent conversations
Input 1
[
{"persona B": ... "dialogue": ... }, # conversation 1 Turn 1
...
{"persona B": ... "dialogue": ... } # conversation 50 Turn 1
]
Model should return 50 responses for Turn 1
...
Input 7 (test_data)
[
{"persona B": ... "dialogue": ... }, # conversation 1 Turn 7
...
{"persona B": ... "dialogue": ... } # conversation 50 Turn 7
]
Model should return 50 responses for Turn 7
api_responses - A list of output strings by the api call for each previous prompt response,
Will be a list of blank strings on the first call
Note: Turn numbers will NOT be provided as input
Return a dictionary with the following format
"use_api": True/False - Note that this cannot be used when using GPU
"prompts": [ <list of the prompts that go as "messages" to the api > ] - Note that every api call is independent and we don't use threads
"max_generated_tokens": [ list of ints for the max generation limit on each call] - Note that the submission will fail if the total generation limit is exceeded
"final_responses: [ <list of strings with the final responses> ] - Only used when use_api is set to False
"""
final_responses = []
for conversation in tqdm(test_data):
# tensor_input_ids, tensor_attention_mask = self.prepare_tensors(conversation)
with torch.no_grad():
input_ = self.prepare_prompt(conversation)
output = self.pipeline(input_)[0]['generated_text']
# output = self.model.generate(input_, self.sampling_params, lora_request=self.lora)[0].outputs[0].text
output = output.split('\n')[0]
end_idx = output.find('```')
if end_idx != -1:
output = output[:end_idx]
print(output)
final_responses.append(output)
self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7
response = {
"use_api": False, # Cannot use API if GPU true is set in aicrowd.json
"prompts": ["" for _ in test_data], # Cannot use API if GPU true is set in aicrowd.json
"max_generated_tokens": [0 for _ in test_data],
"final_responses": final_responses
}
return response
# from agents.dummy_agent import DummyResponseAgent
# from agents.dummy_prompt_agent import DummyPromptAgent
# from agents.bart_agent import BARTResponseAgent
# from agents.prompt_agent import PromptAgent
from agents.mistral_agent import MistralAgent
# UserAgent = DummyResponseAgent
# UserAgent = DummyPromptAgent
# UserAgent = BARTResponseAgent
# UserAgent = PromptAgent
UserAgent = MistralAgent
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment