Source code for eval_framework.result_processors.wandb_uploader

"""
Module for writing result folder to a W&B artifact
"""

import hashlib
import logging
import subprocess
from collections.abc import Callable
from pathlib import Path

import wandb
from wandb.sdk.artifacts._validators import NAME_MAXLEN

from eval_framework.result_processors.base import ResultsUploader
from eval_framework.tasks.eval_config import EvalConfig

logger = logging.getLogger(__name__)

ArtifactUploadFunction = Callable[[str, str, list[Path]], str | None]  # returns reference path for W&B or None
_ARTIFACT_UPLOAD_FUNCTION: ArtifactUploadFunction | None = None


[docs] def register_artifact_upload_function(func: ArtifactUploadFunction | None) -> None: global _ARTIFACT_UPLOAD_FUNCTION _ARTIFACT_UPLOAD_FUNCTION = func
[docs] def artifact_upload_function(artifact_name: str, subpath: str, file_paths: list[Path]) -> str | None: if _ARTIFACT_UPLOAD_FUNCTION is None: return None logger.info(f"Uploading '{artifact_name}'.") reference_path = _ARTIFACT_UPLOAD_FUNCTION(artifact_name, subpath, file_paths) if reference_path is None: logger.warning("Failed uploading, the custom upload function returned empty destination path!") else: logger.info(f"Successfully uploaded to {reference_path}.") return reference_path
[docs] class WandbUploader(ResultsUploader): def __init__( self, config: EvalConfig, include_all: bool = True, compress_non_json: bool = True, wandb_registry: str | None = None, ) -> None: if not config.wandb_upload_results: logger.warning("Results will not be persisted in WandB (`wandb_upload_results` not set).") return if config.output_dir is None: raise ValueError("Output directory is not set in the configuration.") if wandb.run is None or wandb.run.settings._noop: logger.warning("Results will not be persisted in WandB (no WandB run active / `wandb_project` not set).") return self._include_all = include_all self._compress_non_json = compress_non_json self._wandb_registry = wandb_registry
[docs] def upload(self, llm_name: str, config: EvalConfig, output_dir: Path) -> bool: if hasattr(self, "_wandb_registry") is False: return False # not initialized try: if self._include_all and self._compress_non_json: # note: individual gz files can be easily read by `less` or `grepz`, unlike a tar of multiple files subprocess.run( ["find", output_dir, "-type", "f", "!", "-name", "*.json", "-exec", "gzip", "-k", "{}", ";"], check=True, ) file_paths = list(output_dir.glob("*.json")) + list(output_dir.glob("*.gz")) elif self._include_all: file_paths = list(output_dir.glob("*")) else: file_paths = list(output_dir.glob("*.json")) artifact_name = self._get_artifact_name(llm_name, config) alias_name = self._get_alias(output_dir) try: rel_upload_dir = str(output_dir.relative_to(config.output_dir)) reference_path = artifact_upload_function(artifact_name, rel_upload_dir, file_paths) except Exception as e: logger.error(f"Problem during artifact upload function, aborting registration: {e}.") return False artifact = wandb.Artifact(name=artifact_name, type="eval") # note: metadata is added from run automatically if reference_path: artifact.add_reference(reference_path, checksum=True) else: logger.info("Uploading results directly to WandB.") for fp in file_paths: artifact.add_file(str(fp)) # Because metadata and e.g. logs are added to the artifact, we get a new version everytime! # To mitigate this, we also add an alias based on the content hash of the "pure" result files. wandb.log_artifact(artifact, aliases=[alias_name]) if self._wandb_registry: artifact.link(f"wandb-registry-{self._wandb_registry}/{artifact_name}") logger.info(f"Successfully registered '{artifact_name}' in WandB") finally: for fp in file_paths: if fp.suffix == ".gz" and self._compress_non_json: fp.unlink(missing_ok=True) return True
def _get_artifact_name(self, llm_name: str, config: EvalConfig) -> str: llm_name = llm_name.replace("/", "_") # assuming this has class name and checkpoint name in it # As in generate_output_dir() for consistency, though shorter. # But we don't include the eval_framework version and timestamp here (-> handled via W&B versioning) fewshot_str = f"fs{config.num_fewshot}" if config.num_fewshot is not None else "" samples_str = f"s{config.num_samples}" if config.num_samples is not None else "" tokens_str = f"t{config.max_tokens}" if config.max_tokens is not None else "" params_str = f"{fewshot_str}{samples_str}{tokens_str}" config_json = config.model_json_robust_subset_dump() config_hash = hashlib.sha256(config_json.encode("utf-8")).hexdigest()[:5] # Respect W&B artifact name length limit eval_name = f"__{config.task_name}__{params_str}_{config_hash}" max_llm_name_len = NAME_MAXLEN - len(eval_name) return llm_name[:max_llm_name_len] + eval_name def _get_alias(self, output_dir: Path) -> str: digests = [] # These files don't contain result-irrelevant things such as timestamps or paths and are fields are ordered. # This makes them good for generating a hash that identifies the actual results and not "random" metadata. for filename in ["aggregated_results.json", "output.jsonl", "results.jsonl"]: if (output_dir / filename).exists(): with open(output_dir / filename, "rb") as f: digests.append(hashlib.file_digest(f, "sha256").hexdigest()) hash = hashlib.sha256("".join(digests).encode("utf-8")).hexdigest() return f"H-{hash[:10]}"