Source code for eval_framework.context.local
import importlib
from os import PathLike
from typing import Any
from eval_framework.context.eval import EvalContext, import_models
from eval_framework.llm.base import BaseLLM
from eval_framework.tasks.eval_config import EvalConfig
def _load_model(llm_name: str, models_path: str | PathLike | None, *, info: str = "") -> type[BaseLLM]:
"""Load a model class either from a models file or as a fully qualified module path.
Args:
llm_name: The name of the model class to load, or a fully qualified module path.
models_path: The path to a Python file containing model class definitions
info: Additional info to include in error messages.
Returns:
The model class.
"""
if models_path is None or "." in llm_name:
# The llm_name must a a fully qualified module path
if "." not in llm_name:
raise ValueError(f"LLM {info} '{llm_name}' is not a fully qualified module path.")
module_path, llm_class_name = llm_name.rsplit(".", 1)
module = importlib.import_module(module_path)
if not hasattr(module, llm_class_name):
raise ValueError(f"LLM '{llm_class_name}' not found in module '{module_path}'.")
return getattr(module, llm_class_name)
else:
models_dict = import_models(models_path)
if llm_name not in models_dict:
if info:
info = f"{info.strip()} "
raise ValueError(f"LLM {info} '{llm_name}' not found in {models_path}.")
return models_dict[llm_name]
[docs]
class LocalContext(EvalContext):
def __enter__(self) -> "LocalContext":
llm_class = _load_model(self.llm_name, models_path=self.models_path)
self.llm_judge_class: type[BaseLLM] | None = None
if self.judge_model_name is not None:
self.llm_judge_class = _load_model(self.judge_model_name, models_path=self.judge_models_path, info="judge")
self.config = EvalConfig(
llm_class=llm_class,
llm_args=self.llm_args,
num_samples=self.num_samples,
max_tokens=self.max_tokens,
num_fewshot=self.num_fewshot,
perturbation_config=self.perturbation_config,
task_name=self.task_name,
task_subjects=self.task_subjects,
hf_revision=self.hf_revision,
output_dir=self.output_dir,
hf_upload_dir=self.hf_upload_dir,
hf_upload_repo=self.hf_upload_repo,
wandb_entity=self.wandb_entity,
wandb_project=self.wandb_project,
wandb_run_id=self.wandb_run_id,
wandb_upload_results=self.wandb_upload_results,
llm_judge_class=self.llm_judge_class,
judge_model_args=self.judge_model_args,
batch_size=self.batch_size,
description=self.description,
randomize_judge_order=self.randomize_judge_order,
delete_output_dir_after_upload=self.delete_output_dir_after_upload,
repeats=self.repeats,
)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any | None,
) -> None:
pass