diff --git a/Dockerfile b/Dockerfile index 95834bc26f84da6e2ebb116c416ba45a87eda023..ca33aaaecdc71008f4946051b0ccaa75524d5578 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,14 @@ ## This is an example Dokerfile you can change to make submissions on aicrowd ## To use it, place it in the base of the repo, and remove the underscore (_) from the filename -FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 ENV DEBIAN_FRONTEND=noninteractive COPY apt.txt /tmp/apt.txt RUN apt -qq update && apt -qq install -y --no-install-recommends `cat /tmp/apt.txt` \ && rm -rf /var/cache/* -RUN apt install -y locales wget +RUN apt install -y locales wget build-essential # Unicode support: RUN locale-gen en_US.UTF-8 diff --git a/docker_run.sh b/docker_run.sh index 2b7fedca5726830afcd96468f78d730e8fa96b51..57e84206b111fcf746007ff840cf4e101677c4f6 100755 --- a/docker_run.sh +++ b/docker_run.sh @@ -49,6 +49,7 @@ docker run \ -v "$(pwd)":/submission \ -w /submission \ -e OPENAI_API_KEY=$OPENAI_API_KEY \ + --ipc=host \ $IMAGE_NAME python local_evaluation.py # Note: We assume you have nvidia-container-toolkit installed and configured diff --git a/docs/baselines.md b/docs/baselines.md index 83f5e524b5e2377e15728d3c4eca26cd66c513e6..596836352fc9db81cdab781539c2a4be1e0fb48a 100644 --- a/docs/baselines.md +++ b/docs/baselines.md @@ -7,9 +7,9 @@ Please note that these baselines are **NOT** tuned for performance or efficiency ## Available Baseline Models: -1. [**Vanilla Llama 2 Model**](../models/vanilla_llama_baseline.py): For an implementation guide and further details, refer to the Vanilla Llama 2 model documentation [here](../models/vanilla_llama_baseline.py). +1. [**Vanilla Llama 3 Model**](../models/vanilla_llama_baseline.py): For an implementation guide and further details, refer to the Vanilla Llama 3 model inline documentation [here](../models/vanilla_llama_baseline.py). -2. [**RAG Baseline Model**](../models/rag_llm_model.py): For an implementation guide and further details, refer to the RAG Baseline model documentation [here](../models/rag_llm_model.py). +2. [**RAG Baseline Model**](../models/rag_llm_model.py): For an implementation guide and further details, refer to the RAG Baseline model inline documentation [here](../models/rag_llm_model.py). ## Preparing Your Submission: diff --git a/docs/batch_prediction_interface.md b/docs/batch_prediction_interface.md new file mode 100644 index 0000000000000000000000000000000000000000..24c630367763ae1f213517f11ae1b8d941fa9e87 --- /dev/null +++ b/docs/batch_prediction_interface.md @@ -0,0 +1,55 @@ +## Batch Prediction Interface +- Date: `14-05-2024` + +Your submitted models can now make batch predictions on the test set, allowing you to fully utilize the multi-GPU setup available during evaluations. + +### Changes to Your Code + +1. **Add a `get_batch_size()` Function:** + + - This function should return an integer between `[1, 16]`. The maximum batch size supported at the moment is 16. + - You can also choose the batch size dynamically. + - This function is a **required** interface for your model class. + +2. **Replace `generate_answer` with `batch_generate_answer`:** + + - Update your code to replace the `generate_answer` function with `batch_generate_answer`. + - For more details on the `batch_generate_answer` interface, please refer to the inline documentation in [dummy_model.py](../models/dummy_model.py). + + ```python + # Old Interface + def generate_answer(self, query: str, search_results: List[Dict], query_time: str) -> str: + .... + .... + return answer + + # New Interface + def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: + batch_interaction_ids = batch["interaction_id"] + queries = batch["query"] + batch_search_results = batch["search_results"] + query_times = batch["query_time"] + + .... + .... + + return [answer1, answer2, ......, answerN] + ``` + + - The new function should return a list of answers (`List[str]`) instead of a single answer (`str`). + - The simplest example of a valid submission with the new interface is as follows: + + ```python + class DummyModel: + def get_batch_size(self) -> int: + return 4 + + def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: + queries = batch["query"] + answers = ["i dont't know" for _ in queries] + return answers + ``` + +### Backward Compatibility + +To ensure a smooth transition, the evaluators will maintain backward compatibility with the `generate_answer` interface for a short period. However, we strongly recommend updating your code to use the `batch_generate_answer` interface to avoid any disruptions when support for the older interface is removed in the coming weeks. diff --git a/docs/download_baseline_model_weights.md b/docs/download_baseline_model_weights.md index f2b19f1bffc428befcf795b97192d85de2249fc6..f2afa0e025e86ff34df3adc764f2aef23c79b024 100644 --- a/docs/download_baseline_model_weights.md +++ b/docs/download_baseline_model_weights.md @@ -1,7 +1,7 @@ ### Setting Up and Downloading Baseline Model weighta with Hugging Face This guide outlines the steps to download (and check in) the models weights required for the baseline models. -We will focus on the `Llama-2-7b-chat-hf` and `all-MiniLM-L6-v2` models. +We will focus on the `Meta-Llama-3-8B-Instruct` and `all-MiniLM-L6-v2` models. But the steps should work equally well for any other models on hugging face. #### Preliminary Steps: @@ -16,7 +16,7 @@ But the steps should work equally well for any other models on hugging face. 2. **Accept the LLaMA Terms**: - You must accept the LLaMA model's terms of use by visiting: [LLaMA-2-7b-chat-hf Terms](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). + You must accept the LLaMA model's terms of use by visiting: [meta-llama/Meta-Llama-3-8B-Instruct Terms](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct). 3. **Create a Hugging Face CLI Token**: @@ -38,14 +38,14 @@ But the steps should work equally well for any other models on hugging face. 1. **Download LLaMA-2-7b Model**: - Execute the following command to download the `Llama-2-7b-chat-hf` model to a local subdirectory. This command excludes unnecessary files to save space: + Execute the following command to download the `Meta-Llama-3-8B-Instruct` model to a local subdirectory. This command excludes unnecessary files to save space: ```bash HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download \ - meta-llama/Llama-2-7b-chat-hf \ + meta-llama/Meta-Llama-3-8B-Instruct \ --local-dir-use-symlinks False \ - --local-dir models/meta-llama/Llama-2-7b-chat-hf \ - --exclude *.bin # These are alternates to the safetensors hence not needed + --local-dir models/meta-llama/Meta-Llama-3-8B-Instruct \ + --exclude *.pth # These are alternates to the safetensors hence not needed ``` 3. **Download MiniLM-L6-v2 Model (for sentence embeddings)**: diff --git a/docs/hardware-and-system-config.md b/docs/hardware-and-system-config.md index b12f75ee7b02e1ea589ae1947a50e37420181a7a..2ea9e2eecad4a16ea073fde56d9c4e8be229da4f 100644 --- a/docs/hardware-and-system-config.md +++ b/docs/hardware-and-system-config.md @@ -20,7 +20,6 @@ Besides, the following restrictions will also be imposed: - Each team will be able to make up to **4 submissions per week per track**, and will be allowed an additional quota of upto **4 failed submissions per task per week**. -Based on the hardware and system configuration, we recommend participants to begin with 7B and 13B models. According to our experiments, models like Llama-2 13B can perform inference smoothly on 4 NVIDIA T4 GPUs, while 13B models will result in OOM. diff --git a/local_evaluation.py b/local_evaluation.py index bacda0f6a44bfffcdc026a57784472e7ec3e6868..8093f6030b78c5d4ff3862f8d02e0f5028bc2333 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -36,9 +36,7 @@ def attempt_api_call(client, model_name, messages, max_retries=10): ) return response.choices[0].message.content except (APIConnectionError, RateLimitError): - logger.warning( - f"API call failed on attempt {attempt + 1}, retrying..." - ) + logger.warning(f"API call failed on attempt {attempt + 1}, retrying...") except Exception as e: logger.error(f"Unexpected error: {e}") break @@ -69,9 +67,7 @@ def parse_response(resp: str): ): answer = 1 else: - raise ValueError( - f"Could not parse answer from response: {model_resp}" - ) + raise ValueError(f"Could not parse answer from response: {model_resp}") return answer except: @@ -79,56 +75,97 @@ def parse_response(resp: str): def trim_predictions_to_max_token_length(prediction): - """Trims prediction output to 75 tokens""" + """Trims prediction output to 75 tokens using Llama2 tokenizer""" max_token_length = 75 tokenized_prediction = tokenizer.encode(prediction) - trimmed_tokenized_prediction = tokenized_prediction[ - 1 : max_token_length + 1 - ] + trimmed_tokenized_prediction = tokenized_prediction[1 : max_token_length + 1] trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction) return trimmed_prediction -def generate_predictions(dataset_path, participant_model): - predictions = [] - with bz2.open(DATASET_PATH, "rt") as bz2_file: - for line in tqdm(bz2_file, desc="Generating Predictions"): - data = json.loads(line) - - query = data["query"] - web_search_results = data["search_results"] - query_time = data["query_time"] - - prediction = participant_model.generate_answer( - query, web_search_results, query_time - ) +def load_data_in_batches(dataset_path, batch_size): + """ + Generator function that reads data from a compressed file and yields batches of data. + Each batch is a dictionary containing lists of interaction_ids, queries, search results, query times, and answers. + + Args: + dataset_path (str): Path to the dataset file. + batch_size (int): Number of data items in each batch. + + Yields: + dict: A batch of data. + """ + def initialize_batch(): + """ Helper function to create an empty batch. """ + return {"interaction_id": [], "query": [], "search_results": [], "query_time": [], "answer": []} - # trim prediction to 75 tokens - prediction = trim_predictions_to_max_token_length(prediction) - predictions.append( - { - "query": query, - "ground_truth": str(data["answer"]).strip().lower(), - "prediction": str(prediction).strip().lower(), - } - ) + try: + with bz2.open(dataset_path, "rt") as file: + batch = initialize_batch() + for line in file: + try: + item = json.loads(line) + for key in batch: + batch[key].append(item[key]) + + if len(batch["query"]) == batch_size: + yield batch + batch = initialize_batch() + except json.JSONDecodeError: + logger.warn("Warning: Failed to decode a line.") + # Yield any remaining data as the last batch + if batch["query"]: + yield batch + except FileNotFoundError as e: + logger.error(f"Error: The file {dataset_path} was not found.") + raise e + except IOError as e: + logger.error(f"Error: An error occurred while reading the file {dataset_path}.") + raise e - return predictions -def evaluate_predictions(predictions, evaluation_model_name, openai_client): +def generate_predictions(dataset_path, participant_model): + """ + Processes batches of data from a dataset to generate predictions using a model. + + Args: + dataset_path (str): Path to the dataset. + participant_model (object): UserModel that provides `get_batch_size()` and `batch_generate_answer()` interfaces. + + Returns: + tuple: A tuple containing lists of queries, ground truths, and predictions. + """ + queries, ground_truths, predictions = [], [], [] + batch_size = participant_model.get_batch_size() + + for batch in tqdm(load_data_in_batches(dataset_path, batch_size), desc="Generating predictions"): + batch_ground_truths = batch.pop("answer") # Remove answers from batch and store them + batch_predictions = participant_model.batch_generate_answer(batch) + + queries.extend(batch["query"]) + ground_truths.extend(batch_ground_truths) + predictions.extend(batch_predictions) + + return queries, ground_truths, predictions + + +def evaluate_predictions(queries, ground_truths, predictions, evaluation_model_name, openai_client): n_miss, n_correct, n_correct_exact = 0, 0, 0 system_message = get_system_message() - for prediction_dict in tqdm( + for _idx, prediction in enumerate(tqdm( predictions, total=len(predictions), desc="Evaluating Predictions" - ): - query, ground_truth, prediction = ( - prediction_dict["query"], - prediction_dict["ground_truth"], - prediction_dict["prediction"], - ) - + )): + query = queries[_idx] + ground_truth = ground_truths[_idx].strip() + # trim prediction to 75 tokens using Llama2 tokenizer + prediction = trim_predictions_to_max_token_length(prediction) + prediction = prediction.strip() + + ground_truth_lowercase = ground_truth.lower() + prediction_lowercase = prediction.lower() + messages = [ {"role": "system", "content": system_message}, { @@ -136,17 +173,15 @@ def evaluate_predictions(predictions, evaluation_model_name, openai_client): "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n", }, ] - if prediction == "i don't know" or prediction == "i don't know.": + if "i don't know" in prediction_lowercase: n_miss += 1 continue - if prediction == ground_truth: + elif prediction_lowercase == ground_truth_lowercase: n_correct_exact += 1 n_correct += 1 continue - response = attempt_api_call( - openai_client, evaluation_model_name, messages - ) + response = attempt_api_call(openai_client, evaluation_model_name, messages) if response: log_response(messages, response) eval_res = parse_response(response) @@ -173,16 +208,14 @@ if __name__ == "__main__": from models.user_config import UserModel DATASET_PATH = "example_data/dev_data.jsonl.bz2" - EVALUATION_MODEL_NAME = os.getenv( - "EVALUATION_MODEL_NAME", "gpt-4-0125-preview" - ) + EVALUATION_MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview") # Generate predictions participant_model = UserModel() - predictions = generate_predictions(DATASET_PATH, participant_model) - + queries, ground_truths, predictions = generate_predictions(DATASET_PATH, participant_model) + # Evaluate Predictions openai_client = OpenAI() evaluation_results = evaluate_predictions( - predictions, EVALUATION_MODEL_NAME, openai_client + queries, ground_truths, predictions, EVALUATION_MODEL_NAME, openai_client ) diff --git a/models/README.md b/models/README.md index f9b97884bf2aae96054e17b214c223f3fe4df2f8..927e2d32cef89470b2b250b5df1cb3ad833301fa 100644 --- a/models/README.md +++ b/models/README.md @@ -4,7 +4,7 @@ For a streamlined experience, we suggest placing the code for all your models within the `models` directory. This is a recommendation for organizational purposes, but it's not a strict requirement. ## Model Base Class -Your models should follow the format from the `DummyModel` class found in [dummy_model.py](dummy_model.py). We provide the example model, `dummy_model.py`, to illustrate the structure your own model. Crucially, your model class must implement the `generate_answer` method. +Your models should follow the format from the `DummyModel` class found in [dummy_model.py](dummy_model.py). We provide the example model, `dummy_model.py`, to illustrate the structure your own model. Crucially, your model class must implement the `batch_generate_answer` method. ## Selecting which model to use To ensure your model is recognized and utilized correctly, please specify your model class name in the [`user_config.py`](user_config.py) file, by following the instructions in the inline comments. @@ -12,13 +12,19 @@ To ensure your model is recognized and utilized correctly, please specify your m ## Model Inputs and Outputs ### Inputs -Your model will receive two pieces of information for every task: -- `query`: String representing the input query -- `search_results`: List of strings, each comes from scraped HTML text of the search query. -- `query_time`: The time at which the query was made, represented as a string. +Your model will receive a batch of input queries as a dictionary, where the dictionary has the following keys: + +``` + - 'query' (List[str]): List of user queries. + - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding + to a query. Please refer to the following link for + more details about the individual search objects: + https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. +``` ### Outputs -The output from your model's `generate_answer` function should always be a string. +The output from your model's `batch_generate_answer` function should be a list of string responses for all the queries in the input batch. ## Internet Access Your model will not have access to the internet during evaluation. \ No newline at end of file diff --git a/models/dummy_model.py b/models/dummy_model.py index ac3919fd41e664cf167c693dd5ef8c0633204e3a..57f3e0980cf5ad931f970a66895687c131b12ccf 100644 --- a/models/dummy_model.py +++ b/models/dummy_model.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List +from typing import Any, Dict, List from models.utils import trim_predictions_to_max_token_length @@ -24,22 +24,35 @@ class DummyModel: """ pass - def generate_answer( - self, query: str, search_results: List[Dict], query_time: str - ) -> str: + def get_batch_size(self) -> int: """ - Generate an answer based on a provided query and a list of pre-cached search results. + Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. + + Returns: + int: The batch size, an integer between 1 and 16. This value indicates how many + queries should be processed together in a single batch. It can be dynamic + across different batch_generate_answer calls, or stay a static value. + """ + self.batch_size = 4 + return self.batch_size + + def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: + """ + Generates answers for a batch of queries using associated (pre-cached) search results and query times. Parameters: - - query (str): The user's question or query input. - - search_results (List[Dict]): A list containing the search result objects, - as described here: - https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail - - query_time (str): The time at which the query was made, represented as a string. + batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: + - 'interaction_id; (List[str]): List of interaction_ids for the associated queries + - 'query' (List[str]): List of user queries. + - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding + to a query. Please refer to the following link for + more details about the individual search objects: + https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. Returns: - - (str): A plain text response that answers the query. This response is limited to 75 tokens. - If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. + List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. + If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. Notes: - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid @@ -47,10 +60,14 @@ class DummyModel: - Response Time: Ensure that your model processes and responds to each query within 10 seconds. Failing to adhere to this time constraint **will** result in a timeout during evaluation. """ - # Default response when unsure about the answer - answer = "i don't know" + batch_interaction_ids = batch["interaction_id"] + queries = batch["query"] + search_results = batch["search_results"] + query_times = batch["query_time"] - # Trim prediction to a max of 75 tokens - trimmed_answer = trim_predictions_to_max_token_length(answer) + answers = [] + for idx, query in enumerate(queries): + # Implement logic to generate answers based on search results and query times + answers.append("i don't know") # Default placeholder response - return trimmed_answer + return answers diff --git a/models/rag_llama_baseline.py b/models/rag_llama_baseline.py index 177c8abe1cd48eefe637cf23f20feac889631d05..80df19a7c12268819602c4b35aa182d45303262d 100644 --- a/models/rag_llama_baseline.py +++ b/models/rag_llama_baseline.py @@ -1,18 +1,14 @@ import os -from typing import Dict, List +from collections import defaultdict +from typing import Any, Dict, List import numpy as np +import ray import torch +import vllm from blingfire import text_to_sentences_and_offsets from bs4 import BeautifulSoup -from models.utils import trim_predictions_to_max_token_length from sentence_transformers import SentenceTransformer -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - pipeline, -) ###################################################################################################### ###################################################################################################### @@ -23,6 +19,8 @@ from transformers import ( ### ### https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/download_baseline_model_weights.md ### +### And please pay special attention to the comments that start with "TUNE THIS VARIABLE" +### as they depend on your model and the available GPU resources. ### ### DISCLAIMER: This baseline has NOT been tuned for performance ### or efficiency, and is provided as is for demonstration. @@ -41,146 +39,350 @@ from transformers import ( CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") -class RAGModel: - def __init__(self): +#### CONFIG PARAMETERS --- + +# Define the number of context sentences to consider for generating an answer. +NUM_CONTEXT_SENTENCES = 20 +# Set the maximum length for each context sentence (in characters). +MAX_CONTEXT_SENTENCE_LENGTH = 1000 +# Set the maximum context references length (in characters). +MAX_CONTEXT_REFERENCES_LENGTH = 4000 + +# Batch size you wish the evaluators will use to call the `batch_generate_answer` function +AICROWD_SUBMISSION_BATCH_SIZE = 8 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. + +# VLLM Parameters +VLLM_TENSOR_PARALLEL_SIZE = 4 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. +VLLM_GPU_MEMORY_UTILIZATION = 0.85 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. + +# Sentence Transformer Parameters +SENTENTENCE_TRANSFORMER_BATCH_SIZE = 128 # TUNE THIS VARIABLE depending on the size of your embedding model and GPU mem available + +#### CONFIG PARAMETERS END--- + +class ChunkExtractor: + + @ray.remote + def _extract_chunks(self, interaction_id, html_source): """ - Initialize the RAGModel with necessary models and configurations. + Extracts and returns chunks from given HTML source. - This constructor sets up the environment by loading sentence transformers for embedding generation, - a large language model for generating responses, and tokenizer for text processing. It also initializes - model parameters and templates for generating answers. + Note: This function is for demonstration purposes only. + We are treating an independent sentence as a chunk here, + but you could choose to chunk your text more cleverly than this. + + Parameters: + interaction_id (str): Interaction ID that this HTML source belongs to. + html_source (str): HTML content from which to extract text. + + Returns: + Tuple[str, List[str]]: A tuple containing the interaction ID and a list of sentences extracted from the HTML content. """ - # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available. - self.sentence_model = SentenceTransformer( - "models/sentence-transformers/all-MiniLM-L6-v2", device="cuda" - ) + # Parse the HTML content using BeautifulSoup + soup = BeautifulSoup(html_source, "lxml") + text = soup.get_text(" ", strip=True) # Use space as a separator, strip whitespaces + + if not text: + # Return a list with empty string when no text is extracted + return interaction_id, [""] + + # Extract offsets of sentences from the text + _, offsets = text_to_sentences_and_offsets(text) - # Define the number of context sentences to consider for generating an answer. - self.num_context = 10 - # Set the maximum length for each context sentence in characters. - self.max_ctx_sentence_length = 1000 + # Initialize a list to store sentences + chunks = [] - # Template for formatting the input to the language model, including placeholders for the question and references. - self.prompt_template = """ - ### Question - {query} + # Iterate through the list of offsets and extract sentences + for start, end in offsets: + # Extract the sentence and limit its length + sentence = text[start:end][:MAX_CONTEXT_SENTENCE_LENGTH] + chunks.append(sentence) - ### References - {references} + return interaction_id, chunks - ### Answer + def extract_chunks(self, batch_interaction_ids, batch_search_results): """ + Extracts chunks from given batch search results using parallel processing with Ray. - # Configuration for model quantization to improve performance, using 4-bit precision. - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=False, - ) + Parameters: + batch_interaction_ids (List[str]): List of interaction IDs. + batch_search_results (List[List[Dict]]): List of search results batches, each containing HTML text. + + Returns: + Tuple[np.ndarray, np.ndarray]: A tuple containing an array of chunks and an array of corresponding interaction IDs. + """ + # Setup parallel chunk extraction using ray remote + ray_response_refs = [ + self._extract_chunks.remote( + self, + interaction_id=batch_interaction_ids[idx], + html_source=html_text["page_result"] + ) + for idx, search_results in enumerate(batch_search_results) + for html_text in search_results + ] + + # Wait until all sentence extractions are complete + # and collect chunks for every interaction_id separately + chunk_dictionary = defaultdict(list) + + for response_ref in ray_response_refs: + interaction_id, _chunks = ray.get(response_ref) # Blocking call until parallel execution is complete + chunk_dictionary[interaction_id].extend(_chunks) + + # Flatten chunks and keep a map of corresponding interaction_ids + chunks, chunk_interaction_ids = self._flatten_chunks(chunk_dictionary) + + return chunks, chunk_interaction_ids + + def _flatten_chunks(self, chunk_dictionary): + """ + Flattens the chunk dictionary into separate lists for chunks and their corresponding interaction IDs. + + Parameters: + chunk_dictionary (defaultdict): Dictionary with interaction IDs as keys and lists of chunks as values. + + Returns: + Tuple[np.ndarray, np.ndarray]: A tuple containing an array of chunks and an array of corresponding interaction IDs. + """ + chunks = [] + chunk_interaction_ids = [] - # Specify the large language model to be used. - model_name = "models/meta-llama/Llama-2-7b-chat-hf" + for interaction_id, _chunks in chunk_dictionary.items(): + # De-duplicate chunks within the scope of an interaction ID + unique_chunks = list(set(_chunks)) + chunks.extend(unique_chunks) + chunk_interaction_ids.extend([interaction_id] * len(unique_chunks)) - # Load the tokenizer for the specified model. - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + # Convert to numpy arrays for convenient slicing/masking operations later + chunks = np.array(chunks) + chunk_interaction_ids = np.array(chunk_interaction_ids) - # Load the large language model with the specified quantization configuration. - self.llm = AutoModelForCausalLM.from_pretrained( - model_name, - device_map="auto", - quantization_config=bnb_config, - torch_dtype=torch.float16, + return chunks, chunk_interaction_ids + +class RAGModel: + """ + An example RAGModel for the KDDCup 2024 Meta CRAG Challenge + which includes all the key components of a RAG lifecycle. + """ + def __init__(self): + self.initialize_models() + self.chunk_extractor = ChunkExtractor() + + def initialize_models(self): + # Initialize Meta Llama 3 - 8B Instruct Model + self.model_name = "models/meta-llama/Meta-Llama-3-8B-Instruct" + + if not os.path.exists(self.model_name): + raise Exception( + f""" + The evaluators expect the model weights to be checked into the repository, + but we could not find the model weights at {self.model_name} + + Please follow the instructions in the docs below to download and check in the model weights. + + https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md + """ + ) + + # Initialize the model with vllm + self.llm = vllm.LLM( + self.model_name, + tensor_parallel_size=VLLM_TENSOR_PARALLEL_SIZE, + gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, + trust_remote_code=True, + dtype="half", # note: bfloat16 is not supported on nvidia-T4 GPUs + enforce_eager=True ) + self.tokenizer = self.llm.get_tokenizer() - # Initialize a text generation pipeline with the loaded model and tokenizer. - self.generation_pipe = pipeline( - task="text-generation", - model=self.llm, - tokenizer=self.tokenizer, - max_new_tokens=10, + # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available. + self.sentence_model = SentenceTransformer( + "models/sentence-transformers/all-MiniLM-L6-v2", + device=torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ), + ) + + def calculate_embeddings(self, sentences): + """ + Compute normalized embeddings for a list of sentences using a sentence encoding model. + + This function leverages multiprocessing to encode the sentences, which can enhance the + processing speed on multi-core machines. + + Args: + sentences (List[str]): A list of sentences for which embeddings are to be computed. + + Returns: + np.ndarray: An array of normalized embeddings for the given sentences. + + """ + embeddings = self.sentence_model.encode( + sentences=sentences, + normalize_embeddings=True, + batch_size=SENTENTENCE_TRANSFORMER_BATCH_SIZE, ) + # Note: There is an opportunity to parallelize the embedding generation across 4 GPUs + # but sentence_model.encode_multi_process seems to interefere with Ray + # on the evaluation servers. + # todo: this can also be done in a Ray native approach. + # + return embeddings - def generate_answer( - self, query: str, search_results: List[Dict], query_time: str - ) -> str: + def get_batch_size(self) -> int: """ - Generate an answer based on the provided query and a list of pre-cached search results. + Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. + + The evaluation timeouts linearly scale with the batch size. + i.e.: time out for the `batch_generate_answer` call = batch_size * per_sample_timeout + + + Returns: + int: The batch size, an integer between 1 and 16. It can be dynamic + across different batch_generate_answer calls, or stay a static value. + """ + self.batch_size = AICROWD_SUBMISSION_BATCH_SIZE + return self.batch_size + + def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: + """ + Generates answers for a batch of queries using associated (pre-cached) search results and query times. Parameters: - - query (str): The user's question. - - search_results (List[Dict]): A list containing the search result objects, - as described here: - https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail - - query_time (str): The time at which the query was made, represented as a string. + batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: + - 'interaction_id; (List[str]): List of interaction_ids for the associated queries + - 'query' (List[str]): List of user queries. + - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding + to a query. Please refer to the following link for + more details about the individual search objects: + https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. Returns: - - str: A text response that answers the query. Limited to 75 tokens. + List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. + If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. - This method processes the search results to extract relevant sentences, generates embeddings for them, - and selects the top context sentences based on cosine similarity to the query embedding. It then formats - this information into a prompt for the language model, which generates an answer that is then trimmed to - meet the token limit. + Notes: + - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid + the penalty for hallucination. + - Response Time: Ensure that your model processes and responds to each query within 10 seconds. + Failing to adhere to this time constraint **will** result in a timeout during evaluation. """ + batch_interaction_ids = batch["interaction_id"] + queries = batch["query"] + batch_search_results = batch["search_results"] + query_times = batch["query_time"] - # Initialize a list to hold all extracted sentences from the search results. - all_sentences = [] - - # Process each HTML text from the search results to extract text content. - for html_text in search_results: - # Parse the HTML content to extract text. - soup = BeautifulSoup(html_text["page_result"], features="lxml") - text = soup.get_text().replace("\n", "") - if len(text) > 0: - # Convert the text into sentences and extract their offsets. - offsets = text_to_sentences_and_offsets(text)[1] - for ofs in offsets: - # Extract each sentence based on its offset and limit its length. - sentence = text[ofs[0] : ofs[1]] - all_sentences.append( - sentence[: self.max_ctx_sentence_length] - ) - else: - # If no text is extracted, add an empty string as a placeholder. - all_sentences.append("") - - # Generate embeddings for all sentences and the query. - all_embeddings = self.sentence_model.encode( - all_sentences, normalize_embeddings=True + # Chunk all search results using ChunkExtractor + chunks, chunk_interaction_ids = self.chunk_extractor.extract_chunks( + batch_interaction_ids, batch_search_results ) - query_embedding = self.sentence_model.encode( - query, normalize_embeddings=True - )[None, :] - - # Calculate cosine similarity between query and sentence embeddings, and select the top sentences. - cosine_scores = (all_embeddings * query_embedding).sum(1) - top_sentences = np.array(all_sentences)[ - (-cosine_scores).argsort()[: self.num_context] - ] - # Format the top sentences as references in the model's prompt template. - references = "" - for snippet in top_sentences: - references += "<DOC>\n" + snippet + "\n</DOC>\n" - references = " ".join( - references.split()[:500] - ) # Limit the length of references to fit the model's input size. - final_prompt = self.prompt_template.format( - query=query, references=references + # Calculate all chunk embeddings + chunk_embeddings = self.calculate_embeddings(chunks) + + # Calculate embeddings for queries + query_embeddings = self.calculate_embeddings(queries) + + # Retrieve top matches for the whole batch + batch_retrieval_results = [] + for _idx, interaction_id in enumerate(batch_interaction_ids): + query = queries[_idx] + query_time = query_times[_idx] + query_embedding = query_embeddings[_idx] + + # Identify chunks that belong to this interaction_id + relevant_chunks_mask = chunk_interaction_ids == interaction_id + + # Filter out the said chunks and corresponding embeddings + relevant_chunks = chunks[relevant_chunks_mask] + relevant_chunks_embeddings = chunk_embeddings[relevant_chunks_mask] + + # Calculate cosine similarity between query and chunk embeddings, + cosine_scores = (relevant_chunks_embeddings * query_embedding).sum(1) + + # and retrieve top-N results. + retrieval_results = relevant_chunks[ + (-cosine_scores).argsort()[:NUM_CONTEXT_SENTENCES] + ] + + # You might also choose to skip the steps above and + # use a vectorDB directly. + batch_retrieval_results.append(retrieval_results) + + # Prepare formatted prompts from the LLM + formatted_prompts = self.format_prompts(queries, query_times, batch_retrieval_results) + + # Generate responses via vllm + responses = self.llm.generate( + formatted_prompts, + vllm.SamplingParams( + n=1, # Number of output sequences to return for each prompt. + top_p=0.9, # Float that controls the cumulative probability of the top tokens to consider. + temperature=0.1, # Randomness of the sampling + skip_special_tokens=True, # Whether to skip special tokens in the output. + max_tokens=50, # Maximum number of tokens to generate per output sequence. + + # Note: We are using 50 max new tokens instead of 75, + # because the 75 max token limit for the competition is checked using the Llama2 tokenizer. + # Llama3 instead uses a different tokenizer with a larger vocabulary + # This allows the Llama3 tokenizer to represent the same content more efficiently, + # while using fewer tokens. + ), + use_tqdm=False # you might consider setting this to True during local development ) - # Generate an answer using the formatted prompt. - result = self.generation_pipe(final_prompt) - result = result[0]["generated_text"] + # Aggregate answers into List[str] + answers = [] + for response in responses: + answers.append(response.outputs[0].text) + + return answers + + def format_prompts(self, queries, query_times, batch_retrieval_results=[]): + """ + Formats queries, corresponding query_times and retrieval results using the chat_template of the model. + + Parameters: + - queries (List[str]): A list of queries to be formatted into prompts. + - query_times (List[str]): A list of query_time strings corresponding to each query. + - batch_retrieval_results (List[str]) + """ + system_prompt = "You are provided with a question and various references. Your task is to answer the question succinctly, using the fewest words possible. If the references do not contain the necessary information to answer the question, respond with 'I don't know'. There is no need to explain the reasoning behind your answers." + formatted_prompts = [] + + for _idx, query in enumerate(queries): + query_time = query_times[_idx] + retrieval_results = batch_retrieval_results[_idx] - try: - # Extract the answer from the generated text. - answer = result.split("### Answer\n")[-1] - except IndexError: - # If the model fails to generate an answer, return a default response. - answer = "I don't know" + user_message = "" + references = "" + + if len(retrieval_results) > 0: + references += "# References \n" + # Format the top sentences as references in the model's prompt template. + for _snippet_idx, snippet in enumerate(retrieval_results): + references += f"- {snippet.strip()}\n" + + references = references[:MAX_CONTEXT_REFERENCES_LENGTH] + # Limit the length of references to fit the model's input size. - # Trim the prediction to a maximum of 75 tokens to meet the submission requirements. - trimmed_answer = trim_predictions_to_max_token_length(answer) + user_message += f"{references}\n------\n\n" + user_message + user_message += f"Using only the references listed above, answer the following question: \n" + user_message += f"Current Time: {query_time}\n" + user_message += f"Question: {query}\n" + + formatted_prompts.append( + self.tokenizer.apply_chat_template( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + tokenize=False, + add_generation_prompt=True, + ) + ) - return trimmed_answer + return formatted_prompts diff --git a/models/user_config.py b/models/user_config.py index 989758463fe59048b4e3e16bd429ba18f1207f5e..b09f9825d73ca0faab325fa20a37c2ae301b7f27 100644 --- a/models/user_config.py +++ b/models/user_config.py @@ -1,10 +1,11 @@ +# isort: skip_file from models.dummy_model import DummyModel UserModel = DummyModel # Uncomment the lines below to use the Vanilla LLAMA baseline -# from models.vanilla_llama_baseline import ChatModel -# UserModel = ChatModel +# from models.vanilla_llama_baseline import InstructModel +# UserModel = InstructModel # Uncomment the lines below to use the RAG LLAMA baseline diff --git a/models/vanilla_llama_baseline.py b/models/vanilla_llama_baseline.py index 314f8b0a63fbbd30cc263024c9b81731280ceb28..24a9424ca8d5a895005f0e0ad83d8b6feaf7fd9f 100644 --- a/models/vanilla_llama_baseline.py +++ b/models/vanilla_llama_baseline.py @@ -1,15 +1,10 @@ import os -from typing import Dict, List +from typing import Any, Dict, List import numpy as np import torch +import vllm from models.utils import trim_predictions_to_max_token_length -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - pipeline, -) ###################################################################################################### ###################################################################################################### @@ -20,6 +15,8 @@ from transformers import ( ### ### https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/download_baseline_model_weights.md ### +### And please pay special attention to the comments that start with "TUNE THIS VARIABLE" +### as they depend on your model and the available GPU resources. ### ### DISCLAIMER: This baseline has NOT been tuned for performance ### or efficiency, and is provided as is for demonstration. @@ -38,33 +35,35 @@ from transformers import ( CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") -class ChatModel: +#### CONFIG PARAMETERS --- + +# Batch size you wish the evaluators will use to call the `batch_generate_answer` function +AICROWD_SUBMISSION_BATCH_SIZE = 8 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. + +# VLLM Parameters +VLLM_TENSOR_PARALLEL_SIZE = 4 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. +VLLM_GPU_MEMORY_UTILIZATION = 0.85 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. + +#### CONFIG PARAMETERS END--- + +class InstructModel: def __init__(self): """ Initialize your model(s) here if necessary. This is the constructor for your DummyModel class, where you can set up any required initialization steps for your model(s) to function correctly. """ - self.prompt_template = """You are given a quesition and references which may or may not help answer the question. Your goal is to answer the question in as few words as possible. -### Question -{query} - -### Answer""" + self.initialize_models() - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=False, - ) + def initialize_models(self): + # Initialize Meta Llama 3 - 8B Instruct Model + self.model_name = "models/meta-llama/Meta-Llama-3-8B-Instruct" - model_name = "models/meta-llama/Llama-2-7b-chat-hf" - - if not os.path.exists(model_name): + if not os.path.exists(self.model_name): raise Exception( f""" The evaluators expect the model weights to be checked into the repository, - but we could not find the model weights at {model_name} + but we could not find the model weights at {self.model_name} Please follow the instructions in the docs below to download and check in the model weights. @@ -72,38 +71,46 @@ class ChatModel: """ ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - - self.llm = AutoModelForCausalLM.from_pretrained( - model_name, - device_map="auto", - quantization_config=bnb_config, - torch_dtype=torch.float16, + # initialize the model with vllm + self.llm = vllm.LLM( + self.model_name, + tensor_parallel_size=VLLM_TENSOR_PARALLEL_SIZE, + gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, + trust_remote_code=True, + dtype="half", # note: bfloat16 is not supported on nvidia-T4 GPUs + enforce_eager=True ) + self.tokenizer = self.llm.get_tokenizer() - self.generation_pipe = pipeline( - task="text-generation", - model=self.llm, - tokenizer=self.tokenizer, - max_new_tokens=75, - ) + def get_batch_size(self) -> int: + """ + Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. - def generate_answer( - self, query: str, search_results: List[Dict], query_time: str - ) -> str: + Returns: + int: The batch size, an integer between 1 and 16. This value indicates how many + queries should be processed together in a single batch. It can be dynamic + across different batch_generate_answer calls, or stay a static value. """ - Generate an answer based on a provided query and a list of pre-cached search results. + self.batch_size = AICROWD_SUBMISSION_BATCH_SIZE + return self.batch_size + + def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: + """ + Generates answers for a batch of queries using associated (pre-cached) search results and query times. Parameters: - - query (str): The user's question or query input. - - search_results (List[Dict]): A list containing the search result objects, - as described here: - https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail - - query_time (str): The time at which the query was made, represented as a string. + batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: + - 'interaction_id; (List[str]): List of interaction_ids for the associated queries + - 'query' (List[str]): List of user queries. + - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding + to a query. Please refer to the following link for + more details about the individual search objects: + https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. Returns: - - (str): A plain text response that answers the query. This response is limited to 75 tokens. - If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. + List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. + If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. Notes: - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid @@ -111,12 +118,64 @@ class ChatModel: - Response Time: Ensure that your model processes and responds to each query within 10 seconds. Failing to adhere to this time constraint **will** result in a timeout during evaluation. """ + batch_interaction_ids = batch["interaction_id"] + queries = batch["query"] + batch_search_results = batch["search_results"] + query_times = batch["query_time"] + + formatted_prompts = self.format_prommpts(queries, query_times) + + # Generate responses via vllm + responses = self.llm.generate( + formatted_prompts, + vllm.SamplingParams( + n=1, # Number of output sequences to return for each prompt. + top_p=0.9, # Float that controls the cumulative probability of the top tokens to consider. + temperature=0.1, # randomness of the sampling + skip_special_tokens=True, # Whether to skip special tokens in the output. + max_tokens=50, # Maximum number of tokens to generate per output sequence. + # Note: We are using 50 max new tokens instead of 75, + # because the 75 max token limit is checked using the Llama2 tokenizer. + # The Llama3 model instead uses a differet tokenizer with a larger vocabulary + # This allows it to represent the same content more efficiently, using fewer tokens. + ), + use_tqdm = False + ) - final_prompt = self.prompt_template.format(query=query) - result = self.generation_pipe(final_prompt)[0]["generated_text"] - answer = result.split("### Answer")[1].strip() + # Aggregate answers into List[str] + answers = [] + for response in responses: + answers.append(response.outputs[0].text) - # Trim prediction to a max of 75 tokens - trimmed_answer = trim_predictions_to_max_token_length(answer) + return answers + + def format_prommpts(self, queries, query_times): + """ + Formats queries and corresponding query_times using the chat_template of the model. + + Parameters: + - queries (list of str): A list of queries to be formatted into prompts. + - query_times (list of str): A list of query_time strings corresponding to each query. + + """ + system_prompt = "You are provided with a question and various references. Your task is to answer the question succinctly, using the fewest words possible. If the references do not contain the necessary information to answer the question, respond with 'I don't know'." + formatted_prompts = [] + + for _idx, query in enumerate(queries): + query_time = query_times[_idx] + user_message = "" + user_message += f"Current Time: {query_time}\n" + user_message += f"Question: {query}\n" + + formatted_prompts.append( + self.tokenizer.apply_chat_template( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + tokenize=False, + add_generation_prompt=True, + ) + ) - return trimmed_answer + return formatted_prompts diff --git a/requirements.txt b/requirements.txt index 5fc9efd1f15725d33c6775af8f682525d0517ade..9936f7b82ff6c6c70dde2d4e14a49f711e5d5ce2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ lxml openai==1.13.3 sentence_transformers torch -transformers \ No newline at end of file +transformers +vllm>=0.4.2 \ No newline at end of file