import copy
import importlib.util
import json
import logging
import math
import os
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Annotated, Any, Self, cast
import numpy as np
import wandb
import yaml
from pydantic import BaseModel, BeforeValidator, ConfigDict, field_validator, model_validator
from eval_framework.context.local import _load_model
from eval_framework.result_processors.result_processor import generate_output_dir
from eval_framework.run import _run_single_task
from eval_framework.tasks.eval_config import EvalConfig
from eval_framework.tasks.registry import is_registered
logger = logging.getLogger(__name__)
# Fields on TaskSuite that are routed to llm_args when building run kwargs
_LLM_ARG_FIELDS = {"temperature", "top_p", "top_k"}
# Fields on TaskSuite that map directly to EvalConfig / run_with_kwargs keys
_EVAL_CONFIG_FIELDS = {
"num_samples",
"num_fewshot",
"max_tokens",
"repeats",
"batch_size",
"task_subjects",
"hf_revision",
}
_HYPERPARAM_FIELDS = _LLM_ARG_FIELDS | _EVAL_CONFIG_FIELDS
[docs]
def parse_strings_to_task_or_suite(v: str | list) -> str | list:
"""Expand bare strings in a list to leaf-suite dicts. Pydantic validates them into TaskSuite."""
if isinstance(v, str):
return v
return [{"tasks": item, "name": item} if isinstance(item, str) else item for item in v]
_VALID_METHODS = {"mean", "median"}
[docs]
class MetricSource(BaseModel):
"""A single (child, metric) pair used as an input to a SuiteAggregate. See the examples folder
for how these are used."""
model_config = ConfigDict(extra="forbid")
child: str
metric: str
[docs]
class SuiteAggregate(BaseModel):
"""Model to aggregate results from a suite of tasks."""
model_config = ConfigDict(extra="forbid")
name: str
sources: list[MetricSource]
method: str | Callable[[list[float]], float] = "mean"
[docs]
@field_validator("method")
@classmethod
def validate_method(cls, v: str | Callable) -> str | Callable:
if isinstance(v, str) and v not in _VALID_METHODS:
raise ValueError(f"Unknown method '{v}'. Valid string methods: {sorted(_VALID_METHODS)}.")
return v
[docs]
class TaskSuite(BaseModel):
# TODO: Figure out versioning for suites. This differs from the versioning of the eval_framework package.
model_config = ConfigDict(extra="forbid")
name: str | None = None
# Tasks can be a string or a list of strings (which becomes a suite) or a suite.
tasks: Annotated[str | list[str | Self], BeforeValidator(parse_strings_to_task_or_suite)] = []
aggregates: list[SuiteAggregate] = []
# things passed to LLM class:
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
# a dumping dict for all the non-standard like api keys.
extra_llm_args: dict[str, Any] = {}
# things passed to EvalConfig:
num_samples: int | None = None
num_fewshot: int | None = None
max_tokens: int | None = None
repeats: int | None = None
batch_size: int | None = None
task_subjects: list[str] | None = None
hf_revision: str | None = None
[docs]
@model_validator(mode="after")
def validate_suite(self) -> Self:
if isinstance(self.tasks, str):
if self.name is None:
self.name = self.tasks
if not is_registered(self.tasks):
raise ValueError(f"Task '{self.tasks}' is not registered.")
elif not self.tasks:
raise ValueError(f"TaskSuite '{self.name}': 'tasks' must not be empty.")
elif self.name is None:
raise ValueError("Composite TaskSuite must have a 'name'.")
return self
@property
def is_leaf(self) -> bool:
return isinstance(self.tasks, str)
@property
def task_name(self) -> str:
"""The registered task name. Only valid for leaf tasks."""
assert self.is_leaf, "task_name is only valid for leaf tasks."
return self.tasks # type: ignore[return-value]
[docs]
def get_hyperparam_overrides(self) -> dict[str, Any]:
"""Return hyperparam fields that were explicitly set in the suite definition."""
explicitly_set = self.model_fields_set
overrides: dict[str, Any] = {}
for field_name in _HYPERPARAM_FIELDS:
if field_name in explicitly_set:
overrides[field_name] = getattr(self, field_name)
if "extra_llm_args" in explicitly_set:
overrides["extra_llm_args"] = self.extra_llm_args
if "task_subjects" in explicitly_set:
overrides["task_subjects"] = self.task_subjects
return overrides
[docs]
@classmethod
def load_from_yaml(cls, path: Path) -> Self:
data = yaml.safe_load(path.read_text())
return cls.model_validate(data)
[docs]
@classmethod
def load_from_py(cls, path: Path | str) -> Self:
if isinstance(path, str):
path = Path(path)
path = path.resolve()
module_name = f"_suite_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, str(path))
if spec is None or spec.loader is None:
raise ImportError(f"Could not load suite module from '{path}'.")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
if not hasattr(module, "suite"):
raise ValueError(f"Suite file '{path}' must define a 'suite' variable.")
suite = module.suite
if not isinstance(suite, TaskSuite):
raise TypeError(f"'suite' in '{path}' must be a TaskSuite instance, got {type(suite).__name__}.")
return cast(Self, suite)
[docs]
@classmethod
def load(cls, path: Path | str) -> Self:
path = Path(path)
if path.suffix in (".yaml", ".yml"):
return cls.load_from_yaml(path)
elif path.suffix == ".py":
return cls.load_from_py(path)
else:
raise ValueError(f"Unsupported suite file format: '{path.suffix}'. Use .yaml, .yml, or .py.")
[docs]
class SuiteResult(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
task_results: dict[str, Self] = {} # this stores the full hierarchy of results. Not used at all.
aggregates: dict[str, float | None] = {}
[docs]
def resolve_to_evalconfig_kwargs(
leaf: TaskSuite, resolved_defaults: dict[str, Any], cli_kwargs: dict[str, Any]
) -> dict:
"""Build the kwargs dict expected by run_with_kwargs() for a single leaf task.
Merges CLI kwargs as the base, overlays resolved suite defaults, and routes
temperature/top_p/extra_llm_args into the llm_args dict.
"""
kwargs = copy.deepcopy(cli_kwargs)
kwargs["task_name"] = leaf.task_name
for key, value in resolved_defaults.items():
if key in _EVAL_CONFIG_FIELDS:
kwargs[key] = value
elif key in _LLM_ARG_FIELDS:
kwargs["llm_args"][key] = value
elif key == "extra_llm_args":
kwargs["llm_args"].update(value)
return kwargs
[docs]
def compute_aggregates(
aggregates: list[SuiteAggregate],
child_results: dict[str, SuiteResult],
) -> dict[str, float | None]:
"""Compute suite-level stats from explicitly named (child, metric) sources.
For each `SuiteAggregate`, the value from each `MetricSource` is looked up by
child name and exact metric key. Sources whose child is missing or whose metric is
None or NaN are silently skipped. If no sources yield a valid value the aggregate is None.
"""
result: dict[str, float | None] = {}
for agg in aggregates:
values: list[float] = []
for source in agg.sources:
child = child_results.get(source.child)
if child is None:
logger.warning(
f"SuiteAggregate '{agg.name}' uses source '{source.child}' which is not a child of the suite. "
f"Available children: {list(child_results.keys())}."
)
continue
val = child.aggregates.get(source.metric)
if val is not None and not math.isnan(val):
values.append(val)
else:
logger.warning(f"The value for source '{source.child}' with metric '{source.metric}' is None or NaN.")
result[agg.name] = _apply_method(agg.method, values) if values else None
return result
def _apply_method(
method: str | Callable[[list[float]], float],
values: list[float],
) -> float:
if callable(method):
return method(values)
elif method == "mean":
return float(np.mean(values))
elif method == "median":
return float(np.median(values))
else:
raise ValueError(f"Unknown aggregation method: '{method}'. Use mean or median.")
def _merge_defaults(parent: dict[str, Any], child: dict[str, Any]) -> dict[str, Any]:
"""Merge child overrides on top of parent defaults."""
return {**parent, **child}
[docs]
def run_suite(
suite: TaskSuite,
cli_kwargs: dict[str, Any],
parent_defaults: dict[str, Any] | None = None,
root_suite_name: str | None = None,
) -> SuiteResult:
"""Recursively run all tasks in a suite and compute aggregates bottom-up using
post-order traversal.
For a leaf suite: runs the single task via _run_single_task and returns
the aggregated results directly.
For a composite suite: recurses into each child, collects results, then
computes this suite's aggregates.
"""
parent_defaults = parent_defaults or {}
current_defaults = _merge_defaults(parent_defaults, suite.get_hyperparam_overrides())
suite_name = suite.name # guaranteed non-None by validate_suite
assert suite_name is not None
# Track the top-level suite name so all leaf tasks share the same W&B group.
if root_suite_name is None:
root_suite_name = suite_name
# Lets do post-order traversal here. If leaf, go to the code in run.py
if suite.is_leaf:
resolved = resolve_to_evalconfig_kwargs(suite, current_defaults, cli_kwargs)
# Each task in a suite gets its own W&B run (nulling a shared run_id prevents
# all tasks from piling into the same W&B run), and shares the suite group.
resolved["wandb_run_id"] = None
resolved["wandb_group"] = root_suite_name
logger.info(f"Running task: {suite.task_name}")
_run_single_task(resolved)
return SuiteResult(
name=suite_name,
task_results={},
aggregates=_load_aggregated_results(resolved),
)
# else keep going down the tree depth first.
children = cast(list[TaskSuite], suite.tasks)
child_results: dict[str, SuiteResult] = {}
for child in children:
assert child.name is not None
child_results[child.name] = run_suite(
child, cli_kwargs, parent_defaults=current_defaults, root_suite_name=root_suite_name
)
# we can only compute the aggregates after all the children are run.
suite_aggregates = compute_aggregates(suite.aggregates, child_results)
output_dir = Path(cli_kwargs.get("output_dir", "outputs"))
# check that individual task results are saved in the output directory.
save_suite_results(output_dir / suite_name, suite_aggregates)
_log_suite_aggregates_to_wandb(suite_name, root_suite_name, suite_aggregates, cli_kwargs)
return SuiteResult(name=suite_name, task_results=child_results, aggregates=suite_aggregates)
# I don't like this way of loading the results. But this is how it was done in the original
# eval_framework. I reconstruct the EvalConfig just to create a hash and load that file.
def _load_aggregated_results(resolved_kwargs: dict[str, Any]) -> dict[str, Any]:
"""Load the aggregated_results.json for a completed task run."""
llm_class = _load_model(resolved_kwargs["llm_name"], models_path=resolved_kwargs["models"])
llm_instance = llm_class(**resolved_kwargs.get("llm_args", {}))
config = EvalConfig(
llm_class=llm_class,
llm_args=resolved_kwargs.get("llm_args", {}),
num_samples=resolved_kwargs.get("num_samples"),
max_tokens=resolved_kwargs.get("max_tokens"),
num_fewshot=resolved_kwargs.get("num_fewshot", 0),
task_name=resolved_kwargs["task_name"],
task_subjects=resolved_kwargs.get("task_subjects"),
hf_revision=resolved_kwargs.get("hf_revision"),
output_dir=resolved_kwargs.get("output_dir", "outputs"),
wandb_project=resolved_kwargs.get("wandb_project"),
wandb_entity=resolved_kwargs.get("wandb_entity"),
wandb_run_id=resolved_kwargs.get("wandb_run_id"),
wandb_upload_results=resolved_kwargs.get("wandb_upload_results"),
hf_upload_dir=resolved_kwargs.get("hf_upload_dir"),
hf_upload_repo=resolved_kwargs.get("hf_upload_repo"),
batch_size=resolved_kwargs.get("batch_size", 1),
repeats=resolved_kwargs.get("repeats", 1),
description=resolved_kwargs.get("description"),
randomize_judge_order=resolved_kwargs.get("randomize_judge_order", False),
delete_output_dir_after_upload=resolved_kwargs.get("delete_output_dir_after_upload", False),
)
output_dir = generate_output_dir(llm_instance.name, config)
agg_file = output_dir / "aggregated_results.json"
if agg_file.exists():
with open(agg_file) as f:
return json.load(f)
raise ValueError(f"No aggregated_results.json found at {agg_file}")
def _log_suite_aggregates_to_wandb(
suite_name: str,
root_suite_name: str,
aggregates: dict[str, float | None],
cli_kwargs: dict[str, Any],
) -> None:
"""Create a W&B run for a composite suite and log its aggregate metrics."""
from eval_framework.main import _wandb_mode
wandb_project = cli_kwargs.get("wandb_project")
if not wandb_project:
return
with wandb.init(
entity=cli_kwargs.get("wandb_entity"),
project=wandb_project,
group=root_suite_name,
job_type="suite",
name=suite_name,
mode=_wandb_mode(wandb_project),
settings=wandb.Settings(disable_code=True),
) as run:
run.log({k: v for k, v in aggregates.items() if v is not None})
logger.info(f"Logged suite aggregates for '{suite_name}' to W&B project '{wandb_project}'.")
[docs]
def save_suite_results(output_dir: Path, results: dict[str, float | None]) -> None:
os.makedirs(output_dir, exist_ok=True)
with open(output_dir / "suite_aggregated_results.json", "w") as f:
json.dump(results, f, indent=4, sort_keys=True)
logger.info(f"Saved suite aggregated results to {output_dir / 'suite_aggregated_results.json'}")