Skip to content
Snippets Groups Projects
Commit 9d258fd3 authored by mohanty's avatar mohanty
Browse files

Update rag_llama_baseline.py

parent d6857f4d
No related branches found
No related tags found
No related merge requests found
......@@ -44,26 +44,32 @@ CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
class RAGModel:
def __init__(self):
"""
Initialize your model(s) here if necessary.
This is the constructor for your DummyModel class, where you can set up any
required initialization steps for your model(s) to function correctly.
Initialize the RAGModel with necessary models and configurations.
This constructor sets up the environment by loading sentence transformers for embedding generation,
a large language model for generating responses, and tokenizer for text processing. It also initializes
model parameters and templates for generating answers.
"""
# Load a sentence transformer model optimized for sentence embeddings, using CUDA if available.
self.sentence_model = SentenceTransformer('models/sentence-transformers/all-MiniLM-L6-v2', device='cuda')
# Define the number of context sentences to consider for generating an answer.
self.num_context = 10
self.max_ctx_sentence_length = 1000 # characters
# Set the maximum length for each context sentence in characters.
self.max_ctx_sentence_length = 1000
self.prompt_template = """You are given a quesition and references which may or may not help answer the question.
You are to respond with just the answer and no surrounding sentences.
If you are unsure about the answer, respond with "I don't know".
### Question
{query}
# Template for formatting the input to the language model, including placeholders for the question and references.
self.prompt_template = """
### Question
{query}
### References
{references}
### References
{references}
### Answer"""
### Answer
"""
# Configuration for model quantization to improve performance, using 4-bit precision.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
......@@ -71,10 +77,13 @@ If you are unsure about the answer, respond with "I don't know".
bnb_4bit_use_double_quant=False,
)
# Specify the large language model to be used.
model_name = "models/meta-llama/Llama-2-7b-chat-hf"
# Load the tokenizer for the specified model.
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the large language model with the specified quantization configuration.
self.llm = AutoModelForCausalLM.from_pretrained(
model_name,
device_map='auto',
......@@ -82,62 +91,75 @@ If you are unsure about the answer, respond with "I don't know".
torch_dtype=torch.float16,
)
# Initialize a text generation pipeline with the loaded model and tokenizer.
self.generation_pipe = pipeline(task="text-generation",
model=self.llm,
tokenizer=self.tokenizer,
max_new_tokens=10)
def generate_answer(self, query: str, search_results: List[str]) -> str:
"""
Generate an answer based on a provided query and a list of pre-cached search results.
Generate an answer based on the provided query and a list of pre-cached search results.
Parameters:
- query (str): The user's question or query input.
- search_results (List[str]): A list containing the text content from web pages
retrieved as search results for the query. Each element in the list is a string
representing the HTML text of a web page.
- query (str): The user's question.
- search_results (List[str]): Text content from web pages as search results.
Returns:
- (str): A plain text response that answers the query. This response is limited to 75 tokens.
If the generated response exceeds 75 tokens, it will be truncated to fit within this limit.
Notes:
- If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid
the penalty for hallucination.
- Response Time: Ensure that your model processes and responds to each query within 10 seconds.
Failing to adhere to this time constraint **will** result in a timeout during evaluation.
- str: A text response that answers the query. Limited to 75 tokens.
This method processes the search results to extract relevant sentences, generates embeddings for them,
and selects the top context sentences based on cosine similarity to the query embedding. It then formats
this information into a prompt for the language model, which generates an answer that is then trimmed to
meet the token limit.
"""
# Initialize a list to hold all extracted sentences from the search results.
all_sentences = []
# Process each HTML text from the search results to extract text content.
for html_text in search_results:
# Parse the HTML content to extract text.
soup = BeautifulSoup(html_text['page_result'], features="html.parser")
text = soup.get_text().replace('\n', '')
if len(text) > 0:
offsets = text_to_sentences_and_offsets(text)[1]
for ofs in offsets:
sentence = text[ofs[0]:ofs[1]]
all_sentences.append(sentence[:self.max_ctx_sentence_length])
# Convert the text into sentences and extract their offsets.
offsets = text_to_sentences_and_offsets(text)[1]
for ofs in offsets:
# Extract each sentence based on its offset and limit its length.
sentence = text[ofs[0]:ofs[1]]
all_sentences.append(sentence[:self.max_ctx_sentence_length])
else:
# If no text is extracted, add an empty string as a placeholder.
all_sentences.append('')
# Generate embeddings for all sentences and the query.
all_embeddings = self.sentence_model.encode(all_sentences, normalize_embeddings=True)
query_embedding = self.sentence_model.encode(query, normalize_embeddings=True)[None, :]
# Calculate cosine similarity between query and sentence embeddings, and select the top sentences.
cosine_scores = (all_embeddings * query_embedding).sum(1)
top_sentences = np.array(all_sentences)[(-cosine_scores).argsort()[:self.num_context]]
# Format the top sentences as references in the model's prompt template.
references = ''
for snippet in top_sentences:
references += '<DOC>\n' + snippet + '\n</DOC>\n'
references = ' '.join(references.split()[:500])
references = ' '.join(references.split()[:500]) # Limit the length of references to fit the model's input size.
final_prompt = self.prompt_template.format(query=query, references=references)
result = self.generation_pipe(final_prompt)[0]['generated_text']
answer = result.split("### Answer\n")[1]
# Generate an answer using the formatted prompt.
result = self.generation_pipe(final_prompt)
result = result[0]['generated_text']
try:
# Extract the answer from the generated text.
answer = result.split("### Answer\n")[-1]
except IndexError:
# If the model fails to generate an answer, return a default response.
answer = "I don't know"
# Trim prediction to a max of 75 tokens
# Trim the prediction to a maximum of 75 tokens (this function needs to be defined).
trimmed_answer = trim_predictions_to_max_token_length(answer)
return trimmed_answer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment