Source code for eval_framework.tasks.registry

import contextlib
import importlib.util
import re
from collections.abc import Generator, Iterator, Sequence
from typing import Annotated, Any

import pydantic
from pydantic import AfterValidator

from eval_framework.tasks.base import BaseTask
from eval_framework.utils.packaging import is_extra_installed, validate_package_extras

__all__ = [
    "register_task",
    "register_lazy_task",
    "Registry",
    "with_registry",
    "get_task",
    "registered_tasks_iter",
    "is_registered",
    "validate_task_name",
    "registered_task_names",
]


def validate_import_path(import_path: str) -> str:
    if importlib.util.find_spec(import_path) is None:
        raise ValueError(f"Invalid import path: {import_path}")
    return import_path


class TaskPlaceholder(pydantic.BaseModel, extra="forbid", frozen=True):
    name: Annotated[
        str,
        "The name of the Task class that we want to import",
    ]
    module: Annotated[
        str,
        "The module from where to import the task",
        validate_import_path,
    ]
    extras: Annotated[
        tuple[str, ...],
        "Extra dependencies that are required for the task",
        AfterValidator(validate_package_extras),
    ] = ()

    def load(self) -> type[BaseTask]:
        for extra in self.extras:
            if not is_extra_installed(extra):
                raise ImportError(f"The required package eval_framework[{extra}] is not installed.")
        module = importlib.import_module(self.module)
        return getattr(module, self.name)


[docs] class Registry: """A registry for tasks with support for lazy loading. Task names are hashed based on the upper-case name, to avoid issues with ambiguous naming. """ def __init__(self) -> None: # TODO: Lookup only with upper names self._registry: dict[str, tuple[str, type[BaseTask] | TaskPlaceholder]] = dict() def __iter__(self) -> Iterator[str]: for name, _ in self._registry.values(): yield name @staticmethod def _task_key(name: str, /) -> str: name = re.sub(r"[\s\-_]+", "", name).upper() if not name.isalnum(): raise ValueError( f"Task name '{name}' contains invalid characters. Only alphanumeric characters are allowed." ) return name def __contains__(self, name: str) -> bool: task_key = self._task_key(name) return task_key in self._registry def __getitem__(self, name: str, /) -> type[BaseTask]: task_key = self._task_key(name) try: name, task = self._registry[task_key] except KeyError: raise KeyError(f"Task not found: {name}") if isinstance(task, TaskPlaceholder): task = task.load() self._registry[task_key] = (name, task) return task
[docs] def add(self, task: type[BaseTask]) -> None: task_key = self._task_key(task.NAME) self._registry[task_key] = (task.NAME, task)
def __setitem__(self, name: str, task: type[BaseTask] | TaskPlaceholder) -> None: task_key = self._task_key(name) if task_key in self._registry: raise ValueError(f"Cannot register duplicate task with key: {task_key}") self._registry[task_key] = (name, task)
_REGISTRY = Registry()
[docs] @contextlib.contextmanager def with_registry(registry: Registry) -> Generator[None, Any, None]: """Contextmanager to change the current registry.""" global _REGISTRY old_registry = _REGISTRY try: _REGISTRY = registry yield finally: _REGISTRY = old_registry
[docs] def registered_task_names() -> list[str]: """Return the names of all registered tasks.""" return list(_REGISTRY)
[docs] def is_registered(name: str, /) -> bool: """Return True if a task is registered.""" return name in _REGISTRY
[docs] def validate_task_name(name: str) -> str: """Pydantic-style validator for task names.""" if not is_registered(name): raise ValueError(f"Task not registered: {name}") return name
[docs] def registered_tasks_iter() -> Iterator[tuple[str, type[BaseTask]]]: """Iterate over the names and classes of all registered tasks. Note: This method will import any lazily registered task. """ for name in registered_task_names(): yield name, get_task(name)
[docs] def get_task(name: str, /) -> type[BaseTask]: """Return a registered task for a given name. Note: This method will import any lazily registered task. """ return _REGISTRY[name]
[docs] def register_task(task: type[BaseTask]) -> str: """The class name is used as the task name.""" if not issubclass(task, BaseTask): raise ValueError(f"Can only register subclasses of BaseTask, got {task}") name = task.__name__ _REGISTRY[name] = task return name
[docs] def register_lazy_task(class_path: str, /, *, extras: Sequence[str] = ()) -> None: """Register a task without importing it. Lazily register a task without importing the module. Args: class_path: The full path to the task class. For example, `eval_framework.tasks.benchmarks.truthfulqa.TRUTHFULQA`. extras: Any extra dependencies of `eval_framework` that need to be installed for this task. """ if isinstance(extras, str): extras = [extras] if "." not in class_path: raise ValueError( f"Invalid class path `{class_path}`. This needs to be a global path like " "`eval_framework.tasks.benchmarks.truthfulqa.TRUTHFULQA`): " ) base_module, class_name = class_path.rsplit(".", maxsplit=1) placeholder = TaskPlaceholder(name=class_name, module=base_module, extras=extras) _REGISTRY[class_name] = placeholder