Source code for eval_framework.main

import gc
import json
import logging
import os
import shutil
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal

import wandb

from eval_framework.evaluation_generator import EvaluationGenerator, Result
from eval_framework.llm.base import BaseLLM
from eval_framework.response_generator import ResponseGenerator
from eval_framework.result_processors.hf_uploader import HFUploader
from eval_framework.result_processors.result_processor import ResultsFileProcessor, generate_output_dir
from eval_framework.result_processors.wandb_uploader import WandbUploader
from eval_framework.tasks.eval_config import EvalConfig
from eval_framework.utils.constants import RED, RESET
from eval_framework.utils.logging import setup_logging

logger = logging.getLogger(__name__)


[docs] def main( llm: BaseLLM, config: EvalConfig, should_preempt_callable: Callable[[], bool] | None = None, trial_id: int | None = None, *args: Any, resource_cleanup: bool = False, verbosity: int = 1, ) -> list[Result]: """Runs the entire evaluation process: responses generation and evaluation.""" # Set up centralized logging early output_dir = generate_output_dir(llm.name, config) setup_logging(output_dir=output_dir, log_level=verbosity, log_filename="evaluation.log") logger.info(f"Output directory for evaluation: {output_dir}") logger.info(f"{RED}[ Running full evaluation process ------- ]{RESET}") logger.info(f"Evaluating {llm.name} on {config.task_name}") logger.info(f"Configuration: num_fewshot={config.num_fewshot}, num_samples={config.num_samples}") logger.info(f"Output directory: {output_dir}") if not should_preempt_callable: should_preempt_callable = lambda: False # noqa: E731 preemption_data = None if trial_id: preemption_data = _read_preemption_data(config, trial_id) if preemption_data is None: output_dir = generate_output_dir(llm.name, config) wandb_run_id = config.wandb_run_id # defaults to none, if no run_id is provided then it starts a new one else: logger.info("Found preempted run restarting ...") output_dir = preemption_data["output_dir"] wandb_run_id = preemption_data.get("wandb_run_id", None) logger.info(f"Output directory: {output_dir}") assert output_dir is not None file_processor = ResultsFileProcessor(output_dir) response_generator = ResponseGenerator(llm, config, file_processor) with wandb.init( entity=config.wandb_entity, project=config.wandb_project, group=llm.name[:127], job_type=config.task_name[:63], id=wandb_run_id, # (potentially resuming run after preemption) config=response_generator._get_metadata(), resume="allow", mode=_wandb_mode(config.wandb_project), settings=wandb.Settings(disable_code=True), # ("wandb-history" artifacts not needed) ) as run: artifact = getattr(llm, "artifact", None) if artifact is not None: wandb.use_artifact(artifact) for additional_artifact in os.getenv("WANDB_ADDITIONAL_ARTIFACT_REFERENCES", "").split(","): if additional_artifact.strip(): wandb.use_artifact(additional_artifact.strip()) _, preempted = response_generator.generate(should_preempt_callable) if preempted: logger.info("Response generation was preempted") assert trial_id is not None run.mark_preempting() _save_preemption_data(config, trial_id, output_dir, wandb_run_id=run.id) wandb.finish(exit_code=1) return [] # update config from response generator with get metadata if trial_id is not None: _delete_preemption_file(config, trial_id) if resource_cleanup: del response_generator gc.collect() evaluator = EvaluationGenerator(config, file_processor) results = evaluator.run_eval() upload_success = False for uploader in [HFUploader(config), WandbUploader(config)]: upload_success |= uploader.upload(llm.name, config, output_dir) if config.delete_output_dir_after_upload and upload_success: logger.warning(f"Deleting output directory '{output_dir}' after successful upload(s)!") shutil.rmtree(output_dir, ignore_errors=True) if output_dir.exists() and any(output_dir.iterdir()): logger.warning("Could not delete output directory, some files remain.") return results
def _read_preemption_data(config: EvalConfig, trial_id: int) -> dict[str, Any] | None: preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json" if not preemption_file.is_file(): return None with open(preemption_file, "rb") as f: preemption_data = json.load(f) preemption_data["output_dir"] = Path(preemption_data["output_dir"]) preemption_data["wandb_run_id"] = preemption_data.get("wandb_run_id", "") logger.info(f"Loaded preemption data from {preemption_file}") return preemption_data def _save_preemption_data(config: EvalConfig, trial_id: int, output_dir: Path, wandb_run_id: str = "") -> None: preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json" with open(preemption_file, "w") as f: json.dump({"output_dir": str(output_dir), "wandb_run_id": wandb_run_id}, f) def _delete_preemption_file(config: EvalConfig, trial_id: int) -> None: preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json" if preemption_file.is_file(): preemption_file.unlink() logger.info(f"Deleted preemption file: {preemption_file}") else: logger.info(f"No preemption file found to delete: {preemption_file}") logger.info(f"Saved preemption data to {preemption_file}") def _wandb_mode(project: str | None) -> Literal["online", "disabled"] | None: """ Checks to see if a WandB API key is found. If not, wandb starts in offline mode. """ if project is None: logger.warning("No WandB project specified, disabling logging.") return "disabled" else: try: api_key = wandb.api.api_key if api_key is None: logger.warning( """No wandb API key found. Disabling Wandb logging. If you have a WandB account set the environment variable 'WANDB_API_KEY'""" ) return "disabled" else: logger.info("Wandb login detected. Using online mode.") except Exception as e: logger.warning(f"Wandb login check failed: {e}. Disabling Wandb logging.") return "disabled" return "online"