Source code for eval_framework.tasks.benchmarks.casehold
import random
from typing import Any
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
AccuracyLoglikelihood,
AccuracyNormLoglikelihood,
)
from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, BaseTask, Language, ResponseType
[docs]
class CASEHOLD(BaseTask[str]):
NAME = "CaseHold"
DATASET_PATH = "lex_glue"
SAMPLE_SPLIT = "test"
FEWSHOT_SPLIT = "train"
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
SUBJECTS = ["case_hold"]
LANGUAGE = Language.ENG
def _load_dataset(self, subject: str) -> None:
name = subject if subject != NO_SUBJECT else None
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
self.dataset = {}
self.rnd = random.Random(RANDOM_SEED)
for split, data in hf_dataset.items():
data_list = list(data)
if split == self.SAMPLE_SPLIT:
self.rnd.shuffle(data_list)
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
self.dataset[split] = [i for i in data_list if i["context"].count("(<HOLDING>)") == 1]
def _get_instruction_text(self, item: dict[str, Any]) -> str:
return item["context"].split("(<HOLDING>)", maxsplit=1)[0]
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
right = item["context"].split("(<HOLDING>)", maxsplit=1)[1]
return f"{item['endings'][item['label']]}{right}"
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
right = item["context"].split("(<HOLDING>)", maxsplit=1)[1]
return [f"{ending}{right}" for ending in item["endings"]]