Source code for eval_framework.tasks.benchmarks.squad

import os
import random
from pathlib import Path
from typing import Any

import requests
from datasets import Dataset, DatasetDict, DownloadConfig, load_dataset
from huggingface_hub import HfApi
from huggingface_hub.errors import RevisionNotFoundError

from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
from eval_framework.metrics.completion.f1 import F1
from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, BaseTask, Language, ResponseType, SubjectType


[docs] class SQUAD2(BaseTask[str]): """Squad v2 dataset: https://huggingface.co/datasets/rajpurkar/squad_v2""" NAME = "SQuAD2" DATASET_PATH = "rajpurkar/squad_v2" SAMPLE_SPLIT = "validation" FEWSHOT_SPLIT = "train" RESPONSE_TYPE = ResponseType.COMPLETION METRICS = [AccuracyCompletion, F1] SUBJECTS = [NO_SUBJECT] UNANSWERABLE_STR = "unanswerable" PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer", "Context", "unanswerable"] LANGUAGE = Language.ENG def __init__(self, num_fewshot: int = 0) -> None: super().__init__(num_fewshot) self.stop_sequences = [".\n"] self.max_tokens = 300 # the max length of the ground truth is 160 characters while the average is ~19 self.rnd_choice_shuffle = random.Random() def _get_squad_urls(self) -> dict[str, str]: """Get the URLs for this SQUAD version.""" return { "train": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v2.0.json", "validation": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v2.0.json", } def _load_hf_dataset(self, **kwargs: Any) -> Any: """Load SQUAD dataset, falling back to JSON if HF fails.""" # Validate HF revision if specified self._validate_hf_revision(kwargs.get("path", self.DATASET_PATH)) # Try HuggingFace first try: return self._load_from_huggingface(**kwargs) except ValueError as e: if "Feature type 'List' not found" in str(e): import warnings warnings.warn( f"Dataset {kwargs.get('path', self.DATASET_PATH)} has incompatible feature types " "(List instead of Sequence), loading directly from JSON files" ) return self._load_from_json(**kwargs) raise def _validate_hf_revision(self, dataset_path: str) -> None: """Validate HuggingFace revision if specified.""" if self.HF_REVISION: try: HfApi().dataset_info(repo_id=dataset_path, revision=self.HF_REVISION, timeout=100.0) except RevisionNotFoundError: raise def _load_from_huggingface(self, **kwargs: Any) -> Any: """Load dataset from HuggingFace.""" cache_dir = os.environ.get("HF_DATASET_CACHE_DIR", f"{Path.home()}/.cache/huggingface/datasets") download_config = DownloadConfig(cache_dir=cache_dir, max_retries=5) return load_dataset( **kwargs, revision=self.HF_REVISION, trust_remote_code=True, cache_dir=cache_dir, download_config=download_config, ) def _load_from_json(self, **kwargs: Any) -> Dataset | DatasetDict: """Load SQUAD directly from GitHub JSON files.""" urls = self._get_squad_urls() requested_split = kwargs.get("split") splits_to_load = [requested_split] if requested_split else list(urls.keys()) datasets = {} for split in splits_to_load: if split not in urls: continue dataset = self._download_and_parse_split(split, urls[split]) if dataset: datasets[split] = dataset if not datasets: raise ValueError(f"Failed to load any splits for {kwargs.get('path', self.DATASET_PATH)}") # Return single dataset or DatasetDict depending on what was requested return datasets[requested_split] if requested_split else DatasetDict(datasets) def _download_and_parse_split(self, split: str, url: str) -> Dataset | None: """Download and parse a single SQUAD split.""" try: # Download the data response = requests.get(url, timeout=30) response.raise_for_status() squad_data = response.json() # Flatten the nested structure examples = self._flatten_squad_data(squad_data) return Dataset.from_list(examples) except Exception as e: import warnings warnings.warn(f"Failed to download {split} split: {e}") return None def _flatten_squad_data(self, squad_data: dict) -> list[dict]: """Flatten nested SQUAD JSON structure into examples.""" examples = [] for article in squad_data["data"]: title = article["title"] for paragraph in article["paragraphs"]: context = paragraph["context"] for qa in paragraph["qas"]: example = { "id": qa["id"], "title": title, "context": context, "question": qa["question"], "answers": { "text": [answer["text"] for answer in qa.get("answers", [])], "answer_start": [answer["answer_start"] for answer in qa.get("answers", [])], }, } examples.append(example) return examples def _load_dataset(self, subject: SubjectType) -> None: name = subject if subject != NO_SUBJECT else None hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name) self.dataset = {} self.rnd = random.Random(RANDOM_SEED) for split, data in hf_dataset.items(): data_list = list(data) if split == self.SAMPLE_SPLIT: self.rnd.shuffle(data_list) if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]: self.dataset[split] = data_list def _get_instruction_text(self, item: dict[str, Any]) -> str: prompt = ( "Given the following context, answer the question. If the question cannot be answered based " f"on the context alone, respond with '{self.UNANSWERABLE_STR}'.\n\n" "Context:\n" f"{item['context']}\n\n" f"Question:\n{item['question']}\nAnswer:" ) return prompt def _get_ground_truth(self, item: dict[str, Any]) -> list[str]: text_ = item["answers"]["text"] ground_truth_for_unanswerable = [ self.UNANSWERABLE_STR, self.UNANSWERABLE_STR + " ", self.UNANSWERABLE_STR.capitalize(), ] ground_truths = text_ if text_ else ground_truth_for_unanswerable return [f" {ground_truth}" for ground_truth in ground_truths] def _get_fewshot_target_text(self, item: dict[str, Any]) -> str: target = self._get_ground_truth(item)[0] assert target is not None assert isinstance(target, str) return target
[docs] class SQUAD(SQUAD2): """Squad dataset: https://huggingface.co/datasets/rajpurkar/squad""" NAME = "SQuAD" DATASET_PATH = "rajpurkar/squad" def _get_squad_urls(self) -> dict[str, str]: """Override to provide SQUAD v1 URLs.""" return { "train": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v1.1.json", "validation": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v1.1.json", } def _get_instruction_text(self, item: dict[str, Any]) -> str: prompt = ( "Given the following context, answer the question.\n\n" "Context:\n" f"{item['context']}\n\n" f"Question:\n{item['question']}\n" ) return prompt def _get_ground_truth(self, item: dict[str, Any]) -> list[str]: return item["answers"]["text"]