Source code for eval_framework.tasks.dataset_revisions
"""Fetch latest Hugging Face dataset commit SHAs for registered tasks.
Overwrites ``task-dataset-revisions.json`` in this package with task class name → SHA.
Usage::
uv run python -m eval_framework.tasks.benchmarks.dataset_revisions
"""
import json
import logging
from functools import lru_cache
from pathlib import Path
from huggingface_hub import HfApi
logger = logging.getLogger(__name__)
DEFAULT_REVISIONS_FILE = Path(__file__).resolve().parent / "task-dataset-revisions.json"
REVISIONS_FILE = DEFAULT_REVISIONS_FILE
@lru_cache
def _pinned_revisions(revisions_file: Path) -> dict[str, str]:
return json.loads(revisions_file.read_text(encoding="utf-8"))
[docs]
class DatasetRevision:
_INSTANCE: "DatasetRevision | None" = None
def __init__(self) -> None:
self._cache: dict[str, str] = {}
@classmethod
def _get_instance(cls) -> "DatasetRevision":
if cls._INSTANCE is None:
cls._INSTANCE = cls()
return cls._INSTANCE
[docs]
@classmethod
def add_revision_file(cls, file_path: Path | str) -> None:
instance = cls._get_instance()
instance._append_revision_file(Path(file_path))
[docs]
@classmethod
def pinned_revision(cls, task_class_name: str) -> str | None:
return cls._get_instance()._cache.get(task_class_name)
[docs]
@classmethod
def reset(cls) -> None:
# for unit tests only.
cls._INSTANCE = None
def _append_revision_file(self, file_path: Path) -> None:
revisions = _pinned_revisions(file_path)
self._cache |= revisions
def _repo_sha(api: HfApi, repo_id: str, cache: dict[str, str | None]) -> str | None:
if repo_id in cache:
return cache[repo_id]
try:
cache[repo_id] = api.dataset_info(repo_id, timeout=100.0).sha
logger.info("%s -> %s", repo_id, cache[repo_id])
except Exception as exc:
logger.warning("Skipping %s: %s", repo_id, exc)
cache[repo_id] = None
return cache[repo_id]
[docs]
def collect_dataset_revisions(
task_names: list[str],
api: HfApi,
) -> dict[str, str]:
"""Return task class name → latest dataset commit SHA for tasks with a Hugging Face path."""
from eval_framework.tasks.registry import get_task
cache: dict[str, str | None] = {}
revisions: dict[str, str] = {}
for name in task_names:
try:
cls = get_task(name)
except Exception as exc:
logger.warning("Skipping task %s: %s", name, exc)
continue
path = (getattr(cls, "DATASET_PATH", None) or "").strip()
if path and (sha := _repo_sha(api, path, cache)):
revisions[cls.__name__] = sha
return revisions
[docs]
def main() -> None:
from eval_framework.tasks.registry import registered_task_names
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
revisions = collect_dataset_revisions(registered_task_names(), HfApi())
REVISIONS_FILE.parent.mkdir(parents=True, exist_ok=True)
REVISIONS_FILE.write_text(
json.dumps(dict(sorted(revisions.items())), indent=4, ensure_ascii=False) + "\n",
encoding="utf-8",
)
logger.info("Wrote %d revisions to %s", len(revisions), REVISIONS_FILE)
if __name__ == "__main__":
main()