diff --git a/metrics.py b/metrics.py index a2f24c9fa204db890b2e62a08a11f5b0cf83b574..df402c300944c7d56ca67408f068b301c1a6fc01 100644 --- a/metrics.py +++ b/metrics.py @@ -5,8 +5,7 @@ import evaluate from typing import List -print("\nsacrebleu loading...") -sacrebleu = evaluate.load("sacrebleu") +sacrebleu = None def accuracy(prediction: int, truth: int): @@ -107,6 +106,11 @@ def ndcg_eval(relevance_scores: List[float], truth: List[float]): def bleu(generation, truth, jp=False): + global sacrebleu + if sacrebleu is None: + print("\nsacrebleu loading...") + sacrebleu = evaluate.load("sacrebleu") + generation = generation.lstrip("\n").rstrip("\n").split("\n")[0] candidate = [generation] reference = [[truth]]