import re
import nltk
from eval_framework.metrics.base import (
BaseMetric,
MetricResult,
)
from eval_framework.shared.types import BaseMetricContext, Completion, extract_context_metric
ALPHABETS = "([A-Za-z])"
PREFIXES = "(Mr|St|Mrs|Ms|Dr|www)[.]"
SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)"
STARTERS = (
r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
)
ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
WEBSITES = "[.](com|net|org|io|gov|edu|me)"
DIGITS = "([0-9])"
MULTIPLE_DOTS = r"\.{2,}"
[docs]
class WordCounterMetricContext(BaseMetricContext):
comparison: str
word_count: int
[docs]
class WordCounter(BaseMetric[Completion]):
NAME = "Word Count"
@staticmethod
def _count_words(text: str) -> int:
tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+")
tokens = tokenizer.tokenize(text)
num_words = len(tokens)
return num_words
[docs]
def calculate(self, response: Completion) -> list[MetricResult]:
if response.error is not None:
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
context = extract_context_metric(response, WordCounterMetricContext)
assert context.comparison in ["less than", "at least"], f"'comparison' is not valid: {context.comparison}"
num_words = self._count_words(response.completion)
if context.comparison == "less than":
valid_word_count = num_words < context.word_count
if context.comparison == "at least":
valid_word_count = num_words >= context.word_count
return [
MetricResult(
metric_name=self.NAME, value=float(valid_word_count), higher_is_better=True, error=response.error
)
]
[docs]
class SentenceCounterMetricContext(BaseMetricContext):
comparison: str
sentence_count: int
[docs]
class SentenceCounter(BaseMetric[Completion]):
NAME = "Sentence Count"
@staticmethod
def _count_sentences(text: str) -> int:
# Note that nltk.tokenize.sent_tokenize would be a straightforward alternative but is also not ideal. Example:
#
# "Mr. Jones gave me $10,000.00... And then he left. Numbers 5...10. Numbers 5..10. Review: bad food,
# bad service,..., so I'd miss it."
#
# this: ['Mr. Jones gave me $10,000.00...', 'And then he left.', 'Numbers 5...', '10.', 'Numbers 5..', '10.',
# 'Review: bad food, bad service,...', ", so I'd miss it."].
# nltk: ['Mr. Jones gave me $10,000.00... And then he left.', 'Numbers 5...10.',
# "Numbers 5..10. Review: bad food, bad service,..., so I'd miss it."]
text = f" {text} "
text = text.replace("\n", " ")
text = re.sub(PREFIXES, "\\1<prd>", text)
text = re.sub(WEBSITES, "<prd>\\1", text)
text = re.sub(DIGITS + "[.]" + DIGITS, "\\1<prd>\\2", text)
text = re.sub(
MULTIPLE_DOTS,
lambda match: "<prd>" * len(match.group(0)) + "<stop>",
text,
)
text = text.replace("Ph.D.", "Ph<prd>D<prd>")
text = re.sub(r"\s" + ALPHABETS + "[.] ", " \\1<prd> ", text)
text = re.sub(ACRONYMS + " " + STARTERS, "\\1<stop> \\2", text)
text = re.sub(
ALPHABETS + "[.]" + ALPHABETS + "[.]" + ALPHABETS + "[.]",
"\\1<prd>\\2<prd>\\3<prd>",
text,
)
text = re.sub(ALPHABETS + "[.]" + ALPHABETS + "[.]", "\\1<prd>\\2<prd>", text)
text = re.sub(" " + SUFFIXES + "[.] " + STARTERS, " \\1<stop> \\2", text)
text = re.sub(" " + SUFFIXES + "[.]", " \\1<prd>", text)
text = re.sub(" " + ALPHABETS + "[.]", " \\1<prd>", text)
text = text.replace(".”", "”.")
text = text.replace('."', '".')
text = text.replace('!"', '"!')
text = text.replace('?"', '"?')
text = text.replace(".", ".<stop>")
text = text.replace("?", "?<stop>")
text = text.replace("!", "!<stop>")
text = text.replace("<prd>", ".")
sentences = text.split("<stop>")
sentences = [s.strip() for s in sentences]
if sentences and not sentences[-1]:
sentences = sentences[:-1]
return len(sentences)
[docs]
def calculate(self, response: Completion) -> list[MetricResult]:
if response.error is not None:
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
context = extract_context_metric(response, SentenceCounterMetricContext)
assert context.comparison in ["less than", "at least"], f"'comparison' is not valid: {context.comparison}"
num_sentences = self._count_sentences(response.completion)
if context.comparison == "less than":
valid_sentence_count = num_sentences < context.sentence_count
elif context.comparison == "at least":
valid_sentence_count = num_sentences >= context.sentence_count
return [
MetricResult(
metric_name=self.NAME, value=float(valid_sentence_count), higher_is_better=True, error=response.error
)
]
[docs]
class ParagraphCounterMetricContext(BaseMetricContext):
comparison: str
paragraph_count: int
[docs]
class ParagraphCounter(BaseMetric[Completion]):
NAME = "Paragraph Count"
@staticmethod
def _count_paragraphs(text: str) -> int:
paragraphs = re.split(r"\s?\n\n\s?", text)
return len(paragraphs)
[docs]
def calculate(self, response: Completion) -> list[MetricResult]:
if response.error is not None:
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
context = extract_context_metric(response, ParagraphCounterMetricContext)
assert context.comparison in ["less than", "at least"], f"'comparison' is not valid: {context.comparison}"
num_paragraphs = self._count_paragraphs(response.completion)
if context.comparison == "less than":
valid_paragraph_count = num_paragraphs < context.paragraph_count
elif context.comparison == "at least":
valid_paragraph_count = num_paragraphs >= context.paragraph_count
return [
MetricResult(
metric_name=self.NAME, value=float(valid_paragraph_count), higher_is_better=True, error=response.error
)
]
[docs]
class ResponseToOriginalLengthRatio(BaseMetric[Completion]):
NAME = "Response to Original Length Ratio"
[docs]
def calculate(self, response: Completion) -> list[MetricResult]:
if response.error is not None:
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
len_original = len(response.last_user_instruction)
if len_original > 0:
score = len(response.completion) / len_original
return [MetricResult(metric_name=self.NAME, value=score, higher_is_better=False, error=response.error)]
else:
return []