Source code for eval_framework.tasks.registry

import contextlib
import importlib
import re
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterator, Sequence
from typing import TYPE_CHECKING, Any

from eval_framework.tasks.base import BaseTask, ResponseType
from eval_framework.tasks.perturbation import PerturbationConfig, create_perturbation_class
from eval_framework.utils.packaging import is_extra_installed, validate_package_extras

if TYPE_CHECKING:
    from eval_framework.metrics.base import BaseMetric

__all__ = [
    "register_task",
    "register_lazy_task",
    "EvalFactory",
    "Registry",
    "with_registry",
    "get_task",
    "registered_tasks_iter",
    "is_registered",
    "validate_task_name",
    "registered_task_names",
]


[docs] class EvalFactory(ABC): """Produces a registered benchmark's eval. The registry stores one factory per eval. This allows the factory to be constructed without constructing all evals. Going via this ABC allows the factory instances to contain state specifically relevant to the eval, as well as supporting different strategies for instantiating it. E.g. eager vs lazy loading of the required dependencies. """
[docs] @abstractmethod def task_class(self) -> type[BaseTask]: """Return the task class, importing it on first access if necessary."""
@property @abstractmethod def source_module(self) -> str: """Module the task class is defined in, resolvable without importing it."""
[docs] @abstractmethod def response_type(self) -> ResponseType: """The eval's response type"""
[docs] @abstractmethod def metrics(self) -> list[type["BaseMetric"]]: """The eval's metrics"""
[docs] @abstractmethod def display_name(self) -> str: """Human-readable display name. Is allowed to have special characters and whitespaces."""
[docs] @abstractmethod def create(
self, num_fewshot: int, custom_subjects: list[str] | None, custom_hf_revision: str | None ) -> BaseTask: ...
[docs] @abstractmethod def create_perturbation(
self, perturbation_config: PerturbationConfig, num_fewshot: int, custom_subjects: list[str] | None, custom_hf_revision: str | None, ) -> BaseTask: ...
class _Lazy(EvalFactory): """ Create eval from qualified class path; Delays importing modules until eval is constructed. """ def __init__(self, class_name: str, module: str, extras: Sequence[str] = ()) -> None: """ Args: class_name: The name of the task class to import. module: The module to import the task class from. extras: Extra dependencies of `eval_framework` required for this task. """ self._class_name = class_name self._module = module self._extras = tuple(validate_package_extras(extras)) self._loaded: type[BaseTask] | None = None @property def source_module(self) -> str: return self._module def task_class(self) -> type[BaseTask]: if self._loaded is None: for extra in self._extras: if not is_extra_installed(extra): raise ImportError(f"The required package eval_framework[{extra}] is not installed.") module = importlib.import_module(self._module) self._loaded = getattr(module, self._class_name) return self._loaded def create(self, num_fewshot: int, custom_subjects: list[str] | None, custom_hf_revision: str | None) -> BaseTask: return self.task_class().with_overwrite( num_fewshot=num_fewshot, custom_subjects=custom_subjects, custom_hf_revision=custom_hf_revision ) def create_perturbation( self, perturbation_config: PerturbationConfig, num_fewshot: int, custom_subjects: list[str] | None, custom_hf_revision: str | None, ) -> BaseTask: perturbation_task_class = create_perturbation_class(self.task_class(), perturbation_config) return perturbation_task_class.with_overwrite( num_fewshot=num_fewshot, custom_subjects=custom_subjects, custom_hf_revision=custom_hf_revision, ) def response_type(self) -> ResponseType: """The eval's response type""" return self.task_class().get_response_type() def metrics(self) -> list[type["BaseMetric"]]: """The eval's metrics""" return self.task_class().get_metrics() def display_name(self) -> str: """The eval's human-readable display name (the task's ``NAME``).""" return self.task_class().NAME class _Eager(EvalFactory): """Wraps an already-imported task class.""" def __init__(self, task: type[BaseTask]) -> None: self._task = task @property def source_module(self) -> str: return self._task.__module__ def task_class(self) -> type[BaseTask]: return self._task def create(self, num_fewshot: int, custom_subjects: list[str] | None, custom_hf_revision: str | None) -> BaseTask: return self.task_class().with_overwrite( num_fewshot=num_fewshot, custom_subjects=custom_subjects, custom_hf_revision=custom_hf_revision ) def create_perturbation( self, perturbation_config: PerturbationConfig, num_fewshot: int, custom_subjects: list[str] | None, custom_hf_revision: str | None, ) -> BaseTask: perturbation_task_class = create_perturbation_class(self.task_class(), perturbation_config) return perturbation_task_class.with_overwrite( num_fewshot=num_fewshot, custom_subjects=custom_subjects, custom_hf_revision=custom_hf_revision, ) def response_type(self) -> ResponseType: """The eval's response type""" return self.task_class().get_response_type() def metrics(self) -> list[type["BaseMetric"]]: """The eval's metrics""" return self.task_class().get_metrics() def display_name(self) -> str: """The eval's human-readable display name (the task's ``NAME``).""" return self.task_class().NAME
[docs] class Registry: """A registry for tasks with support for lazy loading. Task names are hashed based on the upper-case name, to avoid issues with ambiguous naming. """ def __init__(self) -> None: # TODO: Lookup only with upper names self._registry: dict[str, tuple[str, EvalFactory]] = dict() def __iter__(self) -> Iterator[str]: for name, _ in self._registry.values(): yield name @staticmethod def _task_key(name: str, /) -> str: name = re.sub(r"[\s\-_]+", "", name).upper() if not name.isalnum(): raise ValueError( f"Task name '{name}' contains invalid characters. Only alphanumeric characters are allowed." ) return name def __contains__(self, name: str) -> bool: task_key = self._task_key(name) return task_key in self._registry def __getitem__(self, name: str, /) -> EvalFactory: task_key = self._task_key(name) try: _, factory = self._registry[task_key] except KeyError: raise KeyError(f"Task not found: {name=} with task_key {task_key=}") return factory
[docs] def add(self, task: type[BaseTask]) -> None: task_key = self._task_key(task.NAME) self._registry[task_key] = (task.NAME, _Eager(task))
def __setitem__(self, name: str, factory: EvalFactory) -> None: task_key = self._task_key(name) if task_key in self._registry: raise ValueError(f"Cannot register duplicate task with key: {task_key}") self._registry[task_key] = (name, factory)
_REGISTRY = Registry() def registry() -> Registry: return _REGISTRY
[docs] @contextlib.contextmanager def with_registry(registry: Registry) -> Generator[None, Any, None]: """Contextmanager to change the current registry.""" global _REGISTRY old_registry = _REGISTRY try: _REGISTRY = registry yield finally: _REGISTRY = old_registry
[docs] def registered_task_names() -> list[str]: """Return the names of all registered tasks.""" return list(_REGISTRY)
[docs] def is_registered(name: str, /) -> bool: """Return True if a task is registered.""" return name in _REGISTRY
[docs] def validate_task_name(name: str) -> str: """Pydantic-style validator for task names.""" if not is_registered(name): raise ValueError(f"Task not registered: {name}") return name
[docs] def registered_tasks_iter() -> Iterator[tuple[str, type[BaseTask]]]: """Iterate over the names and classes of all registered tasks. Note: This method will import any lazily registered task. """ for name in registered_task_names(): yield name, get_task(name)
[docs] def get_task(name: str, /) -> type[BaseTask]: """Return a registered task for a given name. Note: This method will import any lazily registered task. """ return _REGISTRY[name].task_class()
[docs] def register_task(task: type[BaseTask]) -> str: """The class name is used as the task name.""" if not issubclass(task, BaseTask): raise ValueError(f"Can only register subclasses of BaseTask, got {task}") name = task.__name__ _REGISTRY[name] = _Eager(task) return name
[docs] def register_lazy_task(class_path: str, /, *, extras: Sequence[str] = ()) -> None: """Register a task without importing it. Lazily register a task without importing the module. Args: class_path: The full path to the task class. For example, `eval_framework.tasks.benchmarks.truthfulqa.TRUTHFULQA`. extras: Any extra dependencies of `eval_framework` that need to be installed for this task. """ if isinstance(extras, str): extras = [extras] if "." not in class_path: raise ValueError( f"Invalid class path `{class_path}`. This needs to be a global path like " "`eval_framework.tasks.benchmarks.truthfulqa.TRUTHFULQA`): " ) base_module, class_name = class_path.rsplit(".", maxsplit=1) _REGISTRY[class_name] = _Lazy(class_name=class_name, module=base_module, extras=extras)