Source code for eval_framework.metrics.loglikelihood.confidence_weighted_accuracy

from eval_framework.metrics.base import MetricResult
from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
from eval_framework.shared.types import Loglikelihood


[docs] class ConfidenceWeightedAccuracy(BaseLoglikelihoodMetric): NAME = "Confidence-weighted Accuracy" def __init__(self, *, len_normalised: bool = True) -> None: super().__init__(len_normalised=len_normalised)
[docs] def calculate(self, response: Loglikelihood) -> list[MetricResult]: if response.error is not None: return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)] loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods) ground_truths = self._gather_ground_truths(response) best_key = max(loglikelihoods, key=loglikelihoods.get) # type: ignore[arg-type] best_key_norm = self._normalise_text(best_key) p_c = probs.get(best_key, 0.0) accuracy = p_c if best_key_norm in ground_truths else 0.0 return [MetricResult(metric_name=self.NAME, value=accuracy, higher_is_better=True, error=response.error)]