Source code for eval_framework.response_generator

import logging
import time
import traceback
from collections.abc import Callable, Iterable
from datetime import UTC, datetime
from functools import partial

from eval_framework.tasks.registry import get_task

try:
    from determined._info import get_cluster_info
except ImportError:
    get_cluster_info = None  # type: ignore[assignment]


from typing import Any

from tqdm import tqdm

from eval_framework import __version__ as eval_framework_version
from eval_framework.llm.base import BaseLLM
from eval_framework.result_processors.result_processor import ResultsFileProcessor
from eval_framework.shared.types import (
    Completion,
    Error,
    Loglikelihood,
    RawLoglikelihood,
)
from eval_framework.tasks.base import Language, ResponseType, Sample
from eval_framework.tasks.eval_config import EvalConfig
from eval_framework.tasks.perturbation import create_perturbation_class
from eval_framework.tasks.utils import raise_errors
from eval_framework.utils.constants import RED, RESET
from eval_framework.utils.tqdm_handler import get_disable_bar_flag, safe_tqdm_write

logger = logging.getLogger(__name__)


[docs] def map_language_to_value( language: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None, ) -> str | dict[str, str] | dict[str, tuple[str, str]] | None: if language is None: return None elif isinstance(language, Language): return language.value elif isinstance(language, dict): if isinstance(list(language.values())[0], Language): return {k: v.value for k, v in language.items()} # type: ignore[union-attr] else: return {k: (v[0].value, v[1].value) for k, v in language.items()} # type: ignore[index] else: raise ValueError(f"Invalid language: {language}")
[docs] class ResponseGenerator: def __init__(self, llm: BaseLLM, config: EvalConfig, result_processor: ResultsFileProcessor) -> None: self.few_shot = config.num_fewshot self.task_name = config.task_name self.llm = llm self.config = config self.result_processor = result_processor self.num_samples = config.num_samples self.save_intermediate_results = config.save_intermediate_results task_class = get_task(config.task_name) if config.perturbation_config is not None: perturbation_task_class = create_perturbation_class(task_class, config.perturbation_config) self.task = perturbation_task_class.with_overwrite( self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision ) else: self.task = task_class.with_overwrite( self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision ) self.response_type = task_class.RESPONSE_TYPE def _llm_task_param_precedence(self) -> tuple[list[str] | None, int | None]: """ sets the stop_sequences and max_tokens values to be used in the completion generation. Max token and stop sequence values have an order of precedence: LLM attributes take precedence over task attributes, and therefore overload them. :return: stop_sequences, max_tokens """ llm_stop_sequences = getattr(self.llm, "stop_sequences", None) llm_max_tokens = getattr(self.llm, "max_tokens", None) task_stop_sequences = getattr(self.task, "stop_sequences", None) task_max_tokens = self.config.max_tokens or getattr(self.task, "max_tokens", None) # if both task and model define a max_token, the smaller value is used max_tokens = min([x for x in [llm_max_tokens, task_max_tokens] if x is not None], default=None) logger.info(f"Set max_tokens to {max_tokens}") # if both task and model define stop sequences, those are merged into one list stop_sequences_merged = (llm_stop_sequences or []) + (task_stop_sequences or []) stop_sequences = sorted(list(set(stop_sequences_merged))) if stop_sequences_merged else None logger.info(f"Set stop_sequences to {stop_sequences}") return stop_sequences, max_tokens def _generate_loglikelihoods(self, samples: list[Sample]) -> list[Loglikelihood]: """ Generate log likelihoods when a sample is run against the model. :param sample: sample to run the task against :return: loglikelihoods """ raw_loglikelihoods: list[RawLoglikelihood] try: raw_loglikelihoods = self.llm.logprobs(samples) except Exception as e: if raise_errors(): raise e logger.info(f"Error: {e.__class__.__name__} {e}") assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode" raw_loglikelihoods = [ RawLoglikelihood( prompt="", prompt_sequence_positions=0, loglikelihoods={}, loglikelihoods_sequence_positions={}, raw_loglikelihood_error=Error( error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc() ), ) for _ in range(len(samples)) ] loglikelihood_list = [] for idx, sample in enumerate(samples): raw_loglikelihood = raw_loglikelihoods[idx] assert sample.ground_truth is not None loglikelihood_list.append( Loglikelihood( id=sample.id, subject=sample.subject, ground_truth=sample.ground_truth, prompt=raw_loglikelihood.prompt, prompt_sequence_positions=raw_loglikelihood.prompt_sequence_positions, concat_compression=raw_loglikelihood.concat_compression, loglikelihoods=raw_loglikelihood.loglikelihoods, loglikelihoods_sequence_positions=raw_loglikelihood.loglikelihoods_sequence_positions, error=raw_loglikelihood.raw_loglikelihood_error, ) ) return loglikelihood_list def _generative_output_type_selector(self) -> Callable[[list[Sample]], list[Completion] | list[Loglikelihood]]: """ Selects the generative output type based on the response type. :return: function to generate responses """ match self.response_type: case ResponseType.COMPLETION: stop_sequences, max_tokens = self._llm_task_param_precedence() return partial( self.task.generate_completions, self.llm, stop_sequences=stop_sequences, max_tokens=max_tokens ) # type: ignore[call-arg] case ResponseType.LOGLIKELIHOODS: return self._generate_loglikelihoods case _: raise KeyError(f"Task type {self.task} not supported") def _run_task_against_model( self, should_preempt_callable: Callable[[], bool] ) -> tuple[list[Completion | Loglikelihood], bool]: """ Runs the task against the model and generates responses. :param should_preempt_callable: function to check if preempt is called :return: list of responses, preempted """ logger.info(f"{RED}[ Running task {self.task.NAME} against model ------------ ]{RESET}") self.start_time, monotonic_start = time.time(), time.monotonic() run_fn = self._generative_output_type_selector() self._verify_loaded_metadata_compatibility() responses = self.result_processor.load_responses() # load responses if present subject_response_id_mapping = self._map_subject_response_ids(responses) self.result_processor.save_metadata(self._get_metadata()) responses, preempted = self._curate_responses( responses, subject_response_id_mapping, run_fn, should_preempt_callable ) self.end_time, monotonic_end = time.time(), time.monotonic() self.total_time = monotonic_end - monotonic_start self.result_processor.save_metadata(self._get_metadata()) # overwrite with updated timing return responses, preempted def _map_subject_response_ids(self, responses: list[Completion | Loglikelihood]) -> dict[str, set[int]]: """ Maps subject to response id :param responses: list of responses :return: mapping of subject to response id """ subject_response_id_mapping = {} if responses: response_subjects = {resp.subject for resp in responses} subject_response_id_mapping = { response_subject: set([resp.id for resp in responses if resp.subject == response_subject]) for response_subject in response_subjects } return subject_response_id_mapping def _curate_responses( self, responses: list[Completion | Loglikelihood], subject_response_id_mapping: dict[str, set[int]], generative_output_function: Callable[[list[Sample]], list[Completion] | list[Loglikelihood]], should_preempt_callable: Callable[[], bool], ) -> tuple[list[Completion | Loglikelihood], bool]: """ Generates responses for the task and saves them along with metadata. :param responses: list of responses :param subject_response_id_mapping: mapping of subject to response id :param generative_output_function: function to generate responses :param metadata: metadata dictionary :param should_preempt_callable: function to check if preempt is called :return: None """ def _process_batch(samples_batch: list[Sample]) -> None: if not samples_batch: return if len(samples_batch) > 1: log_msg = "Processing batch..." logger.info(log_msg) # For log files safe_tqdm_write(log_msg) # For console display with tqdm responses_batch = generative_output_function(samples_batch) responses.extend(responses_batch) if self.save_intermediate_results: for response in responses_batch: self.result_processor.save_response(response) # In order to enable parallelism we group samples in batches and send them in parallel to the `run_fn`. # The BaseLLM class is then in charge of managing the parallelism (eg, using AsyncClient in API models). # If samples_batch_size = 1, samples are run sequentially; in any case, we return here after finishing each # individual batch to honor preemption requests and save cached results. samples_batch_size = self.config.batch_size repeats = self.config.repeats # Calculate total samples for progress bar - use num_samples or iterate to count if self.num_samples is None: # Count samples by iterating (this might be expensive for large datasets) total_num_samples = sum(1 for _ in self.task.iterate_samples(None)) * repeats else: total_num_samples = self.num_samples * repeats samples_batch: list[Sample] = [] with tqdm( total=total_num_samples, desc=f"Processing {self.response_type.value}", disable=get_disable_bar_flag() ) as pbar: samples = self.task.iterate_samples(self.num_samples) for i, sample in enumerate(repeat_samples(samples, repeats)): subject = f" - Subject: {sample.subject}" sample_index = i + 1 if sample.id in subject_response_id_mapping.get(sample.subject, []): log_msg = ( f"Task: {self.response_type.value}{subject} - Sample: {sample_index} - skipping, already done." ) logger.info(log_msg) # For log files safe_tqdm_write(log_msg) # For console display with tqdm pbar.update(1) continue log_msg = f"Task: {self.response_type.value}{subject} - Sample: {sample_index}/{total_num_samples}" logger.info(log_msg) # For log files safe_tqdm_write(log_msg) # For console display with tqdm pbar.set_postfix_str(f"Sample {sample_index}/{total_num_samples}") pbar.update(1) samples_batch.append(sample) if len(samples_batch) >= samples_batch_size: _process_batch(samples_batch) samples_batch = [] if should_preempt_callable(): log_msg = "Preempt" logger.info(log_msg) # For log files safe_tqdm_write(log_msg) # For console display with tqdm if not self.save_intermediate_results: self.result_processor.save_responses(responses) return responses, True _process_batch(samples_batch) if not self.save_intermediate_results: self.result_processor.save_responses(responses) return responses, False def _get_metadata(self) -> dict[str, Any]: """Prepares metadata dictionary from the configuration.""" all_metrics = getattr(self.task, "METRICS", None) metadata = self.config.model_dump(mode="json") metadata["llm_name"] = self.llm.name metadata["task_name"] = self.task_name language = getattr(self.task, "LANGUAGE", None) metadata["language"] = map_language_to_value(language) metadata["metrics"] = [m.NAME for m in all_metrics] if all_metrics is not None else [] metadata["primary_metrics"] = getattr(self.task, "PRIMARY_METRICS", None) metadata["eval_framework_version"] = eval_framework_version metadata["task_output_dir"] = str(self.result_processor.output_dir) if hasattr(self, "total_time"): metadata["start_time"] = str(datetime.fromtimestamp(self.start_time, UTC)) metadata["end_time"] = str(datetime.fromtimestamp(self.end_time, UTC)) metadata["total_time"] = self.total_time # add task specific metadata metadata["task_metadata"] = self.task.get_metadata() try: assert get_cluster_info is not None, "Determined cluster info not available" info = get_cluster_info() if info is not None: metadata["determined_agent_id"] = info.agent_id if info.task_type == "TRIAL": metadata["determined_experiment_id"] = info.trial.experiment_id metadata["determined_trial_id"] = info.trial.trial_id except Exception as e: logger.info(f"{e}; cluster info not available in local context") return metadata def _verify_loaded_metadata_compatibility(self) -> None: if not (loaded_metadata := self.result_processor.load_metadata()): return current_metadata = self._get_metadata() # check if crucial keys in metadata are the same as in the previous run keys = [ "task_name", "task_subjects", "num_fewshot", "num_samples", "llm_name", "llm_args", "perturbation_config", "repeats", ] for key in keys: if loaded_metadata[key] != current_metadata[key]: raise ValueError(f"Existing metadata does not match current metadata for {key}.") def __del__(self) -> None: self.llm.__del__()
[docs] def generate(self, should_preempt_callable: Callable[[], bool]) -> tuple[list[Completion | Loglikelihood], bool]: """Generates responses and saves them along with metadata. :param should_preempt_callable: function to check if preempt is called :return: list of responses, preempted: whether the process was preempted or not """ logger.info(f"{RED}[ Running responses generation ---------- ]{RESET}") logger.info(f"{RED}[ Will save into {self.result_processor.output_dir} ---------- ]{RESET}") responses, preempted = self._run_task_against_model(should_preempt_callable) logger.info("Completions generated and saved.") return responses, preempted
[docs] def repeat_samples(samples: Iterable[Sample], repeats: int) -> Iterable[Sample]: """Flatten repeats into a single stream of samples. After expansion original sample indices do not point to the same sample anymore. They Original sample can be recovered by `original_index = expanded_index // repeats`. """ for sample in samples: base_id = sample.id * repeats for repeat_idx in range(repeats): repeated_sample = sample.model_copy() repeated_sample.id = base_id + repeat_idx yield repeated_sample