Source code for eval_framework.tasks.perturbation

import threading
from enum import Enum
from typing import Annotated, Any, TypeVar

from pydantic import BaseModel, ConfigDict, Field

from eval_framework.logger import logger
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Sample
from eval_framework.tasks.utils import Editor, HatPaperEditor


[docs] class PerturbationType(str, Enum): # Editor methods EDITOR = "editor" # Hat paper methods PERMUTE = "permute" REPLACE = "replace" DELETE = "delete" UPPERCASE = "uppercase"
[docs] class PerturbationConfig(BaseModel): model_config = ConfigDict(extra="forbid") type: PerturbationType = PerturbationType.EDITOR probability: Annotated[float, Field(ge=0.0, le=1.0)] = 0.1 seed: int = RANDOM_SEED verbose: bool = False
_DOCKER_LAUNCH_LOCK = threading.Lock() _AUGMENTER_PORT = 0 SomeBaseTask = TypeVar("SomeBaseTask", bound=BaseTask[Any])
[docs] def create_perturbation_class[T: BaseTask](base_class: type[T], perturbation_config: PerturbationConfig) -> type[T]: # mypy seems to have trouble inferring the type class EditorPerturbation(base_class): # type: ignore def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.perturbation_config = perturbation_config self.editor = Editor( language="de" if base_class.LANGUAGE == "German" else "en", seed=perturbation_config.seed ) def _get_instruction_text(self, sample: Sample) -> str: text = super()._get_instruction_text(sample) if self.perturbation_config.verbose: logger.info(f"Perturbating text: {text}") result = self.editor( text, self.perturbation_config.probability, getattr(self, "PERTURBATION_UNMODIFIABLE_WORDS", []) ) if self.perturbation_config.verbose: logger.info(f"Perturbed text: {result}") return result class HatPaperPerturbation(base_class): # type: ignore def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.perturbation_config = perturbation_config self.editor = HatPaperEditor(seed=perturbation_config.seed) def _get_instruction_text(self, sample: Sample) -> str: text = super()._get_instruction_text(sample) if self.perturbation_config.verbose: logger.info(f"Perturbating text: {text}") words = getattr(self, "PERTURBATION_UNMODIFIABLE_WORDS", []) if self.perturbation_config.type == PerturbationType.PERMUTE: result = self.editor.permute_chars_in_string(text, self.perturbation_config.probability, words) elif self.perturbation_config.type == PerturbationType.REPLACE: result = self.editor.replace_chars_in_string(text, self.perturbation_config.probability, words) elif self.perturbation_config.type == PerturbationType.DELETE: result = self.editor.delete_chars_in_string(text, self.perturbation_config.probability, words) elif self.perturbation_config.type == PerturbationType.UPPERCASE: result = self.editor.upper_case_string(text) if self.perturbation_config.verbose: logger.info(f"Perturbed text: {result}") return result if perturbation_config.type == PerturbationType.EDITOR: return EditorPerturbation else: return HatPaperPerturbation