From bb79e72dc57d5a1e474639bf6802d4cb8a25d25f Mon Sep 17 00:00:00 2001
From: "S.P. Mohanty" <spmohanty91@gmail.com>
Date: Thu, 11 Apr 2024 11:05:25 +0000
Subject: [PATCH] Update doc strings refering to search results

---
 models/dummy_model.py            | 15 +++---
 models/rag_llama_baseline.py     | 82 ++++++++++++++++++++------------
 models/vanilla_llama_baseline.py | 45 ++++++++++--------
 3 files changed, 84 insertions(+), 58 deletions(-)

diff --git a/models/dummy_model.py b/models/dummy_model.py
index 9277d32..2ea9412 100644
--- a/models/dummy_model.py
+++ b/models/dummy_model.py
@@ -1,5 +1,5 @@
 import os
-from typing import List
+from typing import Dict, List
 
 from models.utils import trim_predictions_to_max_token_length
 
@@ -14,6 +14,7 @@ from models.utils import trim_predictions_to_max_token_length
 # **Note**: This environment variable will not be available for Task 1 evaluations.
 CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
 
+
 class DummyModel:
     def __init__(self):
         """
@@ -23,15 +24,15 @@ class DummyModel:
         """
         pass
 
-    def generate_answer(self, query: str, search_results: List[str]) -> str:
+    def generate_answer(self, query: str, search_results: List[Dict]) -> str:
         """
         Generate an answer based on a provided query and a list of pre-cached search results.
 
         Parameters:
         - query (str): The user's question or query input.
-        - search_results (List[str]): A list containing the text content from web pages
-          retrieved as search results for the query. Each element in the list is a string
-          representing the HTML text of a web page.
+        - search_results (List[Dict]): A list containing the search result objects,
+        as described here:
+          https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
 
         Returns:
         - (str): A plain text response that answers the query. This response is limited to 75 tokens.
@@ -45,8 +46,8 @@ class DummyModel:
         """
         # Default response when unsure about the answer
         answer = "i don't know"
-        
+
         # Trim prediction to a max of 75 tokens
         trimmed_answer = trim_predictions_to_max_token_length(answer)
-        
+
         return trimmed_answer
diff --git a/models/rag_llama_baseline.py b/models/rag_llama_baseline.py
index acae6c7..1e1f11d 100644
--- a/models/rag_llama_baseline.py
+++ b/models/rag_llama_baseline.py
@@ -1,5 +1,5 @@
 import os
-from typing import List
+from typing import Dict, List
 
 import numpy as np
 import torch
@@ -19,10 +19,10 @@ from transformers import (
 ###
 ### IMPORTANT !!!
 ### Before submitting, please follow the instructions in the docs below to download and check in :
-### the model weighs. 
-### 
+### the model weighs.
+###
 ###  https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/download_baseline_model_weights.md
-### 
+###
 ###
 ### DISCLAIMER: This baseline has NOT been tuned for performance
 ###             or efficiency, and is provided as is for demonstration.
@@ -45,13 +45,15 @@ class RAGModel:
     def __init__(self):
         """
         Initialize the RAGModel with necessary models and configurations.
-        
+
         This constructor sets up the environment by loading sentence transformers for embedding generation,
         a large language model for generating responses, and tokenizer for text processing. It also initializes
         model parameters and templates for generating answers.
         """
         # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available.
-        self.sentence_model = SentenceTransformer('models/sentence-transformers/all-MiniLM-L6-v2', device='cuda')
+        self.sentence_model = SentenceTransformer(
+            "models/sentence-transformers/all-MiniLM-L6-v2", device="cuda"
+        )
 
         # Define the number of context sentences to consider for generating an answer.
         self.num_context = 10
@@ -68,7 +70,7 @@ class RAGModel:
 
         ### Answer
         """
-        
+
         # Configuration for model quantization to improve performance, using 4-bit precision.
         bnb_config = BitsAndBytesConfig(
             load_in_4bit=True,
@@ -86,24 +88,28 @@ class RAGModel:
         # Load the large language model with the specified quantization configuration.
         self.llm = AutoModelForCausalLM.from_pretrained(
             model_name,
-            device_map='auto',
+            device_map="auto",
             quantization_config=bnb_config,
             torch_dtype=torch.float16,
         )
 
         # Initialize a text generation pipeline with the loaded model and tokenizer.
-        self.generation_pipe = pipeline(task="text-generation",
-                                        model=self.llm,
-                                        tokenizer=self.tokenizer,
-                                        max_new_tokens=10)
+        self.generation_pipe = pipeline(
+            task="text-generation",
+            model=self.llm,
+            tokenizer=self.tokenizer,
+            max_new_tokens=10,
+        )
 
-    def generate_answer(self, query: str, search_results: List[str]) -> str:
+    def generate_answer(self, query: str, search_results: List[Dict]) -> str:
         """
         Generate an answer based on the provided query and a list of pre-cached search results.
 
         Parameters:
         - query (str): The user's question.
-        - search_results (List[str]): Text content from web pages as search results.
+        - search_results (List[Dict]): A list containing the search result objects,
+        as described here:
+          https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
 
         Returns:
         - str: A text response that answers the query. Limited to 75 tokens.
@@ -120,46 +126,60 @@ class RAGModel:
         # Process each HTML text from the search results to extract text content.
         for html_text in search_results:
             # Parse the HTML content to extract text.
-            soup = BeautifulSoup(html_text['page_result'], features="html.parser")
-            text = soup.get_text().replace('\n', '')
+            soup = BeautifulSoup(
+                html_text["page_result"], features="html.parser"
+            )
+            text = soup.get_text().replace("\n", "")
             if len(text) > 0:
                 # Convert the text into sentences and extract their offsets.
                 offsets = text_to_sentences_and_offsets(text)[1]
                 for ofs in offsets:
                     # Extract each sentence based on its offset and limit its length.
-                    sentence = text[ofs[0]:ofs[1]]
-                    all_sentences.append(sentence[:self.max_ctx_sentence_length])
+                    sentence = text[ofs[0] : ofs[1]]
+                    all_sentences.append(
+                        sentence[: self.max_ctx_sentence_length]
+                    )
             else:
                 # If no text is extracted, add an empty string as a placeholder.
-                all_sentences.append('')
+                all_sentences.append("")
 
         # Generate embeddings for all sentences and the query.
-        all_embeddings = self.sentence_model.encode(all_sentences, normalize_embeddings=True)
-        query_embedding = self.sentence_model.encode(query, normalize_embeddings=True)[None, :]
+        all_embeddings = self.sentence_model.encode(
+            all_sentences, normalize_embeddings=True
+        )
+        query_embedding = self.sentence_model.encode(
+            query, normalize_embeddings=True
+        )[None, :]
 
         # Calculate cosine similarity between query and sentence embeddings, and select the top sentences.
         cosine_scores = (all_embeddings * query_embedding).sum(1)
-        top_sentences = np.array(all_sentences)[(-cosine_scores).argsort()[:self.num_context]]
+        top_sentences = np.array(all_sentences)[
+            (-cosine_scores).argsort()[: self.num_context]
+        ]
 
         # Format the top sentences as references in the model's prompt template.
-        references = ''
+        references = ""
         for snippet in top_sentences:
-            references += '<DOC>\n' + snippet + '\n</DOC>\n'
-        references = ' '.join(references.split()[:500])  # Limit the length of references to fit the model's input size.
-        final_prompt = self.prompt_template.format(query=query, references=references)
-        
+            references += "<DOC>\n" + snippet + "\n</DOC>\n"
+        references = " ".join(
+            references.split()[:500]
+        )  # Limit the length of references to fit the model's input size.
+        final_prompt = self.prompt_template.format(
+            query=query, references=references
+        )
+
         # Generate an answer using the formatted prompt.
         result = self.generation_pipe(final_prompt)
-        result = result[0]['generated_text']
-        
+        result = result[0]["generated_text"]
+
         try:
             # Extract the answer from the generated text.
             answer = result.split("### Answer\n")[-1]
         except IndexError:
             # If the model fails to generate an answer, return a default response.
             answer = "I don't know"
-                                
+
         # Trim the prediction to a maximum of 75 tokens (this function needs to be defined).
         trimmed_answer = trim_predictions_to_max_token_length(answer)
-        
+
         return trimmed_answer
diff --git a/models/vanilla_llama_baseline.py b/models/vanilla_llama_baseline.py
index 481ecd5..a3d2e83 100644
--- a/models/vanilla_llama_baseline.py
+++ b/models/vanilla_llama_baseline.py
@@ -1,5 +1,5 @@
 import os
-from typing import List
+from typing import Dict, List
 
 import numpy as np
 import torch
@@ -16,10 +16,10 @@ from transformers import (
 ###
 ### IMPORTANT !!!
 ### Before submitting, please follow the instructions in the docs below to download and check in :
-### the model weighs. 
-### 
+### the model weighs.
+###
 ###  https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/download_baseline_model_weights.md
-### 
+###
 ###
 ### DISCLAIMER: This baseline has NOT been tuned for performance
 ###             or efficiency, and is provided as is for demonstration.
@@ -37,6 +37,7 @@ from transformers import (
 # **Note**: This environment variable will not be available for Task 1 evaluations.
 CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
 
+
 class ChatModel:
     def __init__(self):
         """
@@ -49,7 +50,7 @@ class ChatModel:
 {query}
 
 ### Answer"""
-        
+
         bnb_config = BitsAndBytesConfig(
             load_in_4bit=True,
             bnb_4bit_compute_dtype=torch.float16,
@@ -58,40 +59,44 @@ class ChatModel:
         )
 
         model_name = "models/meta-llama/Llama-2-7b-chat-hf"
-        
+
         if not os.path.exists(model_name):
-            raise Exception(f"""
+            raise Exception(
+                f"""
             The evaluators expect the model weights to be checked into the repository,
             but we could not find the model weights at {model_name}
             
             Please follow the instructions in the docs below to download and check in the model weights.
             
             https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md
-            """)
+            """
+            )
 
         self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
         self.llm = AutoModelForCausalLM.from_pretrained(
             model_name,
-            device_map='auto',
+            device_map="auto",
             quantization_config=bnb_config,
             torch_dtype=torch.float16,
         )
 
-        self.generation_pipe = pipeline(task="text-generation",
-                                        model=self.llm,
-                                        tokenizer=self.tokenizer,
-                                        max_new_tokens=75)
+        self.generation_pipe = pipeline(
+            task="text-generation",
+            model=self.llm,
+            tokenizer=self.tokenizer,
+            max_new_tokens=75,
+        )
 
-    def generate_answer(self, query: str, search_results: List[str]) -> str:
+    def generate_answer(self, query: str, search_results: List[Dict]) -> str:
         """
         Generate an answer based on a provided query and a list of pre-cached search results.
 
         Parameters:
         - query (str): The user's question or query input.
-        - search_results (List[str]): A list containing the text content from web pages
-          retrieved as search results for the query. Each element in the list is a string
-          representing the HTML text of a web page.
+        - search_results (List[Dict]): A list containing the search result objects,
+          as described here:
+          https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
 
         Returns:
         - (str): A plain text response that answers the query. This response is limited to 75 tokens.
@@ -105,10 +110,10 @@ class ChatModel:
         """
 
         final_prompt = self.prompt_template.format(query=query)
-        result = self.generation_pipe(final_prompt)[0]['generated_text']
+        result = self.generation_pipe(final_prompt)[0]["generated_text"]
         answer = result.split("### Answer")[1].strip()
-                
+
         # Trim prediction to a max of 75 tokens
         trimmed_answer = trim_predictions_to_max_token_length(answer)
-        
+
         return trimmed_answer
-- 
GitLab