Source code for eval_framework.metrics.completion.f1
import re
import string
from collections import Counter
from typing import Any
from eval_framework.metrics.base import BaseMetric, MetricResult
from eval_framework.shared.types import Completion
[docs]
class F1(BaseMetric[Completion]):
"""
Token-overlap F1 metric.
Default behavior is backward-compatible with the previous implementation:
lowercase + whitespace tokenization.
"""
NAME = "F1"
[docs]
def normalize(self, text: str) -> str:
"""Normalizes text to use lower case."""
return text.lower()
[docs]
def tokenize(self, text: str) -> list[str]:
"""Tokenizes text into a list of tokens using whitespace as the delimiter."""
return self.normalize(text).split()
[docs]
def calculate(self, response: Completion) -> list[MetricResult]:
if response.error is not None:
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
if not ground_truths:
return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
prediction_tokens = self.tokenize(response.completion)
ground_truths_tokens = [self.tokenize(gt) for gt in ground_truths]
f1_scores = [calculate_f1(gt_tokens, prediction_tokens) for gt_tokens in ground_truths_tokens]
max_f1 = max(f1_scores)
return [MetricResult(metric_name=self.NAME, value=max_f1, higher_is_better=True, error=response.error)]
[docs]
class F1SquadNormalized(F1):
"""
SQuAD-style normalized F1:
- lowercase
- remove punctuation
- remove articles (a, an, the)
- collapse extra whitespace
"""
NAME = "F1 SQuAD Normalized"
_ARTICLES_RE = re.compile(r"\b(a|an|the)\b")
_PUNCTUATION = set(string.punctuation)
[docs]
def normalize(self, text: str) -> str:
text = text.lower()
text = "".join(ch for ch in text if ch not in self._PUNCTUATION)
text = self._ARTICLES_RE.sub(" ", text)
text = " ".join(text.split())
return text
[docs]
def calculate_f1(ref_tokens: list[Any], hyp_tokens: list[Any]) -> float:
"""Calculate F1 score between two texts based on token overlap."""
if not ref_tokens and not hyp_tokens:
return 1.0
if not ref_tokens or not hyp_tokens:
return 0.0
common = Counter(ref_tokens) & Counter(hyp_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0.0
precision = num_same / len(hyp_tokens)
recall = num_same / len(ref_tokens)
return 2 * precision * recall / (precision + recall)