Source code for eval_framework.context.determined

import logging
from pathlib import Path
from typing import Annotated, Any

from determined._info import get_cluster_info
from determined.core._context import Context
from determined.core._context import init as determined_core_init
from determined.core._distributed import DummyDistributedContext
from pydantic import AfterValidator, BaseModel, ConfigDict

from eval_framework.context.eval import EvalContext
from eval_framework.context.local import _load_model
from eval_framework.llm.base import BaseLLM
from eval_framework.tasks.eval_config import EvalConfig
from eval_framework.tasks.perturbation import PerturbationConfig
from eval_framework.tasks.registry import validate_task_name
from eval_framework.tasks.task_loader import load_extra_tasks

logger = logging.getLogger(__name__)


[docs] class TaskArgs(BaseModel): model_config = ConfigDict(extra="forbid") task_name: Annotated[str, AfterValidator(validate_task_name)] num_fewshot: int num_samples: int | None = None max_tokens: int | None = None batch_size: int | None = None judge_model_name: str | None = None judge_model_args: dict[str, Any] = {} task_subjects: list[str] | None = None hf_revision: str | None = None perturbation_config: PerturbationConfig | None = None repeats: int | None = None
[docs] class Hyperparameters(BaseModel): model_config = ConfigDict(extra="forbid") llm_name: str output_dir: Path hf_upload_dir: str | None = None hf_upload_repo: str | None = None wandb_project: str | None = None wandb_entity: str | None = None wandb_run_id: str | None = None wandb_upload_results: bool | None = None description: str | None = None task_args: TaskArgs llm_args: dict[str, Any] | None = {} extra_task_modules: list[str] | None = None delete_output_dir_after_upload: bool | None = None
[docs] class DeterminedContext(EvalContext): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._core_context: Context | None = None def __enter__(self) -> "DeterminedContext": distributed_context = DummyDistributedContext() self._core_context = determined_core_init(distributed=distributed_context) self._core_context.start() info = get_cluster_info() if info is None: raise RuntimeError("Failed to retrieve cluster info.") # Load extra tasks if specified first extra_task_modules = info.trial.hparams.get("extra_task_modules", None) if extra_task_modules: name = "extra_task_modules" val_cli = getattr(self, name, None) val_hparams = extra_task_modules if val_hparams: if val_cli and val_hparams and val_cli != val_hparams: logger.info( f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters:" f"({val_hparams}). If it fails due to duplicate task names, remove the CLI argument and" "consolidate as a determined hyperparameter instead." ) load_extra_tasks(val_hparams) self.hparams = Hyperparameters(**info.trial.hparams) for name in [ "llm_name", "llm_args", "output_dir", "hf_upload_dir", "hf_upload_repo", "wandb_project", "wandb_entity", "wandb_run_id", "wandb_upload_results", "description", "delete_output_dir_after_upload", ]: val_cli = getattr(self, name, None) val_hparams = getattr(self.hparams, name, None) if val_cli and val_hparams and val_cli != val_hparams: logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).") for name in [ "num_samples", "max_tokens", "num_fewshot", "task_name", "task_subjects", "batch_size", "hf_revision", "judge_model_name", "judge_model_args", "perturbation_config", "repeats", ]: val_cli = getattr(self, name, None) val_hparams = getattr(self.hparams.task_args, name, None) if val_cli and val_hparams and val_cli != val_hparams: logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).") # Hyperparameters take precedence over core context llm_name = self.hparams.llm_name or self.llm_name judge_model_name = self.hparams.task_args.judge_model_name or self.judge_model_name llm_class = _load_model(llm_name, models_path=self.models_path) llm_judge_class: type[BaseLLM] | None = ( _load_model(judge_model_name, models_path=self.judge_models_path, info="judge") if judge_model_name else None ) # for all optional hyperparameters, resort to the respective CLI argument if the hyperparameter is not set self.config = EvalConfig( llm_class=llm_class, llm_args=self.hparams.llm_args or self.llm_args, num_samples=self.hparams.task_args.num_samples or self.num_samples, max_tokens=self.hparams.task_args.max_tokens or self.max_tokens, num_fewshot=self.hparams.task_args.num_fewshot, task_name=self.hparams.task_args.task_name, task_subjects=self.hparams.task_args.task_subjects, hf_revision=self.hparams.task_args.hf_revision or self.hf_revision, perturbation_config=self.hparams.task_args.perturbation_config or self.perturbation_config, output_dir=self.hparams.output_dir, llm_judge_class=llm_judge_class, judge_model_args=self.hparams.task_args.judge_model_args or self.judge_model_args, hf_upload_dir=self.hparams.hf_upload_dir or self.hf_upload_dir, hf_upload_repo=self.hparams.hf_upload_repo or self.hf_upload_repo, wandb_project=self.hparams.wandb_project or self.wandb_project, wandb_entity=self.hparams.wandb_entity or self.wandb_entity, wandb_run_id=self.hparams.wandb_run_id or self.wandb_run_id, wandb_upload_results=self.hparams.wandb_upload_results or self.wandb_upload_results, batch_size=self.hparams.task_args.batch_size or self.batch_size, description=self.hparams.description or self.description, randomize_judge_order=self.randomize_judge_order, delete_output_dir_after_upload=self.hparams.delete_output_dir_after_upload or self.delete_output_dir_after_upload, repeats=self.hparams.task_args.repeats or self.repeats, ) return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any | None, ) -> None: if self._core_context is not None: self._core_context.close() self._core_context = None
[docs] def should_preempt(self) -> bool: if self._core_context is None: return False return self._core_context.preempt.should_preempt()
[docs] def get_trial_id(self) -> int | None: if self._core_context is None: return None return self._core_context.train._trial_id