import gc
import logging
import math
import os
import warnings
from collections.abc import Callable, Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
from tokenizers import Tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from eval_framework.llm.base import BaseLLM
from eval_framework.shared.types import (
ConcatCompression,
Error,
PromptTooLongException,
RawCompletion,
RawLoglikelihood,
)
from eval_framework.tasks.base import Sample
from eval_framework.tasks.utils import raise_errors
from eval_framework.utils.constants import RED, RESET
from template_formatting.formatter import BaseFormatter, ConcatFormatter, HFFormatter, Message
logger = logging.getLogger(__name__)
[docs]
class StopSequenceCriteria(StoppingCriteria):
def __init__(self, tokenizer: Tokenizer, stop_sequences: list[str], prompt_token_count: int) -> None:
self.tokenizer = tokenizer
self.stop_sequences = stop_sequences
self.prompt_token_count = prompt_token_count
# (relatively weak) upper bound for the number of tokens that
# need to be decoded to check for stop sequences
self.token_history_length = max(map(len, stop_sequences), default=0)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> bool:
if not self.stop_sequences:
return False
sequence = input_ids[0].tolist()
sequence = sequence[self.prompt_token_count :]
if len(sequence) > self.token_history_length:
sequence = sequence[-self.token_history_length :]
decoded_text = self.tokenizer.decode(sequence, skip_special_tokens=True)
for stop_sequence in self.stop_sequences:
if stop_sequence in decoded_text:
return True
return False
[docs]
class RepeatedTokenSequenceCriteria(StoppingCriteria):
def __init__(self, tokenizer: Tokenizer, completion_start_index: int) -> None:
self.tokenizer = tokenizer
# Initialize with an empty string to store the last line
self.last_line = ""
self.completion_start_index = completion_start_index
# self.newline_token_id = tokenizer.encode('\n')
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> torch.Tensor:
# Convert token ids to tokens
tokens = self.tokenizer.decode(input_ids[0][self.completion_start_index :])
# Join tokens to form the current text
current_text = "".join(tokens)
# Split text into lines
lines = current_text.split("\n")
# Check if the last full line (ignoring the last if it's incomplete) is repeated
if len(lines) > 1 and lines[-2] == lines[-1] and not (lines[-1] == "" and lines[-2] == ""):
return torch.BoolTensor([True]).to(input_ids.device) # Stop generation if repeated line is found
return torch.BoolTensor([False]).to(input_ids.device)
[docs]
class BaseHFLLM(BaseLLM):
LLM_NAME: str
DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
SEQ_LENGTH: int | None = None
BYTES_PER_TOKEN: float = 4.0 # rule of thumb according to https://platform.openai.com/tokenizer
def __init__(self, formatter: BaseFormatter | None = None, bytes_per_token: float | None = None) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(self.LLM_NAME)
self.model = AutoModelForCausalLM.from_pretrained(self.LLM_NAME, device_map="auto")
logger.info(f"{RED}[ Model initialized --------------------- {RESET}{self.LLM_NAME} {RED}]{RESET}")
self._set_formatter(formatter)
# set bytes_per_token_scalar for non-standard models
if bytes_per_token is not None and bytes_per_token <= 0:
raise ValueError("bytes_per_token must be positive")
self.bytes_per_token_scalar = (
4.0 / bytes_per_token if bytes_per_token is not None else 4.0 / self.BYTES_PER_TOKEN
)
def _set_formatter(self, formatter: BaseFormatter | None = None) -> None:
# if formatter is being set at initialization time, use it
if formatter is not None:
self._formatter = formatter
# if formatter is not being set at initialization time, but DEFAULT_FORMATTER was specified, use it
elif self.DEFAULT_FORMATTER is not None:
self._formatter = self.DEFAULT_FORMATTER()
# if formatter is not being set at initialization time and there is no default formatter,
# using HF chat formatter if exists
elif self.tokenizer.chat_template is not None:
self._formatter = HFFormatter(self.LLM_NAME)
# if formatter is not being set at initialization time and there is no default formatter and no chat formatter,
# using ConcatFormatter
else:
raise ValueError("No formatter specified and no default formatter available.")
logger.info(
f"{RED}[ Using default formatter --------------------- {RESET}{self._formatter.__class__.__name__} {RED}]{RESET}" # noqa: E501
)
[docs]
def count_tokens(self, text: str, /) -> int:
"""Count the number of tokens in a string."""
return len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
def __del__(self) -> None:
if hasattr(self, "model"):
num_gpus = len(self.model.hf_device_map)
del self.model
if num_gpus > 1 and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()
gc.collect()
[docs]
def generate_from_messages(
self,
messages: list[Sequence[Message]],
stop_sequences: list[str] | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
) -> list[RawCompletion]:
if temperature is None:
effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
logger.info(
f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
)
else:
effective_temperature = temperature
raw_completions = []
for single_messages in messages:
# format
prompt = self._formatter.format(single_messages, output_mode="string")
# add_special_tokens would add a second BOS token without explicitly setting it False
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
prompt_token_count = len(inputs["input_ids"][0])
pad_token_id = self.tokenizer.eos_token_id
# Prepare stopping criteria
stopping_criteria = StoppingCriteriaList()
if stop_sequences is not None:
stopping_criteria.append(StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_token_count)) # type: ignore[attr-defined]
stopping_criteria.append( # type: ignore[attr-defined]
RepeatedTokenSequenceCriteria(
self.tokenizer,
prompt_token_count,
)
)
min_seq_length = min(filter(None, [self.seq_length, self.SEQ_LENGTH]))
# Calculate the maximum number of tokens to generate
max_tokens_to_generate = min_seq_length - prompt_token_count
# Adjust max tokens based on bytes_per_token_scalar so that non-standard models generate full responses
scaled_max_tokens = math.ceil(max_tokens * self.bytes_per_token_scalar) if max_tokens is not None else None
# If max_tokens is specified, use the smaller of the two
max_tokens_to_generate = min(filter(None, [max_tokens_to_generate, scaled_max_tokens]))
if max_tokens_to_generate < 1:
if raise_errors():
raise PromptTooLongException("Prompt exceeded context size.")
raw_completions.append(
RawCompletion(
prompt=prompt,
prompt_sequence_positions=prompt_token_count,
completion="",
completion_sequence_positions=0,
raw_completion_error=Error(
error_class=PromptTooLongException.__name__,
message="Prompt exceeded context size.",
traceback="",
),
)
)
continue
completion, completion_token_count = self._model_generate(
redis_key=(prompt, stop_sequences, max_tokens_to_generate, effective_temperature),
prompt_token_count=prompt_token_count,
inputs=inputs["input_ids"],
max_new_tokens=max_tokens_to_generate,
stopping_criteria=stopping_criteria,
num_return_sequences=1,
pad_token_id=pad_token_id,
return_dict_in_generate=False,
output_scores=False,
do_sample=effective_temperature > 0,
temperature=effective_temperature if effective_temperature > 0 else None,
)
raw_completions.append(
RawCompletion(
prompt=prompt,
prompt_sequence_positions=prompt_token_count,
concat_compression=ConcatCompression.calculate(
single_messages, count_tokens=self.count_tokens, completion=completion
),
completion=completion,
completion_sequence_positions=completion_token_count,
)
)
return raw_completions
def _model_generate(self, redis_key: Any, prompt_token_count: int, **kwargs: Any) -> tuple[str, int]:
with torch.no_grad():
outputs = self.model.generate(**kwargs)[0]
completion = self.tokenizer.decode(outputs[prompt_token_count:], skip_special_tokens=True)
if kwargs["stopping_criteria"][0].__class__.__name__ == "StopSequenceCriteria":
for stop_sequence in kwargs["stopping_criteria"][0].stop_sequences:
completion = completion.split(stop_sequence)[0]
return completion, len(outputs[prompt_token_count:])
[docs]
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
results = []
for sample in samples:
# format
prompt = self._formatter.format(sample.messages, output_mode="string")
choices_log_probs: dict[str, float] = {}
choices_log_probs_sequence_positions: dict[str, float] = {}
error: Error | None = None
for choice in sample.possible_completions or []:
num_choice_tokens = len(self.tokenizer.encode(choice, add_special_tokens=False))
prompt_and_choice = f"{prompt}{choice}"
total_tokens_count = len(self.tokenizer.encode(prompt_and_choice, add_special_tokens=False))
min_max_tokens = min(filter(None, [self.SEQ_LENGTH, self.seq_length]))
if min_max_tokens < total_tokens_count:
if raise_errors():
raise PromptTooLongException("Prompt exceeded context size.")
choices_log_probs = {}
choices_log_probs_sequence_positions = {}
error = Error(
error_class=PromptTooLongException.__name__,
message="Prompt and choice exceeded context size.",
traceback="",
)
break
else:
# Calculate log-likelihoods for each token in the completion
sum_log_probs = self._model_log_probs(prompt_and_choice, num_choice_tokens)
choices_log_probs.update({choice: sum_log_probs})
choices_log_probs_sequence_positions.update({choice: num_choice_tokens})
results.append(
RawLoglikelihood(
prompt=prompt,
prompt_sequence_positions=len(self.tokenizer.encode(prompt, add_special_tokens=False)),
concat_compression=ConcatCompression.calculate(
sample.messages, count_tokens=self.count_tokens, choices=sample.possible_completions
),
loglikelihoods=choices_log_probs,
loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
raw_loglikelihood_error=error,
)
)
return results
def _model_log_probs(self, prompt: str, num_choice_tokens: int) -> float:
with torch.no_grad():
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
outputs = self.model(**inputs, labels=inputs["input_ids"])
logits = outputs.logits[:, :-1, :].squeeze(0)
target_ids = inputs["input_ids"][:, 1:].squeeze(0)
token_loglikelihoods = []
for i in range(0, len(target_ids)):
token_id = target_ids[i].item()
token = self.tokenizer.decode([token_id])
loglikelihood = torch.log_softmax(logits[i], dim=-1)[token_id].item()
token_loglikelihoods.append({token: loglikelihood})
return sum([list(log_prob.values())[0] for log_prob in token_loglikelihoods[-num_choice_tokens:]])
@property
def seq_length(self) -> int | None:
config = self.model.config
return config.max_position_embeddings if hasattr(config, "max_position_embeddings") else None
[docs]
class HFLLM(BaseHFLLM):
"""A class to create HFLLM instances from various model sources."""
def __init__(
self,
# Model source (3 options: file path, HuggingFace model name, Wandb artifact name):
checkpoint_path: str | Path | None = None,
model_name: str | None = None,
artifact_name: str | None = None,
# Formatter (2 options):
formatter: BaseFormatter | None = None,
formatter_name: str | None = None,
formatter_kwargs: dict[str, Any] | None = None,
# Explicit name for the `name` property:
checkpoint_name: str | None = None,
# HFLLM parameters:
bytes_per_token: float | None = None,
**kwargs: Any,
) -> None:
final_path, possible_name = self._get_final_checkpoint(checkpoint_path, model_name, artifact_name)
self.checkpoint_name = checkpoint_name
if self.checkpoint_name is None and possible_name is not None:
self.checkpoint_name = possible_name.replace("/", "_").replace(":", "_").strip("_") # sanitize pathname
if final_path:
self.LLM_NAME = str(final_path)
final_formatter = self._get_final_formatter(formatter, formatter_name, formatter_kwargs)
super().__init__(
formatter=final_formatter,
bytes_per_token=bytes_per_token,
**kwargs,
)
@property
def name(self) -> str:
if self.checkpoint_name is not None:
return f"{self.__class__.__name__}_checkpoint_{self.checkpoint_name}"
return super().name
[docs]
class HFLLM_from_name(HFLLM): # deprecated
"""
A generic class to create HFLLM instances from a given model name.
"""
def __init__(self, model_name: str, formatter: str = "Llama3Formatter", **kwargs: Any) -> None:
warnings.warn("`HFLLM_from_name` is deprecated, please use `HFLLM`.", DeprecationWarning)
super().__init__(
model_name=model_name,
formatter_name=formatter,
**kwargs,
)
[docs]
class HFLLMRegistryModel(HFLLM): # deprecated
"""
A class to create HFLLM instances from registered models in Wandb registry.
Downloads the model artifacts from Wandb and creates a local HFLLM instance.
"""
def __init__(
self,
artifact_name: str,
version: str = "latest",
formatter: str = "",
formatter_identifier: str = "",
**kwargs: Any,
) -> None:
"""
Initialize HFLLM from a Wandb registered model artifact.
Args:
artifact_name: Name of the artifact in the Wandb registry
version: Version of the artifact to download (default: "latest")
formatter: Type of formatter to use (default: "")
**kwargs: Additional arguments passed to the parent class
"""
warnings.warn("`HFLLMRegistryModel` is deprecated, please use `HFLLM`.", DeprecationWarning)
download_path = kwargs.pop("download_path", None)
if download_path is not None and os.getenv("WANDB_ARTIFACT_DIR") is None:
os.environ["WANDB_ARTIFACT_DIR"] = download_path
super().__init__(
artifact_name=f"{artifact_name}:{version}",
formatter_name=formatter,
formatter_kwargs={"hf_llm_name": formatter_identifier} if formatter_identifier else {},
checkpoint_name=f"{artifact_name}/{version}",
**kwargs,
)
[docs]
class Pythia410m(HFLLM):
LLM_NAME = "EleutherAI/pythia-410m"
DEFAULT_FORMATTER = ConcatFormatter
[docs]
class SmolLM135M(HFLLM):
LLM_NAME = "HuggingFaceTB/SmolLM-135M"
DEFAULT_FORMATTER = ConcatFormatter
[docs]
class Smollm135MInstruct(HFLLM):
LLM_NAME = "HuggingFaceTB/SmolLM-135M-Instruct"
DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME)
[docs]
class Qwen3_0_6B(HFLLM):
LLM_NAME = "Qwen/Qwen3-0.6B"
DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": True})