diff --git a/local_evaluation.py b/local_evaluation.py
index 3e8ba70c0e8d5073c3617de2bae8381f7c0b5908..bfaf590046dde52002a0764d26fb94513bafca7c 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -4,7 +4,7 @@ import torch
 import numpy as np
 import os
 
-from sentence_transformers import SentenceTransformer
+
 import metrics
 import parsers
 
@@ -145,7 +145,7 @@ def aggregate_scores(per_task_metrics):
         overall_score = (
             np.mean(sample_scores)
             if metric != "micro f1"
-            else metrics.compute_f1_score(sample_scores)
+            else metrics.calculate_f1_score(sample_scores)
         )
 
         overall_metrics["task_name"].append(task_name)
@@ -165,26 +165,28 @@ def get_evaluation_methods():
     Returns:
     - A dictionary mapping metric names to their respective evaluation functions.
     """
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    sentence_all_lm = SentenceTransformer("all-MiniLM-L6-v2").to(device)
-    sentence_multilingual = SentenceTransformer(
-        "paraphrase-multilingual-MiniLM-L12-v2"
-    ).to(device)
-
     return {
-        "accuracy": metrics.accuracy,
-        "hit rate@3": metrics.hit_rate_3,
-        "rougel": metrics.rougel,
-        "sent-transformer": lambda g, t: metrics.sent_transformer(
-            g, t, sentence_all_lm
+        "accuracy": metrics.calculate_per_sample_accuracy,
+        "hit rate@3": metrics.calculate_hit_rate_3,
+        "rougel": metrics.calculate_rougel,
+        "sent-transformer": lambda generated_text, reference_texts: metrics.calculate_cosine_similarity(
+            generated_text=generated_text, 
+            reference_texts=reference_texts, 
+            model_name="all-MiniLM-L6-v2"
+        ),
+        "multilingual-sent-transformer": lambda generated_text, reference_texts: metrics.calculate_cosine_similarity(
+            generated_text=generated_text, 
+            reference_texts=reference_texts, 
+            model_name="paraphrase-multilingual-MiniLM-L12-v2"
         ),
-        "multilingual-sent-transformer": lambda g, t: metrics.sent_transformer(
-            g, t, sentence_multilingual
+        "micro f1": metrics.calculate_true_positive_false_positives_false_negatives, 
+        "ndcg": metrics.calculate_ndcg,
+        "bleu": metrics.calculate_bleu_score,
+        "jp-bleu": lambda generated_text, reference_text: metrics.calculate_bleu_score(
+            generated_text=generated_text, 
+            reference_text=reference_text, 
+            is_japanese=True
         ),
-        "micro f1": metrics.tp_fp_fn,
-        "ndcg": metrics.ndcg_eval,
-        "bleu": metrics.bleu,
-        "jp-bleu": lambda g, t: metrics.bleu(g, t, jp=True),
     }
 
 
diff --git a/metrics.py b/metrics.py
index df402c300944c7d56ca67408f068b301c1a6fc01..9d6d3d61fe1e7a10b311252ec48f00a774330ad5 100644
--- a/metrics.py
+++ b/metrics.py
@@ -3,132 +3,265 @@ from sentence_transformers import SentenceTransformer
 import numpy as np
 import evaluate
 
-from typing import List
+import torch
+
+from typing import List, Union, Tuple
 
 sacrebleu = None
+sentence_transformer_model_cache = {}
+
 
+def calculate_per_sample_accuracy(prediction: int, truth: int) -> bool:
+    """
+    Computes the accuracy of a single prediction.
 
-def accuracy(prediction: int, truth: int):
+    This function checks if a given prediction matches the ground truth.
+    
+    Parameters:
+    - prediction (int): The predicted value.
+    - truth (int): The actual ground truth value.
+    
+    Returns:
+    - bool: True if the prediction matches the truth, False otherwise.
+    """
     return prediction == truth
 
 
-def hit_rate_3(retrieved_int: List[int], truth: List[int]):
+def calculate_hit_rate_3(retrieved_int: List[int], truth: List[int]) -> float:
+    """
+    Calculates the hit rate within the top 3 retrieved integers.
+
+    This function assesses how many of the truth integers are present 
+    within the first three elements of the retrieved list of integers.
+    
+    Parameters:
+    - retrieved_int (List[int]): The list of retrieved integers, ordered by relevance.
+    - truth (List[int]): The list of ground truth integers.
+    
+    Returns:
+    - float: The hit rate, calculated as the proportion of truth integers found 
+      in the top 3 retrieved integers, relative to the total number of truth integers.
+    """
+    # Calculate the number of hits within the top 3 retrieved integers
     hit = len(set(truth).intersection(set(retrieved_int[:3])))
-    hit /= len(truth)
-    return hit
+    # Normalize the hit count by the total number of truth integers to get the hit rate
+    hit_rate = hit / len(truth)
+    return hit_rate
+
 
+def calculate_rougel(generation: str, truth: str) -> float:
+    """
+    Calculates the ROUGE-L F-measure score between a generated string and the truth string.
 
-def rougel(generation: str, truth: str):
+    ROUGE-L measures the longest common subsequence between the generated text and the truth text,
+    considering both the precision and recall of the sequences. It is widely used in evaluating
+    the quality of text generation systems.
+    
+    Parameters:
+    - generation (str): The generated text to evaluate.
+    - truth (str): The ground truth text to compare against.
+    
+    Returns:
+    - float: The ROUGE-L F-measure score, indicating the quality of the generated text.
+    """
+    # Initialize the ROUGE scorer with the ROUGE-L metric
     scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
+    # Calculate the ROUGE scores between the generated text and the truth text
     scores = scorer.score(generation, truth)
+    # Extract and return the ROUGE-L F-measure score
     return scores["rougeL"].fmeasure
 
 
-def sent_transformer(generation: str, truth: str, sent_transformer_model):
-    generation_embedding = sent_transformer_model.encode([generation])[0]
+def load_sentence_transformer_model(model_name: str) -> SentenceTransformer:
+    """
+    Loads a Sentence Transformer model by its name and moves it to the appropriate device.
 
-    if isinstance(truth, str):
-        truth_embedding = sent_transformer_model.encode([truth])[0]
-        score = (generation_embedding * truth_embedding).sum()
-        score /= np.linalg.norm(generation_embedding, ord=2) * np.linalg.norm(
-            truth_embedding, ord=2
-        )
-        if score > 0:
-            return score
-        else:
-            return 0
+    Parameters:
+    - model_name (str): The name of the model to load.
+
+    Returns:
+    - SentenceTransformer: The loaded SentenceTransformer model.
+    """
+    
+    global sentence_transformer_model_cache
+    
+    # a model cache ensure we do not load the model on every call
+    if model_name not in sentence_transformer_model_cache:
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+        model = SentenceTransformer(model_name).to(device)
+        sentence_transformer_model_cache[model_name] = model
+        
+    return sentence_transformer_model_cache[model_name]
+
+def calculate_cosine_similarity(generated_text: str, reference_texts: Union[str, List[str]], model_name) -> float:
+    """
+    Computes the cosine similarity score(s) between a generated text and reference text(s) using a sentence embedding model.
+    
+    This function calculates the cosine similarity between the embedding of the generated text and the embedding(s) 
+    of reference text(s). The embeddings are generated using a specified sentence embedding model. The cosine similarity 
+    score is a measure of similarity between two vectors, ranging from -1 (completely different) to 1 (exactly the same).
+    
+    Parameters:
+    - generated_text (str): The text generated by the model.
+    - reference_texts (Union[str, List[str]]): The reference text(s) for comparison. Can be a single string or a list of strings.
+    - model_name: The sentence embedding model used to generate text embeddings.
+    
+    Returns:
+    - float: The average cosine similarity score between the generated text and the reference text(s). If reference_texts is a single 
+      string, a single score is returned. If reference_texts is a list of strings, the average score across all references is returned.
+      The score is bounded between 0 (no similarity) and 1 (identical), with negative scores adjusted to 0.
+    """
+    # Load/Reference model
+    model = load_sentence_transformer_model(model_name)
+    
+    # Embedding for the generated text
+    generated_embedding = model.encode([generated_text])[0]
+
+    # Handling a single reference text
+    if isinstance(reference_texts, str):
+        # Embedding for the single reference text
+        reference_embedding = model.encode([reference_texts])[0]
+        # Compute cosine similarity
+        similarity_score = np.dot(generated_embedding, reference_embedding) / (np.linalg.norm(generated_embedding) * np.linalg.norm(reference_embedding))
+        # Ensure non-negative score
+        return max(similarity_score, 0)
+    
+    # Handling multiple reference texts
     else:
-        scores = []
-        for label_item in truth:
-            truth_embedding = sent_transformer_model.encode([label_item])[0]
-            score_ = (generation_embedding * truth_embedding).sum()
-            score_ /= np.linalg.norm(
-                generation_embedding, ord=2
-            ) * np.linalg.norm(truth_embedding, ord=2)
-            scores.append(score_)
-        if np.mean(scores) > 0:
-            return np.mean(scores)
-        else:
-            return 0
-
-
-def tp_fp_fn(entity_list, truth):
-    answer_lower = []
-    for a in entity_list:
-        answer_lower.append(a.lower().lstrip(" ").rstrip(" "))
-    truth_lower = []
-    for l in truth:
-        truth_lower.append(l.lower())
-    true_positive = len(set(answer_lower).intersection(set(truth_lower)))
-    false_positive = len(answer_lower) - true_positive
-    false_negative = len(truth_lower) - true_positive
-    return true_positive, false_positive, false_negative
-
-
-def compute_f1_score(tp_fp_fn_list):
-    total_tp = 0
-    total_fp = 0
-    total_fn = 0
-    for tp, fp, fn in tp_fp_fn_list:
+        similarity_scores = []
+        for reference_text in reference_texts:
+            # Embedding for each reference text
+            reference_embedding = model.encode([reference_text])[0]
+            # Compute cosine similarity for each reference
+            individual_score = np.dot(generated_embedding, reference_embedding) / (np.linalg.norm(generated_embedding) * np.linalg.norm(reference_embedding))
+            similarity_scores.append(individual_score)
+        # Calculate and ensure non-negative average score
+        return max(np.mean(similarity_scores), 0)
+    
+def calculate_true_positive_false_positives_false_negatives(extracted_entities: List[str], ground_truth_entities: List[str]) -> Tuple[int, int, int]:
+    """
+    Calculates true positives, false positives, and false negatives for entity extraction.
+
+    This function compares a list of extracted entities against a list of ground truth entities
+    to determine the count of true positives (correctly extracted entities), false positives
+    (incorrectly extracted entities), and false negatives (missed entities).
+
+    Both lists are case-insensitive, and leading/trailing spaces in extracted entities are ignored.
+
+    Parameters:
+    - extracted_entities (List[str]): The list of entities extracted by the model.
+    - ground_truth_entities (List[str]): The list of actual entities (ground truth).
+
+    Returns:
+    - Tuple[int, int, int]: A tuple containing the counts of true positives, false positives, and false negatives.
+    """
+    # Normalize the extracted entities by making them lowercase and stripping leading/trailing spaces
+    normalized_extracted_entities = [entity.lower().strip() for entity in extracted_entities]
+    
+    # Normalize the ground truth entities by making them lowercase
+    normalized_ground_truth_entities = [entity.lower() for entity in ground_truth_entities]
+
+    # Calculate true positives by finding the intersection between extracted and ground truth entities
+    true_positives = len(set(normalized_extracted_entities).intersection(set(normalized_ground_truth_entities)))
+
+    # Calculate false positives as extracted entities not in ground truth
+    false_positives = len(normalized_extracted_entities) - true_positives
+
+    # Calculate false negatives as ground truth entities not extracted
+    false_negatives = len(normalized_ground_truth_entities) - true_positives
+
+    return true_positives, false_positives, false_negatives
+
+def calculate_f1_score(metrics_list: List[Tuple[int, int, int]]) -> float:
+    """
+    Calculates the F1 score from a list of tuples containing true positives, false positives, and false negatives.
+
+    Parameters:
+    - metrics_list (List[Tuple[int, int, int]]): A list of tuples, where each tuple contains counts of true positives,
+      false positives, and false negatives in that order for various classifications or entity extractions.
+
+    Returns:
+    - float: The computed F1 score, ranging from 0 to 1.
+    """
+    total_tp, total_fp, total_fn = 0, 0, 0
+
+    # Aggregate total true positives, false positives, and false negatives
+    for tp, fp, fn in metrics_list:
         total_tp += tp
         total_fp += fp
         total_fn += fn
-    precision = total_tp / (total_tp + total_fp)
-    recall = total_tp / (total_tp + total_fn)
+
+    # Calculate precision and recall
+    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
+    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
+
+    # Calculate F1 score, handling the case where precision + recall equals 0
     if precision + recall == 0:
         return 0
     else:
         return 2 * precision * recall / (precision + recall)
 
+def calculate_ndcg(predicted_relevance_scores: List[int], true_relevance_weights: List[float]) -> float:
+    """
+    Calculates and evaluates the Normalized Discounted Cumulative Gain (NDCG) score directly from predicted relevance scores
+    against true relevance weights. It normalizes the scores to ensure a fair comparison, trimming the predicted scores
+    if necessary to match the length of the true relevance weights.
+
+    Parameters:
+    - predicted_relevance_scores (List[int]): Indices of items ranked by the algorithm, expected to be integers starting from 1.
+    - true_relevance_weights (List[float]): Actual relevance weights for the items, with higher values indicating greater relevance.
+
+    Returns:
+    - float: The NDCG score, normalized against the ideal ranking, ranging from 0 to 1.
+    """
+    # Trim the predicted scores to match the true scores length if necessary
+    if len(predicted_relevance_scores) > len(true_relevance_weights):
+        predicted_relevance_scores = predicted_relevance_scores[:len(true_relevance_weights)]
+
+    dcg, idcg = 0.0, 0.0
 
-def ndcg(ranked_list, weight):
-    idcg = 0
-    dcg = 0
-    for i in range(len(ranked_list)):
-        position = i + 1
-        if ranked_list[i] - 1 < len(weight):
-            relevance = weight[ranked_list[i] - 1]
+    # Calculate DCG for the predicted ranking
+    for i, score_index in enumerate(predicted_relevance_scores, start=1):
+        if score_index - 1 < len(true_relevance_weights):
+            relevance = true_relevance_weights[score_index - 1]
         else:
             relevance = 0
-        dcg += (np.power(2, relevance) - 1) / np.log2(position + 1)
-    weight.sort(reverse=True)
-    for i in range(len(weight)):
-        position = i + 1
-        relevance = weight[i]
-        idcg += (np.power(2, relevance) - 1) / np.log2(position + 1)
-    return dcg / idcg
+        dcg += (np.power(2, relevance) - 1) / np.log2(i + 1)
+    
+    # Calculate IDCG using sorted true relevance weights
+    for i, weight in enumerate(sorted(true_relevance_weights, reverse=True), start=1):
+        idcg += (np.power(2, weight) - 1) / np.log2(i + 1)
+    
+    # Avoid division by zero
+    return 0 if idcg == 0 else dcg / idcg
 
 
-def ndcg_eval(relevance_scores: List[float], truth: List[float]):
-    if len(relevance_scores) > len(truth):
-        relevance_scores = relevance_scores[: len(truth)]
-    return ndcg(relevance_scores, truth)
+def calculate_bleu_score(generated_text: str, reference_text: str, is_japanese: bool = False) -> float:
+    """
+    Calculates the BLEU score for a generated text compared to a reference truth text. This function supports
+    both general text and Japanese-specific evaluation by using the sacrebleu library.
 
+    Parameters:
+    - generated_text (str): The generated text to be evaluated.
+    - reference_text (str): The reference truth text.
+    - is_japanese (bool, optional): Flag to indicate whether the text is in Japanese, requiring special tokenization.
 
-def bleu(generation, truth, jp=False):
+    Returns:
+    - float: The BLEU score as a percentage (0 to 1 scale) for the generated text against the reference truth.
+    """
     global sacrebleu
     if sacrebleu is None:
-        print("\nsacrebleu loading...")
         sacrebleu = evaluate.load("sacrebleu")
 
-    generation = generation.lstrip("\n").rstrip("\n").split("\n")[0]
-    candidate = [generation]
-    reference = [[truth]]
-    if not jp:
-        score = (
-            sacrebleu.compute(
-                predictions=candidate, references=reference, lowercase=True
-            )["score"]
-            / 100
-        )
-    else:
-        score = (
-            sacrebleu.compute(
-                predictions=candidate,
-                references=reference,
-                lowercase=True,
-                tokenize="ja-mecab",
-            )["score"]
-            / 100
-        )
+    # Preprocess input texts
+    generated_text = generated_text.lstrip("\n").rstrip("\n").split("\n")[0]
+    candidate = [generated_text]
+    reference = [[reference_text]]
+
+    # Compute BLEU score with or without Japanese-specific tokenization
+    bleu_args = {"predictions": candidate, "references": reference, "lowercase": True}
+    if is_japanese:
+        bleu_args["tokenize"] = "ja-mecab"
+    score = sacrebleu.compute(**bleu_args)["score"] / 100
+
     return score