Source code for eval_framework.tasks.benchmarks.flores_plus
import random
from itertools import product
from typing import Any
from eval_framework.metrics.completion.bleu import BLEU
from eval_framework.metrics.completion.chrf import CHRF
from eval_framework.metrics.completion.comet import COMET
from eval_framework.shared.types import BaseMetricContext, UntemplatedPrompt
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
LANG_MAP = {
"deu_Latn": "German",
"eng_Latn": "English",
"fra_Latn": "French",
"ita_Latn": "Italian",
"nld_Latn": "Dutch",
"pol_Latn": "Polish",
"rus_Cyrl": "Russian",
"spa_Latn": "Spanish",
"ukr_Cyrl": "Ukrainian",
}
[docs]
class FloresPlus(BaseTask[str]):
"""Flores-Plus dataset: https://huggingface.co/datasets/openlanguagedata/flores_plus"""
NAME = "Flores-Plus"
DATASET_PATH = "openlanguagedata/flores_plus"
SAMPLE_SPLIT = "dev"
FEWSHOT_SPLIT = "devtest"
RESPONSE_TYPE = ResponseType.COMPLETION
METRICS = [BLEU, CHRF, COMET]
SUBJECTS = [f"{s}-{t}" for s, t in product(LANG_MAP, LANG_MAP) if s != t]
PERTURBATION_UNMODIFIABLE_WORDS = ["sentence"]
LANGUAGE = {
"deu_Latn": Language.DEU,
"eng_Latn": Language.ENG,
"fra_Latn": Language.FRA,
"ita_Latn": Language.ITA,
"nld_Latn": Language.NLD,
"pol_Latn": Language.POL,
"rus_Cyrl": Language.RUS,
"spa_Latn": Language.SPA,
"ukr_Cyrl": Language.UKR,
}
def __init__(self, num_fewshot: int = 0) -> None:
super().__init__(num_fewshot)
self.stop_sequences = ["\n"]
def _load_dataset(self, subject: str) -> None:
hf_dataset_src = self._load_hf_dataset(path=self.DATASET_PATH, name=subject.split("-")[0])
hf_dataset_tgt = self._load_hf_dataset(path=self.DATASET_PATH, name=subject.split("-")[1])
self.dataset = {}
self.rnd = random.Random(42)
for split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
data_src = hf_dataset_src[split]
data_tgt = hf_dataset_tgt[split]
data_list = []
for item_src, item_tgt in zip(data_src, data_tgt):
assert item_src["id"] == item_tgt["id"]
iso_src = f"{item_src['iso_639_3']}_{item_src['iso_15924']}"
iso_tgt = f"{item_tgt['iso_639_3']}_{item_tgt['iso_15924']}"
text_src = item_src["text"]
text_tgt = item_tgt["text"]
data_list.append({"iso_source": iso_src, "iso_target": iso_tgt, "source": text_src, "target": text_tgt})
if split == self.SAMPLE_SPLIT:
self.rnd.shuffle(data_list)
self.dataset[split] = data_list
def _get_instruction_text(self, item: dict[str, Any]) -> str:
target_language = LANG_MAP[item["iso_target"]]
instruction = f"Translate the following text into {target_language}:\n{item['source']}"
return instruction
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
return item["target"]
def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMetricContext] | None:
return UntemplatedPrompt(untemplated_prompt=item["source"])
[docs]
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
return completion_text.strip()