Source code for eval_framework.tasks.benchmarks.mbpp

import ast
import logging
import re
from typing import Any

from eval_framework.metrics.completion.code_assertion import (
    CodeCompletionAssertion,
)
from eval_framework.shared.types import BaseMetricContext
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample

logger = logging.getLogger(__name__)

BEGIN = "```python"
END = "```"


[docs] class MBPPMetricContext(BaseMetricContext): tests_code: str
[docs] class MBPP(BaseTask[str]): """ MBPP provides both the problem statement and the test cases upfront. It says, "Here's the problem and here are the tests; write code that passes them.". Note that LLMs can cheat and only write code that passes the tests without solving the given problem. MBPP_PROMPT_WITHOUT_TESTS, on the other hand, only gives you the problem statement and function signature initially. It says, "Here's the problem and function signature; write code, then we'll run tests later." """ NAME = "MBPP" DATASET_PATH = "google-research-datasets/mbpp" SAMPLE_SPLIT = "test" FEWSHOT_SPLIT = "train" RESPONSE_TYPE = ResponseType.COMPLETION METRICS = [CodeCompletionAssertion] SUBJECTS = ["full"] # , "sanitized"] # these are HF dataset SUBSETS! LANGUAGE = Language.ENG def __init__(self, num_fewshot: int = 0) -> None: super().__init__(num_fewshot) self.stop_sequences = [END] @staticmethod def _code_expander(code: str, gt_asserts: str) -> str: """ code variable carries the LLM-generated code snippet. We append the asserts for code testing here. If no valid code is found in the LLM output, this function is not called. Important: gt_asserts come as a stringiied list of assert strings. We safely reconvert this string back to the list of of individual assert statements (also strings) by ast.literal_eval """ if not gt_asserts: # no ground truth (data asserts) are given, we return the original code return code gt_asserts = ast.literal_eval(gt_asserts) # never use eval! if not isinstance(gt_asserts, list): logger.info("*** WARNING, we expect a list of ground truth asserts here! Sample can not be finalized") return code postfix = "" stacked_asserts = "" for gt_assert in gt_asserts: stacked_asserts += " " + gt_assert + "\n" postfix = "try:\n" + stacked_asserts + " score = True\nexcept:\n score = False\nprint(score)" return code + postfix @staticmethod def _get_function_name(line: str) -> str: match = re.search(r"def\s+(\w+)\s*\(", line) function_name = "" if match: function_name = match.group(1) return function_name def _get_instruction_text(self, item: dict[str, Any]) -> str: """ Passing selected task and tests depending on zero or few-shot setting """ tests = "\n".join(item["test_list"]) text = item["text"] if "text" in item else item["prompt"] instruction_text = f"You are an expert Python programmer, and here is your task: {text} Your code should pass these tests:\n\n{tests}\n" # noqa E501 return instruction_text def _get_cue_text(self, item: dict[str, Any]) -> str: return BEGIN def _get_ground_truth(self, item: dict[str, Any]) -> str | None: """ asserts are being passed as ground_truth, this is expected by CodeCompletionAssertion metrics """ return f"{item['test_list']}" def _get_fewshot_target_text(self, item: dict[str, Any]) -> str: target = item["code"] assert target is not None assert isinstance(target, str) return f"{BEGIN}\n" + target + f"\n{END}" def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]: fewshot_examples = self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot) return fewshot_examples def _get_context(self, item: dict[str, Any]) -> MBPPMetricContext: return MBPPMetricContext(tests_code="\n".join(item["test_list"]))
[docs] def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str: assert sample is not None if BEGIN in completion_text: completion_text = completion_text.split(f"{BEGIN}\n")[1] if END in completion_text: completion_text = completion_text.split(END)[0] extracted_code = completion_text + "\n" mbpp_ground_truth = str(sample.ground_truth) code = self._code_expander(extracted_code, mbpp_ground_truth) return code
[docs] class MBPP_SANITIZED(MBPP): NAME = "MBPP_SANITZED" SUBJECTS = ["sanitized"]
[docs] class MBPP_PROMPT_WITHOUT_TESTS(MBPP): """ MBPP provides both the problem statement and the test cases upfront. It says, "Here's the problem and here are the tests; write code that passes them.". Note that LLMs can cheat and only write code that passes the tests without solving the given problem. MBPP_PROMPT_WITHOUT_TESTS, on the other hand, only gives you the problem statement and function signature initially. It says, "Here's the problem and function signature; write code, then we'll run tests later." """ NAME = "MBPP_PROMPT_WITHOUT_TESTS" def _get_instruction_text(self, item: dict[str, Any]) -> str: """ Passing selected task and tests depending on zero or few-shot setting """ text = item["text"] if "text" in item else item["prompt"] instruction_text = f"You are an expert Python programmer, and here is your task: {text}\n\n" # noqa E501 return instruction_text def _get_cue_text(self, item: dict[str, Any]) -> str: function_header = self._get_function_header(item["code"]) return f"{BEGIN}\n{function_header}" def _get_fewshot_target_text(self, item: dict[str, Any]) -> str: target = item["code"] assert target is not None assert isinstance(target, str) return f"{BEGIN}\n" + target + f"\n{END}" @staticmethod def _get_function_header(line: str) -> str: match = re.search(r"^\s*def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\)\s*:", line, re.MULTILINE) postfix = "" if match is not None: # extract up to next open parenthesis in the found substring postfix = line[match.start() :] match = re.search(r"\)", postfix) if match is not None: end = match.start() postfix = postfix[: end + 1] else: postfix = "" if postfix == "": return postfix return f"{postfix.strip()}:"
[docs] def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str: assert sample is not None if BEGIN in completion_text: completion_text = completion_text.split(BEGIN)[1] if END in completion_text: completion_text = completion_text.split(END)[0] extracted_code = completion_text + "\n" mbpp_ground_truth = str(sample.ground_truth) function_header = self._get_function_header(sample.messages[-1].content) code = self._code_expander(extracted_code, mbpp_ground_truth) return function_header + code
[docs] class MBPP_PROMPT_WITHOUT_TESTS_SANITIZED(MBPP_PROMPT_WITHOUT_TESTS): NAME = "MBPP_PROMPT_WITHOUT_TESTS_SANITIZED" SUBJECTS = ["sanitized"]