Skip to content
Snippets Groups Projects
Commit 263c243f authored by xw_g's avatar xw_g
Browse files

Your commit message

parent 3b37b9e1
No related branches found
No related tags found
No related merge requests found
......@@ -113,48 +113,50 @@ class llama3_8b_FewShot(ShopBenchBaseModel):
scores, indices = self.index.search(np.array([query_embed]).astype(np.float32), topk)
# Retrieve and process results
exmaple_prompt = []
if not is_multiple_choice:
exmaple_prompt = []
for score, idx in zip(scores[0], indices[0]):
if score>=0.85:
fewshot_examaple = self.metadata[idx]["fewshot_examaple"]
exmaple_prompt.append(fewshot_examaple[9:])
if len(exmaple_prompt) > 0:
prompt_example = self.system_prompt + 'Here are some similar questions and answers you can refer to:\n'
for i in exmaple_prompt:
prompt_example += i+'\n'
prompt_example += '\nQuestion:' + prompt
else:
prompt_example = self.system_prompt + '\n' + prompt
print(prompt_example)
messages = [
{"role": "system", "content": prompt_example[:len(self.system_prompt)]},
{"role": "user", "content": prompt_example[len(self.system_prompt):]},
]
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(self.model.device)
if is_multiple_choice:
if len(exmaple_prompt) > 0:
prompt_example = self.system_prompt + 'Here are some similar questions and answers you can refer to:\n'
for i in exmaple_prompt:
prompt_example += i+'\n'
prompt_example += '\nQuestion:' + prompt
else:
prompt_example = self.system_prompt + '\n' + prompt
print(prompt_example)
messages = [
{"role": "system", "content": prompt_example[:len(self.system_prompt)]},
{"role": "user", "content": prompt_example[len(self.system_prompt):]},
]
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=1,
max_new_tokens=138,
eos_token_id=self.terminators,
do_sample=False,
)
response = outputs[0][input_ids.shape[-1]:]
response = self.tokenizer.decode(response, skip_special_tokens=True)
print(response)
return response
else:
outputs = self.model.generate(
input_ids,
max_new_tokens=128,
eos_token_id=self.terminators,
do_sample=False,
)
response = outputs[0][input_ids.shape[-1]:]
response = self.tokenizer.decode(response, skip_special_tokens=True)
print(response)
return response
prompt_example = self.system_prompt + '\n' + prompt
print(prompt_example)
inputs = self.tokenizer.encode(prompt_example, add_special_tokens=False, return_tensors="pt")
inputs = inputs.cuda()
if is_multiple_choice:
generate_ids = self.model.generate(inputs, max_new_tokens=1, temperature=0.1, eos_token_id=self.terminators)
result = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
generation = result[len(prompt_example):]
print(f'model generate answer : {generation}')
return generation
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment