Source code for eval_framework.tasks.benchmarks.mmlu

import re
from typing import Any

from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
    AccuracyLoglikelihood,
    AccuracyNormLoglikelihood,
)
from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
from eval_framework.metrics.loglikelihood.ternary import TernaryScore
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
from eval_framework.tasks.utils import get_n_letters

MMLU_SUBJECTS = [
    "abstract_algebra",
    "anatomy",
    "astronomy",
    "business_ethics",
    "clinical_knowledge",
    "college_biology",
    "college_chemistry",
    "college_computer_science",
    "college_mathematics",
    "college_medicine",
    "college_physics",
    "computer_security",
    "conceptual_physics",
    "econometrics",
    "electrical_engineering",
    "elementary_mathematics",
    "formal_logic",
    "global_facts",
    "high_school_biology",
    "high_school_chemistry",
    "high_school_computer_science",
    "high_school_european_history",
    "high_school_geography",
    "high_school_government_and_politics",
    "high_school_macroeconomics",
    "high_school_mathematics",
    "high_school_microeconomics",
    "high_school_physics",
    "high_school_psychology",
    "high_school_statistics",
    "high_school_us_history",
    "high_school_world_history",
    "human_aging",
    "human_sexuality",
    "international_law",
    "jurisprudence",
    "logical_fallacies",
    "machine_learning",
    "management",
    "marketing",
    "medical_genetics",
    "miscellaneous",
    "moral_disputes",
    "moral_scenarios",
    "nutrition",
    "philosophy",
    "prehistory",
    "professional_accounting",
    "professional_law",
    "professional_medicine",
    "professional_psychology",
    "public_relations",
    "security_studies",
    "sociology",
    "us_foreign_policy",
    "virology",
    "world_religions",
]


[docs] class MMLU(BaseTask[str]): """MMLU dataset: https://huggingface.co/datasets/cais/mmlu""" NAME = "MMLU" DATASET_PATH = "cais/mmlu" SAMPLE_SPLIT = "test" FEWSHOT_SPLIT = "dev" RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood] SUBJECTS = MMLU_SUBJECTS PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer"] + get_n_letters(4) LANGUAGE = Language.ENG def __init__(self, num_fewshot: int = 0) -> None: super().__init__(num_fewshot) self.keys = get_n_letters(4) def _get_subject_name(self, item: dict[str, Any]) -> str: return " ".join(item["subject"].split("_")) def _get_initial_prompt_text(self, item: dict[str, Any]) -> str: return f"The following are multiple choice questions (with answers) about {self._get_subject_name(item)}." def _get_instruction_text(self, item: dict[str, Any]) -> str: question = item["question"].strip() choices = "".join([f"{key}. {choice}\n" for key, choice in zip(self.keys, item["choices"])]) return f"Question: {question}\n{choices}" def _get_fewshot_target_text(self, item: dict[str, Any]) -> str: ground_truth = self._get_ground_truth(item) assert ground_truth is not None return f"{self._get_cue_text(item)}{ground_truth}" def _get_cue_text(self, item: dict[str, Any]) -> str: return "Answer:" def _get_ground_truth(self, item: dict[str, Any]) -> str | None: return f" {self.keys[item['answer']]}" def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None: return [f" {key}" for key in self.keys]
[docs] class FullTextMMLU(MMLU): """MMLU dataset but where the model is expected to replicate choice text, rather than just the key.""" NAME = "Full Text MMLU" PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "answers"] + get_n_letters(4) def _get_initial_prompt_text(self, item: dict[str, Any]) -> str: subject_name = self._get_subject_name(item) return f"""The following are multiple choice questions (with possible answers) about {subject_name}. Answer with the full text of the correct answer.""" def _get_instruction_text(self, item: dict[str, Any]) -> str: question = item["question"].strip() choices = "".join([f"- {choice}\n" for choice in item["choices"]]) return f"Question: {question}\nPossible answers:\n{choices}" def _get_ground_truth(self, item: dict[str, Any]) -> str | None: return f" {item['choices'][item['answer']]}" def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None: return [f" {choice}" for choice in item["choices"]]
[docs] class MMLU_IDK(MMLU): NAME = "MMLU_IDK" METRICS = [ AccuracyLoglikelihood, AccuracyNormLoglikelihood, ConfidenceWeightedAccuracy, DistributionalCorrectnessScore, TernaryScore, ] def _get_initial_prompt_text(self, item: dict[str, Any]) -> str: return ( f"The following are multiple choice questions (with answers) about {item['subject']}. " "Answer only if you are confident, since mistakes may be penalised, while correct answers receive points. " "It is acceptable to answer with '?' if you are unsure, and you will receive 0 points." ) def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None: completions = super()._get_possible_completions(item) return (completions or []) + [" ?"]
[docs] class MMLU_COT(MMLU): """ MMLU dataset with instruction to summarize reasoning and conclude with answer. Inspired by https://arxiv.org/pdf/2411.15124 (Table 44) """ NAME = "MMLU_COT" RESPONSE_TYPE = ResponseType.COMPLETION METRICS = [AccuracyCompletion] PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Therefore", "the", "answer", "is", "ANSWER_LETTER"] + get_n_letters( 4 ) ANS_RE = re.compile(r"Therefore, the answer is: ([ABCD])") def __init__(self, num_fewshot: int = 0) -> None: assert num_fewshot == 0, "Fewshot is not supported for MMLU_COT" super().__init__(num_fewshot) self.stop_sequences: list[str] = ["Question:"] def _extract_answer(self, completion: str) -> str: match = self.ANS_RE.search(completion) if match: match_str = match.group(1) return match_str else: return "[invalid]" def _get_cue_text(self, item: dict[str, Any]) -> str: return "" def _get_ground_truth(self, item: dict[str, Any]) -> str | None: return self.keys[item["answer"]]
[docs] def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str: for stop_sequence in self.stop_sequences: if stop_sequence in completion_text: completion_text = completion_text.split(stop_sequence)[0] return self._extract_answer(completion_text)
def _get_instruction_text(self, item: dict[str, Any]) -> str: question = item["question"].strip() choices = "\n".join([f"{key}. {choice}" for key, choice in zip(self.keys, item["choices"])]) return f"Question: {question}\n{choices}" def _get_initial_prompt_text(self, item: dict[str, Any]) -> str: return ( f"The following are multiple choice questions about {self._get_subject_name(item)}. " 'Summarize your reasoning concisely, then conclude with "Therefore, the answer is: X", where X is ' "one of A, B, C, or D." )