Source code for eval_framework.metrics.completion.grid_difference
import re
from eval_framework.metrics.base import BaseMetric, MetricResult
from eval_framework.shared.types import Completion
[docs]
class GridDifference(BaseMetric[Completion]):
NAME = "grid_difference"
[docs]
def count_differences(self, character_list_1: list[str], character_list_2: list[str]) -> int:
count = 0
for character_1, character_2 in zip(character_list_1, character_list_2):
if character_1 != character_2:
count += 1
return count
[docs]
def calculate_score(
self, output_ground_truth_difference_count: int, input_ground_truth_difference_count: int
) -> float:
if output_ground_truth_difference_count == 0 and input_ground_truth_difference_count == 0:
return 1.0
score = 1.0 - (float(output_ground_truth_difference_count) / float(input_ground_truth_difference_count))
return score
[docs]
def calculate(self, response: Completion) -> list[MetricResult]:
if response.error is not None:
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
input_grid = self.extract_grid_from_prompt(prompt=response.last_user_instruction).split()
output_grid = response.completion.split()
assert response.ground_truth_list[0], "Ground truth list is empty or not provided in the response."
ground_truth_grid = response.ground_truth_list[0].split()
input_ground_truth_differences_count = self.count_differences(input_grid, ground_truth_grid)
output_ground_truth_differences_count = self.count_differences(output_grid, ground_truth_grid)
exact_match = True
score = 1.0
normalized_score = 1.0
if output_ground_truth_differences_count != 0:
exact_match = False
score = self.calculate_score(
output_ground_truth_differences_count,
input_ground_truth_differences_count,
)
normalized_score = max(score, 0.0)
return [
MetricResult(
metric_name=f"{self.NAME}_exact_match",
value=float(exact_match),
higher_is_better=True,
error=response.error,
),
MetricResult(metric_name=f"{self.NAME}_score", value=score, higher_is_better=True, error=response.error),
MetricResult(
metric_name=f"{self.NAME}_normalized_score",
value=normalized_score,
higher_is_better=True,
error=response.error,
),
]