Source code for eval_framework.metrics.completion.accuracy_completion

import re
import string
from typing import Any

import numpy as np

from eval_framework.metrics.base import BaseMetric, MetricResult
from eval_framework.shared.types import Completion


[docs] class AccuracyCompletion(BaseMetric[Completion]): NAME = "Accuracy Completion"
[docs] def calculate(self, response: Completion) -> list[MetricResult]: if response.error is not None: return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)] ground_truths = response.ground_truth_list is_correct = any(response.completion == gt for gt in ground_truths) return [ MetricResult(metric_name=self.NAME, value=float(is_correct), higher_is_better=True, error=response.error) ]
[docs] class AccuracyCompletionWithEvaluate(AccuracyCompletion): def __init__(self, regexes_to_ignore: list[str], ignore_case: bool = False, ignore_punctuation: bool = False): self.regexes_to_ignore = regexes_to_ignore self.ignore_case = ignore_case self.ignore_punctuation = ignore_punctuation
[docs] def calculate(self, response: Completion) -> list[MetricResult]: if response.error is not None: return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)] ground_truths = response.ground_truth_list model_answer = response.completion is_correct = exact_match_hf_evaluate( predictions=[model_answer] * len(ground_truths), references=ground_truths, # type: ignore[arg-type] regexes_to_ignore=self.regexes_to_ignore, ignore_case=self.ignore_case, ignore_punctuation=self.ignore_punctuation, )["exact_match"] return [ MetricResult(metric_name=self.NAME, value=float(is_correct), higher_is_better=True, error=response.error) ]
[docs] class AccuracyCompletionOLMES(AccuracyCompletionWithEvaluate): # If we did a functools partial, code fails as there an issubclass check that # doesn't work with partial. These specific regexes are taken from # https://github.com/allenai/olmes/blob/main/oe_eval/tasks/oe_eval_tasks/gsm8k.py#L70 def __init__(self) -> None: super().__init__(regexes_to_ignore=[",", "\\$", "(?s).*#### ", "\\.$"])
# The following code is (largely) reproduced from https://github.com/allenai/olmes/blob/main/oe_eval/dependencies/hf_evaluate/exact_match.py#L25 # Olmes released under Apache 2.0 license and so is the HF evaluate library. # Some cosmetic modifications have been made to fit our codebase and linting rules. # ------------------------------------------------------------------------------------- ### Code ported from Huggingface's `evaluate` library at ### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py ### which is under the apache license. ### Port taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/metrics.py used ### to fix the issue: https://github.com/EleutherAI/lm-evaluation-harness/pull/2045 # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.
[docs] def exact_match_hf_evaluate( predictions: list[str], references: list[str], regexes_to_ignore: list[str] | None = None, ignore_case: bool = False, ignore_punctuation: bool = False, ignore_numbers: bool = False, ) -> dict[str, Any]: # type: ignore if regexes_to_ignore is not None: for s in regexes_to_ignore: predictions = np.array([re.sub(s, "", x) for x in predictions]) # type: ignore references = np.array([re.sub(s, "", x) for x in references]) # type: ignore else: predictions = np.asarray(predictions) # type: ignore references = np.asarray(references) # type: ignore if ignore_case: predictions = np.char.lower(predictions) # type: ignore references = np.char.lower(references) # type: ignore if ignore_punctuation: repl_table = string.punctuation.maketrans("", "", string.punctuation) predictions = np.char.translate(predictions, table=repl_table) # type: ignore references = np.char.translate(references, table=repl_table) # type: ignore if ignore_numbers: repl_table = string.digits.maketrans("", "", string.digits) predictions = np.char.translate(predictions, table=repl_table) # type: ignore references = np.char.translate(references, table=repl_table) # type: ignore # NOTE: For multiple ground-truths OLMES returns the mean over their scores. The max over # it would be more meaningful, but we leave it here for parity. score_list = predictions == references return {"exact_match": np.mean(score_list)}