Source code for eval_framework.llm.base

from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import Any

from eval_framework.shared.types import RawCompletion, RawLoglikelihood
from eval_framework.tasks.base import Sample
from template_formatting.formatter import BaseFormatter, Message


[docs] class BaseLLM(ABC): @property def name(self) -> str: """ This property is used to name the results folder and identify the eval results. Overwrite this property in the subclass with e.g. the checkpoint name/huggingface model name.""" return self.__class__.__name__
[docs] @abstractmethod def generate_from_messages( self, messages: list[Sequence[Message]], stop_sequences: list[str] | None = None, max_tokens: int | None = None, temperature: float | None = None, ) -> list[RawCompletion]: """ stop_sequences and max_tokens are injected by the task if exist. They should be overwritten or extended with the properties of the model. This includes but is not limited to the stop tokens by the evaluated checkpoint (e.g. <|eot_id|> for an instruction finetuned Llama3.1, <|endoftext|> for a pretrained Llama3.1). This function is expected to raise errors which are caught and reported when running the eval. Please also make sure to raise an error in case of sequence length issues. We expect to always raise an error if something impedes the expected completion of a task. Important! The completion is expected to be detokenized and to NOT contain special tokens. Returns: List[RawCompletion] """ raise NotImplementedError
[docs] def generate_from_samples( self, samples: list[Sample], stop_sequences: list[str] | None = None, max_tokens: int | None = None, temperature: float | None = None, ) -> list[RawCompletion]: """ stop_sequences and max_tokens are injected by the task if exist. They should be overwritten or extended with the properties of the model. This includes but is not limited to the stop tokens by the evaluated checkpoint (e.g. <|eot_id|> for an instruction finetuned Llama3.1, <|endoftext|> for a pretrained Llama3.1). This function is expected to raise errors which are caught and reported when running the eval. Please also make sure to raise an error in case of sequence length issues. We expect to always raise an error if something impedes the expected completion of a task. Important! The completion is expected to be detokenized and to NOT contain special tokens. Returns: List[RawCompletion] """ raise NotImplementedError
[docs] @abstractmethod def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]: """ This function is expected to raise errors which are caught and reported when running the eval. Please also make sure to raise an error in case of sequence length issues. We expect to always raise an error if something prevents the expected completion of a task. """ raise NotImplementedError
[docs] def generate( self, samples: list[Sample], stop_sequences: list[str] | None = None, max_tokens: int | None = None, temperature: float | None = None, ) -> list[RawCompletion]: """Generates a model response for each sample. Uses 'generate_from_samples' to generate responses if implemented, otherwise falls back to 'generate_from_messages'. """ try: return self.generate_from_samples(samples, stop_sequences, max_tokens, temperature) except NotImplementedError: messages: list[Sequence[Message]] = [sample.messages for sample in samples] return self.generate_from_messages(messages, stop_sequences, max_tokens, temperature)
[docs] def post_process_completion(self, completion: str, sample: Sample) -> str: """ Model-specific post-processing of generated completions. Override this method to apply model-specific cleanup or transformations (e.g., removing specific artifacts such as reasoning traces, handling special tokens). Args: completion: The raw completion string from the model sample: The sample that was used to generate the completion Returns: The post-processed completion string """ return completion
def __del__(self) -> None: """ Method for custom resource cleanup (particularly GPUs) """ pass def _get_final_checkpoint( self, checkpoint_path: str | Path | None = None, model_name: str | None = None, artifact_name: str | None = None ) -> tuple[str | Path | None, str | None]: if (num_provided := sum(x is not None for x in [checkpoint_path, model_name, artifact_name])) == 0: if not getattr(self, "LLM_NAME", ""): raise ValueError("Either LLM_NAME, checkpoint_path, model_name, or artifact_name must be provided.") return None, None # no argument given, so will use the LLM_NAME of the class elif num_provided > 1: raise ValueError("At most one of `checkpoint_path`, `model_name`, or `artifact_name` must be provided.") elif checkpoint_path is not None: return checkpoint_path, str(checkpoint_path) elif model_name is not None: return model_name, model_name else: from eval_framework.utils.file_ops import WandbFs assert artifact_name is not None artifact_base, version = artifact_name.split(":", 1) if ":" in artifact_name else (artifact_name, "latest") with WandbFs() as wandb_fs: self.artifact = wandb_fs.get_artifact(artifact_base, version) # self.artifact being read in main() wandb_fs.download_artifact(self.artifact) file_root = wandb_fs.find_hf_checkpoint_root_from_path_list() if file_root is None: raise ValueError(f"Could not find HuggingFace checkpoint in artifact {artifact_base}:{version}") return file_root, artifact_name def _get_final_formatter( self, formatter: BaseFormatter | None = None, formatter_name: str | None = None, formatter_kwargs: dict[str, Any] | None = None, ) -> BaseFormatter | None: if (num_provided := sum(x is not None for x in [formatter, formatter_name])) == 0: return None # none given, so will use the default of the class elif num_provided > 1: raise ValueError("At most one of `formatter` or `formatter_name` must be provided.") if formatter: if formatter_kwargs: raise ValueError("Cannot provide `formatter_kwargs` when `formatter` is provided.") return formatter elif formatter_name: kwargs = formatter_kwargs or {} match formatter_name: case "Llama3Formatter": from template_formatting.formatter import Llama3Formatter return Llama3Formatter() case "MistralFormatter" | "MagistralFormatter": from eval_framework.llm.mistral import MagistralFormatter return MagistralFormatter(**kwargs) case "ConcatFormatter": from template_formatting.formatter import ConcatFormatter return ConcatFormatter() case "HFFormatter": from template_formatting.formatter import HFFormatter return HFFormatter(**kwargs) case _: raise ValueError(f"Unsupported formatter: {formatter_name}.") return None