Source code for eval_framework.metrics.completion.length_control

import json
from enum import Enum

from eval_framework.metrics.base import BaseMetric, MetricResult
from eval_framework.metrics.completion.text_counter import ParagraphCounter, SentenceCounter, WordCounter
from eval_framework.shared.types import Completion


[docs] class LengthRequirementUnit(Enum): WORDS = "words" SENTENCES = "sentences" PARAGRAPHS = "paragraphs"
[docs] class LengthRequirementType(Enum): MIN = "minimum" MAX = "maximum" TARGET = "target"
[docs] class LengthControl(BaseMetric[Completion]): NAME = "length_control" def __init__(self, tolerance: float = 1 / 6) -> None: super().__init__() self.tolerance = tolerance
[docs] def calculate(self, response: Completion) -> list[MetricResult]: if response.error is not None: return [ MetricResult( metric_name=f"{self.NAME}/fulfills_length_requirement", value=None, higher_is_better=True, error=response.error if response.error is not None else None, ) ] expectations = json.loads(str(response.ground_truth)) stripped_completion = response.completion.strip() match LengthRequirementUnit(expectations["unit"]): case LengthRequirementUnit.WORDS: count = WordCounter._count_words(stripped_completion) case LengthRequirementUnit.SENTENCES: count = SentenceCounter._count_sentences(stripped_completion) case LengthRequirementUnit.PARAGRAPHS: count = ParagraphCounter._count_paragraphs(stripped_completion) case _: raise NotImplementedError(f"LengthRequirementUnit {expectations['unit']} is not supported.") expected_count = int(expectations["count"]) normalized_distance_to_target = (count - expected_count) / float(expected_count) absolute_normalized_distance_to_target = abs(normalized_distance_to_target) match LengthRequirementType(expectations["type"]): case LengthRequirementType.TARGET: fulfills_length_requirement = absolute_normalized_distance_to_target <= self.tolerance case LengthRequirementType.MIN: fulfills_length_requirement = count >= expected_count case LengthRequirementType.MAX: fulfills_length_requirement = count <= expected_count case _: raise NotImplementedError(f"LengthRequirementType {expectations['type']} is not supported.") return [ MetricResult( metric_name=f"{self.NAME}/normalized_distance_to_target", value=float(normalized_distance_to_target), higher_is_better=False, ), MetricResult( metric_name=f"{self.NAME}/absolute_normalized_distance_to_target", value=float(absolute_normalized_distance_to_target), higher_is_better=False, ), MetricResult( metric_name=f"{self.NAME}/fulfills_length_requirement", value=float(fulfills_length_requirement), higher_is_better=True, error=response.error, ), ]