Source code for eval_framework.metrics.completion.repetition

import re
from collections import Counter
from collections.abc import Sequence
from typing import Final

from eval_framework.metrics.base import BaseMetric, MetricResult
from eval_framework.shared.types import Completion


[docs] class WordRepetition(BaseMetric[Completion]): """Word Repetition Metric This metric checks for repetitions of words in the completion text for a given window size and repetition threshold. The window size defines the consecutive word count to consider a repetition, and min_repetitions specifies the minimum repetition count that triggers the metric. This metric returns 0.0 if no repetitions are found, and 1.0 if a sufficient number of repetitions are found. For example, if the completion contains a two-word sequence that repeats once (such as "hello world hello world"), this metric would trigger with a window size of 2 and min_repetitions set to 1. """ NAME = "WordRepetition" HIGHER_IS_BETTER: Final[bool] = False def __init__(self, window_size: int = 128, min_repetitions: int = 1) -> None: """ Initialize the WordRepetition metric. Args: window_size (int): The number of consecutive words to consider as a sequence. min_repetitions (int): The minimum number of times a sequence must repeat to be considered a repetition. Set to 1 to catch any repetition. """ super().__init__() self.window_size = window_size self.min_repetitions = min_repetitions if self.min_repetitions < 1: raise ValueError("min_repetitions must be at least 1") if self.window_size < 1: raise ValueError("window_size must be at least 1")
[docs] def calculate(self, response: Completion) -> list[MetricResult]: if response.error is not None: return [ MetricResult( metric_name=self.NAME, value=None, higher_is_better=self.HIGHER_IS_BETTER, error=response.error, ) ] has_repetition = _has_repetition( text=response.completion, window_size=self.window_size, min_repetitions=self.min_repetitions, ) return [ MetricResult( metric_name=self.NAME, value=float(has_repetition), higher_is_better=self.HIGHER_IS_BETTER, error=response.error, ) ]
def _has_repetition(text: str, window_size: int, min_repetitions: int) -> bool: """Check if the text contains any word sequences of a given size that repeat""" sequences = _word_sequences(_to_words(text), window_size) counts = Counter(sequences) return any([count > min_repetitions for count in counts.values()]) def _to_words(text: str) -> Sequence[str]: """A somewhat crude function to tokenize a string into words.""" return re.findall(r"\b\w+\b", text, re.UNICODE) def _word_sequences(text_words: Sequence[str], window_size: int) -> Sequence[Sequence[str]]: """Get all contiguous sub-sequences of a given size from a word sequence.""" return [tuple(text_words[i : i + window_size]) for i in range(len(text_words) - window_size + 1)]