Source code for eval_framework.tasks.benchmarks.balancedcopa
from datasets import Dataset, DatasetDict
from eval_framework.tasks.base import NO_SUBJECT, SubjectType
from eval_framework.tasks.benchmarks.copa import COPA
[docs]
def split_dataset_by_id_ranges(
dataset: Dataset, id_column: str, ranges: list[tuple[int, int]]
) -> tuple[Dataset, Dataset]:
"""Split a dataset into two based on whether the id column falls within given ranges.
Args:
dataset: The dataset to split.
id_column: The name of the column containing the id values.
ranges: A list of (low, high) tuples defining inclusive ranges.
Rows whose id is within any of these ranges go into the first split.
"""
def in_any_range(id_value: int) -> bool:
return any(low <= id_value <= high for low, high in ranges)
in_indices = [i for i, id_val in enumerate(dataset[id_column]) if in_any_range(id_val)]
not_in_indices = [i for i, id_val in enumerate(dataset[id_column]) if not in_any_range(id_val)]
return dataset.select(in_indices), dataset.select(not_in_indices)
[docs]
class BalancedCOPA(COPA):
"""Balanced-COPA dataset: https://huggingface.co/datasets/pkavumba/balanced-copa"""
NAME = "BalancedCOPA"
DATASET_PATH = "pkavumba/balanced-copa"
HF_REVISION = "813bd03cd6e07d9bd8d7333896ad5d40abb95ea9"
SUBJECTS = ["no_subject"]
def _split_dataset_into_train_and_val(self, dataset: DatasetDict) -> DatasetDict:
# We split the train data into train and validation splits so that
# the validation split matches the validation split of the original COPA dataset.
# These magic numbers of the ids below were arrived at after manual inspection of the dataset.
# Numbers 401-500 correspond to the validation split of the original COPA dataset.
# Numbers 1401-1500 correspond to the mirrored version of the val split.
# The sanity of this version is maintained by the HF_REVISION above.
dataset["validation"], dataset["train"] = split_dataset_by_id_ranges(
dataset["train"], "id", [(401, 500), (1401, 1500)]
)
return dataset
def _load_dataset(self, subject: SubjectType) -> None:
# This method largely reimplements the _load_dataset method in the base class,
# as the _shuffle_splits method drops any column not in FEWSHOT_SPLIT, SAMPLE_SPLIT.
# Thus, we need to split the dataset into train and validation splits before shuffling.
name = subject if subject != NO_SUBJECT else None
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
hf_dataset = self._split_dataset_into_train_and_val(hf_dataset)
self.dataset = self._shuffle_splits(hf_dataset=hf_dataset)