From 9d258fd3186d655f63d17e2dc1bcd9afe136cfae Mon Sep 17 00:00:00 2001 From: mohanty <mohanty@aicrowd.com> Date: Tue, 9 Apr 2024 11:29:45 +0000 Subject: [PATCH] Update rag_llama_baseline.py --- models/rag_llama_baseline.py | 92 ++++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 35 deletions(-) diff --git a/models/rag_llama_baseline.py b/models/rag_llama_baseline.py index 74af65b..acae6c7 100644 --- a/models/rag_llama_baseline.py +++ b/models/rag_llama_baseline.py @@ -44,26 +44,32 @@ CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") class RAGModel: def __init__(self): """ - Initialize your model(s) here if necessary. - This is the constructor for your DummyModel class, where you can set up any - required initialization steps for your model(s) to function correctly. + Initialize the RAGModel with necessary models and configurations. + + This constructor sets up the environment by loading sentence transformers for embedding generation, + a large language model for generating responses, and tokenizer for text processing. It also initializes + model parameters and templates for generating answers. """ + # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available. self.sentence_model = SentenceTransformer('models/sentence-transformers/all-MiniLM-L6-v2', device='cuda') + # Define the number of context sentences to consider for generating an answer. self.num_context = 10 - self.max_ctx_sentence_length = 1000 # characters + # Set the maximum length for each context sentence in characters. + self.max_ctx_sentence_length = 1000 - self.prompt_template = """You are given a quesition and references which may or may not help answer the question. -You are to respond with just the answer and no surrounding sentences. -If you are unsure about the answer, respond with "I don't know". -### Question -{query} + # Template for formatting the input to the language model, including placeholders for the question and references. + self.prompt_template = """ + ### Question + {query} -### References -{references} + ### References + {references} -### Answer""" + ### Answer + """ + # Configuration for model quantization to improve performance, using 4-bit precision. bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, @@ -71,10 +77,13 @@ If you are unsure about the answer, respond with "I don't know". bnb_4bit_use_double_quant=False, ) + # Specify the large language model to be used. model_name = "models/meta-llama/Llama-2-7b-chat-hf" + # Load the tokenizer for the specified model. self.tokenizer = AutoTokenizer.from_pretrained(model_name) + # Load the large language model with the specified quantization configuration. self.llm = AutoModelForCausalLM.from_pretrained( model_name, device_map='auto', @@ -82,62 +91,75 @@ If you are unsure about the answer, respond with "I don't know". torch_dtype=torch.float16, ) + # Initialize a text generation pipeline with the loaded model and tokenizer. self.generation_pipe = pipeline(task="text-generation", model=self.llm, tokenizer=self.tokenizer, max_new_tokens=10) - def generate_answer(self, query: str, search_results: List[str]) -> str: """ - Generate an answer based on a provided query and a list of pre-cached search results. + Generate an answer based on the provided query and a list of pre-cached search results. Parameters: - - query (str): The user's question or query input. - - search_results (List[str]): A list containing the text content from web pages - retrieved as search results for the query. Each element in the list is a string - representing the HTML text of a web page. + - query (str): The user's question. + - search_results (List[str]): Text content from web pages as search results. Returns: - - (str): A plain text response that answers the query. This response is limited to 75 tokens. - If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. - - Notes: - - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid - the penalty for hallucination. - - Response Time: Ensure that your model processes and responds to each query within 10 seconds. - Failing to adhere to this time constraint **will** result in a timeout during evaluation. + - str: A text response that answers the query. Limited to 75 tokens. + + This method processes the search results to extract relevant sentences, generates embeddings for them, + and selects the top context sentences based on cosine similarity to the query embedding. It then formats + this information into a prompt for the language model, which generates an answer that is then trimmed to + meet the token limit. """ + # Initialize a list to hold all extracted sentences from the search results. all_sentences = [] + # Process each HTML text from the search results to extract text content. for html_text in search_results: + # Parse the HTML content to extract text. soup = BeautifulSoup(html_text['page_result'], features="html.parser") text = soup.get_text().replace('\n', '') if len(text) > 0: - offsets = text_to_sentences_and_offsets(text)[1] - for ofs in offsets: - sentence = text[ofs[0]:ofs[1]] - all_sentences.append(sentence[:self.max_ctx_sentence_length]) + # Convert the text into sentences and extract their offsets. + offsets = text_to_sentences_and_offsets(text)[1] + for ofs in offsets: + # Extract each sentence based on its offset and limit its length. + sentence = text[ofs[0]:ofs[1]] + all_sentences.append(sentence[:self.max_ctx_sentence_length]) else: + # If no text is extracted, add an empty string as a placeholder. all_sentences.append('') + # Generate embeddings for all sentences and the query. all_embeddings = self.sentence_model.encode(all_sentences, normalize_embeddings=True) query_embedding = self.sentence_model.encode(query, normalize_embeddings=True)[None, :] + # Calculate cosine similarity between query and sentence embeddings, and select the top sentences. cosine_scores = (all_embeddings * query_embedding).sum(1) top_sentences = np.array(all_sentences)[(-cosine_scores).argsort()[:self.num_context]] + # Format the top sentences as references in the model's prompt template. references = '' for snippet in top_sentences: references += '<DOC>\n' + snippet + '\n</DOC>\n' - - references = ' '.join(references.split()[:500]) + references = ' '.join(references.split()[:500]) # Limit the length of references to fit the model's input size. final_prompt = self.prompt_template.format(query=query, references=references) - result = self.generation_pipe(final_prompt)[0]['generated_text'] - answer = result.split("### Answer\n")[1] + + # Generate an answer using the formatted prompt. + result = self.generation_pipe(final_prompt) + result = result[0]['generated_text'] + + try: + # Extract the answer from the generated text. + answer = result.split("### Answer\n")[-1] + except IndexError: + # If the model fails to generate an answer, return a default response. + answer = "I don't know" - # Trim prediction to a max of 75 tokens + # Trim the prediction to a maximum of 75 tokens (this function needs to be defined). trimmed_answer = trim_predictions_to_max_token_length(answer) return trimmed_answer -- GitLab