From 9d258fd3186d655f63d17e2dc1bcd9afe136cfae Mon Sep 17 00:00:00 2001
From: mohanty <mohanty@aicrowd.com>
Date: Tue, 9 Apr 2024 11:29:45 +0000
Subject: [PATCH] Update rag_llama_baseline.py

---
 models/rag_llama_baseline.py | 92 ++++++++++++++++++++++--------------
 1 file changed, 57 insertions(+), 35 deletions(-)

diff --git a/models/rag_llama_baseline.py b/models/rag_llama_baseline.py
index 74af65b..acae6c7 100644
--- a/models/rag_llama_baseline.py
+++ b/models/rag_llama_baseline.py
@@ -44,26 +44,32 @@ CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
 class RAGModel:
     def __init__(self):
         """
-        Initialize your model(s) here if necessary.
-        This is the constructor for your DummyModel class, where you can set up any
-        required initialization steps for your model(s) to function correctly.
+        Initialize the RAGModel with necessary models and configurations.
+        
+        This constructor sets up the environment by loading sentence transformers for embedding generation,
+        a large language model for generating responses, and tokenizer for text processing. It also initializes
+        model parameters and templates for generating answers.
         """
+        # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available.
         self.sentence_model = SentenceTransformer('models/sentence-transformers/all-MiniLM-L6-v2', device='cuda')
 
+        # Define the number of context sentences to consider for generating an answer.
         self.num_context = 10
-        self.max_ctx_sentence_length = 1000 # characters
+        # Set the maximum length for each context sentence in characters.
+        self.max_ctx_sentence_length = 1000
 
-        self.prompt_template = """You are given a quesition and references which may or may not help answer the question. 
-You are to respond with just the answer and no surrounding sentences.
-If you are unsure about the answer, respond with "I don't know".
-### Question
-{query}
+        # Template for formatting the input to the language model, including placeholders for the question and references.
+        self.prompt_template = """
+        ### Question
+        {query}
 
-### References 
-{references}
+        ### References 
+        {references}
 
-### Answer"""
+        ### Answer
+        """
         
+        # Configuration for model quantization to improve performance, using 4-bit precision.
         bnb_config = BitsAndBytesConfig(
             load_in_4bit=True,
             bnb_4bit_compute_dtype=torch.float16,
@@ -71,10 +77,13 @@ If you are unsure about the answer, respond with "I don't know".
             bnb_4bit_use_double_quant=False,
         )
 
+        # Specify the large language model to be used.
         model_name = "models/meta-llama/Llama-2-7b-chat-hf"
 
+        # Load the tokenizer for the specified model.
         self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
+        # Load the large language model with the specified quantization configuration.
         self.llm = AutoModelForCausalLM.from_pretrained(
             model_name,
             device_map='auto',
@@ -82,62 +91,75 @@ If you are unsure about the answer, respond with "I don't know".
             torch_dtype=torch.float16,
         )
 
+        # Initialize a text generation pipeline with the loaded model and tokenizer.
         self.generation_pipe = pipeline(task="text-generation",
                                         model=self.llm,
                                         tokenizer=self.tokenizer,
                                         max_new_tokens=10)
 
-
     def generate_answer(self, query: str, search_results: List[str]) -> str:
         """
-        Generate an answer based on a provided query and a list of pre-cached search results.
+        Generate an answer based on the provided query and a list of pre-cached search results.
 
         Parameters:
-        - query (str): The user's question or query input.
-        - search_results (List[str]): A list containing the text content from web pages
-          retrieved as search results for the query. Each element in the list is a string
-          representing the HTML text of a web page.
+        - query (str): The user's question.
+        - search_results (List[str]): Text content from web pages as search results.
 
         Returns:
-        - (str): A plain text response that answers the query. This response is limited to 75 tokens.
-          If the generated response exceeds 75 tokens, it will be truncated to fit within this limit.
-
-        Notes:
-        - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid
-          the penalty for hallucination.
-        - Response Time: Ensure that your model processes and responds to each query within 10 seconds.
-          Failing to adhere to this time constraint **will** result in a timeout during evaluation.
+        - str: A text response that answers the query. Limited to 75 tokens.
+
+        This method processes the search results to extract relevant sentences, generates embeddings for them,
+        and selects the top context sentences based on cosine similarity to the query embedding. It then formats
+        this information into a prompt for the language model, which generates an answer that is then trimmed to
+        meet the token limit.
         """
 
+        # Initialize a list to hold all extracted sentences from the search results.
         all_sentences = []
 
+        # Process each HTML text from the search results to extract text content.
         for html_text in search_results:
+            # Parse the HTML content to extract text.
             soup = BeautifulSoup(html_text['page_result'], features="html.parser")
             text = soup.get_text().replace('\n', '')
             if len(text) > 0:
-              offsets = text_to_sentences_and_offsets(text)[1]
-              for ofs in offsets:
-                  sentence = text[ofs[0]:ofs[1]]
-                  all_sentences.append(sentence[:self.max_ctx_sentence_length])
+                # Convert the text into sentences and extract their offsets.
+                offsets = text_to_sentences_and_offsets(text)[1]
+                for ofs in offsets:
+                    # Extract each sentence based on its offset and limit its length.
+                    sentence = text[ofs[0]:ofs[1]]
+                    all_sentences.append(sentence[:self.max_ctx_sentence_length])
             else:
+                # If no text is extracted, add an empty string as a placeholder.
                 all_sentences.append('')
 
+        # Generate embeddings for all sentences and the query.
         all_embeddings = self.sentence_model.encode(all_sentences, normalize_embeddings=True)
         query_embedding = self.sentence_model.encode(query, normalize_embeddings=True)[None, :]
 
+        # Calculate cosine similarity between query and sentence embeddings, and select the top sentences.
         cosine_scores = (all_embeddings * query_embedding).sum(1)
         top_sentences = np.array(all_sentences)[(-cosine_scores).argsort()[:self.num_context]]
 
+        # Format the top sentences as references in the model's prompt template.
         references = ''
         for snippet in top_sentences:
             references += '<DOC>\n' + snippet + '\n</DOC>\n'
-
-        references = ' '.join(references.split()[:500])
+        references = ' '.join(references.split()[:500])  # Limit the length of references to fit the model's input size.
         final_prompt = self.prompt_template.format(query=query, references=references)
-        result = self.generation_pipe(final_prompt)[0]['generated_text']
-        answer = result.split("### Answer\n")[1]
+        
+        # Generate an answer using the formatted prompt.
+        result = self.generation_pipe(final_prompt)
+        result = result[0]['generated_text']
+        
+        try:
+            # Extract the answer from the generated text.
+            answer = result.split("### Answer\n")[-1]
+        except IndexError:
+            # If the model fails to generate an answer, return a default response.
+            answer = "I don't know"
                                 
-        # Trim prediction to a max of 75 tokens
+        # Trim the prediction to a maximum of 75 tokens (this function needs to be defined).
         trimmed_answer = trim_predictions_to_max_token_length(answer)
         
         return trimmed_answer
-- 
GitLab