From a21d388af3be3c38c8cda4a4325a6db6ad81aaa8 Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Mon, 17 Jun 2024 02:13:37 -0400 Subject: [PATCH] Update --- models/llama_loader.py | 2 -- models/v4.py | 6 ++++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/models/llama_loader.py b/models/llama_loader.py index ae01055..8fff3f3 100644 --- a/models/llama_loader.py +++ b/models/llama_loader.py @@ -63,8 +63,6 @@ class VLLMConfig: "disable_custom_all_reduce": self.disable_custom_all_reduce } - # temperature=0.6, - # top_p=0.9, def sampling_args(self, override_kwargs: dict = {}) -> vllm.SamplingParams: kwargs = { "n": self.n, diff --git a/models/v4.py b/models/v4.py index c67a991..dd0b2b6 100644 --- a/models/v4.py +++ b/models/v4.py @@ -162,6 +162,7 @@ class V4(Base): ) simple_qa = self.simple_qa(queries, domains, final_answers_flag) + final_answers = [ (x if x != "" else y) for x, y in zip(simple_qa, final_answers) ] @@ -328,7 +329,6 @@ class V4(Base): self, queries: List[str], domains: List[str], skip_flag: List[bool] ) -> List[str]: - skip = [False] * len(queries) inputs = [ self.template.format(context={"query": query}, template="simple_qa") for query in queries @@ -336,7 +336,7 @@ class V4(Base): responses = self.llm_inference_with_skip_flag( inputs, - skip, + [False] * len(queries), max_tokens=64, stop=["<|eot_id|>"], include_stop_str_in_output=False, @@ -347,6 +347,8 @@ class V4(Base): ] answers = [x.lower().strip() for x in answers] + print(answers) + not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)] contain_keywords = [ check_keywords(x, ["oscar", "academy", "award"]) for x in queries -- GitLab