Source code for eval_framework.tasks.benchmarks.triviaqa
import random
from typing import Any
from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
from eval_framework.metrics.completion.f1 import F1
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
[docs]
class TRIVIAQA(BaseTask[str]):
"""Trivia QA dataset: https://huggingface.co/datasets/mandarjoshi/trivia_qa"""
NAME = "TriviaQA"
DATASET_PATH = "mandarjoshi/trivia_qa"
SAMPLE_SPLIT = "validation"
FEWSHOT_SPLIT = "train"
RESPONSE_TYPE = ResponseType.COMPLETION
METRICS = [AccuracyCompletion, F1]
SUBJECTS = ["rc.wikipedia.nocontext"]
PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer"]
LANGUAGE = Language.ENG
def __init__(self, num_fewshot: int = 0) -> None:
super().__init__(num_fewshot)
self.stop_sequences = ["\n"]
self.max_tokens = 400 # the max length of the ground truth is 282 characters while the average is ~16
self.rnd_choice_shuffle = random.Random()
def _get_instruction_text(self, item: dict[str, Any]) -> str:
prompt = f"Question: {item['question'].strip()}\nAnswer:"
return prompt
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
target = self._get_ground_truth(item)[0]
assert target is not None
assert isinstance(target, str)
return f" {target}"
def _get_ground_truth(self, item: dict[str, Any]) -> list[str]:
return item["answer"]["aliases"]
[docs]
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
return completion_text.strip().rstrip(".")