| | |
| | """ |
| | FunctionGemma evaluation script (v2). |
| | |
| | Uses a unified system prompt for evaluation. |
| | |
| | Usage: |
| | python -m src.evaluate --model_path ./runs/<run>/final_model --benchmark_path ./data/benchmark_dataset.json |
| | """ |
| |
|
| | import os |
| | import re |
| | import sys |
| | import json |
| | import argparse |
| | import logging |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| | from datetime import datetime |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | from threading import Lock |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| | from tqdm import tqdm |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| | if str(PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(PROJECT_ROOT)) |
| |
|
| | DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json" |
| | DEFAULT_RESULTS_DIR = PROJECT_ROOT / "results" |
| |
|
| | from src.config import ( |
| | get_system_prompt, get_system_prompt_short, TOOLS, |
| | SOLANA_TOKENS, get_token_address |
| | ) |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def load_model( |
| | model_path: str, |
| | lora_path: Optional[str] = None, |
| | device: str = "auto", |
| | load_in_8bit: bool = False, |
| | load_in_4bit: bool = False, |
| | ): |
| | """Load model and tokenizer.""" |
| | logger.info(f"Loading model: {model_path}") |
| | |
| | kwargs = { |
| | "device_map": device, |
| | "trust_remote_code": True, |
| | } |
| | |
| | if load_in_8bit: |
| | kwargs["load_in_8bit"] = True |
| | elif load_in_4bit: |
| | from transformers import BitsAndBytesConfig |
| | kwargs["quantization_config"] = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | ) |
| | else: |
| | kwargs["torch_dtype"] = torch.bfloat16 |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) |
| | |
| | if lora_path: |
| | logger.info(f"Loading LoRA adapter: {lora_path}") |
| | model = PeftModel.from_pretrained(model, lora_path) |
| | |
| | model.eval() |
| | return model, tokenizer |
| |
|
| |
|
| | def parse_functiongemma_output(response: str) -> Tuple[Optional[str], Optional[Dict]]: |
| | """ |
| | Parse FunctionGemma formatted output. |
| | |
| | Format: <start_function_call>call:FUNC_NAME{key:<escape>value<escape>,...}<end_function_call> |
| | """ |
| | |
| | pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>' |
| | match = re.search(pattern, response) |
| | |
| | if not match: |
| | |
| | pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}' |
| | match = re.search(pattern, response) |
| | |
| | if not match: |
| | |
| | pattern = r'<start_function_call>call:(\w+)' |
| | match = re.search(pattern, response) |
| | if match: |
| | return match.group(1), {} |
| | |
| | |
| | for func in ["SEARCH_TOKEN", "EXECUTE_SWAP"]: |
| | if func in response: |
| | return func, {} |
| | |
| | return None, None |
| | |
| | func_name = match.group(1) |
| | params_str = match.group(2) if len(match.groups()) > 1 else "" |
| | |
| | |
| | args = parse_params_string(params_str) |
| | |
| | return func_name, args |
| |
|
| |
|
| | def parse_params_string(params_str: str) -> Dict: |
| | """Parse parameter string.""" |
| | args = {} |
| | if not params_str: |
| | return args |
| | |
| | |
| | param_pattern = r'(\w+):(?:<escape>([^<]*)<escape>|([^,}]+))' |
| | |
| | for match in re.finditer(param_pattern, params_str): |
| | key = match.group(1) |
| | value = match.group(2) if match.group(2) is not None else match.group(3) |
| | |
| | if value is None: |
| | continue |
| | |
| | value = value.strip() |
| | |
| | |
| | if value.endswith('%'): |
| | try: |
| | args[key] = float(value[:-1]) / 100 |
| | continue |
| | except ValueError: |
| | pass |
| | |
| | |
| | try: |
| | if '.' in value: |
| | args[key] = float(value) |
| | else: |
| | args[key] = int(value) |
| | except ValueError: |
| | args[key] = value |
| | |
| | return args |
| |
|
| |
|
| | def is_rejection_response(response: str) -> bool: |
| | """Check if the response is a rejection/clarification.""" |
| | |
| | if '<start_function_call>' not in response: |
| | return True |
| | |
| | |
| | rejection_keywords = [ |
| | "please specify", "could you", "what token", "which token", |
| | "请问", "请提供", "请告诉", "您能", "什么代币", "哪个代币", |
| | "sorry", "can't", "cannot", "unable", "抱歉", "无法", |
| | "more information", "more details", "更多信息", |
| | ] |
| | |
| | response_lower = response.lower() |
| | for keyword in rejection_keywords: |
| | if keyword.lower() in response_lower: |
| | return True |
| | |
| | return False |
| |
|
| |
|
| | def format_messages_for_model( |
| | messages: List[Dict], |
| | tokenizer, |
| | tools: List[Dict] = None, |
| | ) -> str: |
| | """Format messages into the model chat template.""" |
| | if hasattr(tokenizer, 'apply_chat_template'): |
| | try: |
| | return tokenizer.apply_chat_template( |
| | messages, |
| | tools=tools, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | except Exception: |
| | pass |
| | |
| | |
| | formatted = "" |
| | for msg in messages: |
| | role = msg["role"] |
| | content = msg["content"] |
| | |
| | if role == "system": |
| | formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n" |
| | elif role == "user": |
| | formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n" |
| | elif role == "assistant": |
| | formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n" |
| | |
| | formatted += "<start_of_turn>model\n" |
| | return formatted |
| |
|
| |
|
| | def generate_response( |
| | model, |
| | tokenizer, |
| | prompt: str, |
| | system_prompt: str, |
| | max_new_tokens: int = 256, |
| | ) -> str: |
| | """Generate model response.""" |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": prompt}, |
| | ] |
| | |
| | input_text = format_messages_for_model(messages, tokenizer, TOOLS) |
| | inputs = tokenizer(input_text, return_tensors="pt") |
| | inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=0.1, |
| | do_sample=True, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| | |
| | response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) |
| | response = response.replace("<end_of_turn>", "").strip() |
| | |
| | return response |
| |
|
| |
|
| | def compare_arguments(expected: Dict, actual: Dict) -> Tuple[float, List[str]]: |
| | """Compare expected vs actual arguments.""" |
| | if not expected: |
| | return 1.0 if not actual else 0.0, [] |
| | |
| | if not actual: |
| | return 0.0, ["No arguments extracted"] |
| | |
| | errors = [] |
| | total_keys = set(expected.keys()) | set(actual.keys()) |
| | |
| | if not total_keys: |
| | return 1.0, [] |
| | |
| | matched = 0 |
| | |
| | for key in expected.keys(): |
| | exp_val = expected.get(key) |
| | act_val = actual.get(key) |
| | |
| | if exp_val is None: |
| | continue |
| | |
| | if act_val is None: |
| | errors.append(f"Missing key: {key}") |
| | continue |
| | |
| | |
| | if str(exp_val) == str(act_val): |
| | matched += 1 |
| | elif isinstance(exp_val, str) and isinstance(act_val, str): |
| | |
| | if exp_val[:10] == act_val[:10]: |
| | matched += 0.5 |
| | errors.append(f"Partial match for {key}") |
| | else: |
| | errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") |
| | elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)): |
| | if abs(float(exp_val) - float(act_val)) < 0.01: |
| | matched += 1 |
| | else: |
| | errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") |
| | else: |
| | errors.append(f"Type mismatch for {key}") |
| | |
| | |
| | for key in actual.keys(): |
| | if key not in expected: |
| | errors.append(f"Extra key: {key}") |
| | |
| | score = matched / len([k for k in expected.keys() if expected.get(k) is not None]) if expected else 1.0 |
| | return score, errors |
| |
|
| |
|
| | def process_single_sample( |
| | sample: Dict, |
| | idx: int, |
| | model, |
| | tokenizer, |
| | system_prompt: str, |
| | ) -> Dict: |
| | """Process one sample and return evaluation result.""" |
| | sample_id = sample.get("id", idx + 1) |
| | category = sample.get("category", "unknown") |
| | user_input = sample["input"] |
| | expected_func = sample["expected"]["function_name"] |
| | expected_args = sample["expected"].get("arguments", {}) |
| | |
| | |
| | if isinstance(user_input, dict) and "messages" in user_input: |
| | prompt = "" |
| | for msg in user_input["messages"]: |
| | if msg.get("role") == "user": |
| | prompt = msg.get("content", "") |
| | break |
| | else: |
| | prompt = str(user_input) |
| | |
| | |
| | response = generate_response(model, tokenizer, prompt, system_prompt) |
| | |
| | |
| | actual_func, actual_args = parse_functiongemma_output(response) |
| | is_rejection = is_rejection_response(response) |
| | |
| | |
| | func_correct = False |
| | args_correct = False |
| | exact_match = False |
| | arg_score = 0.0 |
| | error_msg = None |
| | rejection_correct = False |
| | |
| | if expected_func is None: |
| | |
| | func_correct = is_rejection or actual_func is None |
| | args_correct = func_correct |
| | exact_match = func_correct |
| | arg_score = 1.0 if func_correct else 0.0 |
| | rejection_correct = func_correct |
| | |
| | if not func_correct: |
| | error_msg = f"Expected rejection, got {actual_func}" |
| | else: |
| | |
| | func_correct = actual_func == expected_func |
| | |
| | if func_correct: |
| | |
| | arg_score, arg_errors = compare_arguments(expected_args, actual_args or {}) |
| | args_correct = arg_score >= 0.99 |
| | exact_match = args_correct |
| | |
| | if not args_correct: |
| | error_msg = "; ".join(arg_errors) |
| | else: |
| | error_msg = f"Expected {expected_func}, got {actual_func}" |
| | |
| | |
| | result = { |
| | "sample_id": sample_id, |
| | "category": category, |
| | "expected_func": expected_func, |
| | "actual_func": actual_func, |
| | "func_correct": func_correct, |
| | "args_correct": args_correct, |
| | "exact_match": exact_match, |
| | "rejection_correct": rejection_correct, |
| | "arg_score": arg_score, |
| | "error_msg": error_msg, |
| | "user_input": user_input, |
| | "expected_args": expected_args, |
| | "actual_args": actual_args, |
| | "response": response, |
| | } |
| | |
| | return result |
| |
|
| |
|
| | def evaluate_benchmark( |
| | model, |
| | tokenizer, |
| | benchmark: List[Dict], |
| | chain: str = "solana", |
| | verbose: bool = False, |
| | num_workers: int = 1, |
| | ) -> Dict: |
| | """Evaluate the benchmark (supports concurrency).""" |
| | system_prompt = get_system_prompt_short(chain) |
| | |
| | results = { |
| | "total": len(benchmark), |
| | "function_correct": 0, |
| | "arguments_correct": 0, |
| | "exact_match": 0, |
| | "rejection_correct": 0, |
| | "total_arg_score": 0.0, |
| | "by_category": {}, |
| | "by_function": {}, |
| | "errors": [], |
| | } |
| | |
| | |
| | results_lock = Lock() |
| | |
| | |
| | if num_workers > 1: |
| | logger.info(f"Evaluating with {num_workers} worker threads") |
| | |
| | with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| | |
| | futures = { |
| | executor.submit( |
| | process_single_sample, |
| | sample, i, model, tokenizer, system_prompt |
| | ): i for i, sample in enumerate(benchmark) |
| | } |
| | |
| | |
| | with tqdm(total=len(benchmark), desc="Evaluation") as pbar: |
| | for future in as_completed(futures): |
| | result = future.result() |
| | |
| | |
| | with results_lock: |
| | _update_results(results, result, verbose) |
| | |
| | pbar.update(1) |
| | else: |
| | |
| | logger.info("Evaluating with a single thread") |
| | for i, sample in enumerate(tqdm(benchmark, desc="Evaluation")): |
| | result = process_single_sample(sample, i, model, tokenizer, system_prompt) |
| | _update_results(results, result, verbose) |
| | |
| | return results |
| |
|
| |
|
| | def _update_results(results: Dict, result: Dict, verbose: bool): |
| | """Update aggregated evaluation results.""" |
| | sample_id = result["sample_id"] |
| | category = result["category"] |
| | expected_func = result["expected_func"] |
| | actual_func = result["actual_func"] |
| | func_correct = result["func_correct"] |
| | args_correct = result["args_correct"] |
| | exact_match = result["exact_match"] |
| | rejection_correct = result["rejection_correct"] |
| | arg_score = result["arg_score"] |
| | error_msg = result["error_msg"] |
| | |
| | |
| | if func_correct: |
| | results["function_correct"] += 1 |
| | if args_correct: |
| | results["arguments_correct"] += 1 |
| | if exact_match: |
| | results["exact_match"] += 1 |
| | if rejection_correct: |
| | results["rejection_correct"] += 1 |
| | results["total_arg_score"] += arg_score |
| | |
| | |
| | if category not in results["by_category"]: |
| | results["by_category"][category] = { |
| | "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 |
| | } |
| | results["by_category"][category]["total"] += 1 |
| | if func_correct: |
| | results["by_category"][category]["func_correct"] += 1 |
| | if exact_match: |
| | results["by_category"][category]["exact_match"] += 1 |
| | results["by_category"][category]["arg_score"] += arg_score |
| | |
| | |
| | func_key = expected_func or "None" |
| | if func_key not in results["by_function"]: |
| | results["by_function"][func_key] = { |
| | "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 |
| | } |
| | results["by_function"][func_key]["total"] += 1 |
| | if func_correct: |
| | results["by_function"][func_key]["func_correct"] += 1 |
| | if exact_match: |
| | results["by_function"][func_key]["exact_match"] += 1 |
| | results["by_function"][func_key]["arg_score"] += arg_score |
| | |
| | |
| | if error_msg and len(results["errors"]) < 10: |
| | results["errors"].append({ |
| | "id": sample_id, |
| | "category": category, |
| | "input": result["user_input"], |
| | "expected_func": expected_func, |
| | "actual_func": actual_func, |
| | "expected_args": result["expected_args"], |
| | "actual_args": result["actual_args"], |
| | "error": error_msg, |
| | "response": result["response"][:200], |
| | }) |
| | |
| | if verbose: |
| | status = "✓" if exact_match else "✗" |
| | |
| | user_input = result["user_input"] |
| | if isinstance(user_input, dict): |
| | user_msg = "" |
| | if "messages" in user_input: |
| | for msg in user_input["messages"]: |
| | if msg.get("role") == "user": |
| | user_msg = msg.get("content", "") |
| | break |
| | input_preview = user_msg[:50] if user_msg else str(user_input)[:50] |
| | else: |
| | input_preview = str(user_input)[:50] |
| | logger.info(f"[{sample_id}] {status} {category}: {input_preview}...") |
| |
|
| |
|
| | def print_report(results: Dict): |
| | """Print evaluation report.""" |
| | total = results["total"] |
| | |
| | print("\n" + "=" * 70) |
| | print("FunctionGemma Evaluation Report") |
| | print("=" * 70) |
| | print(f"\nTotal samples: {total}") |
| | |
| | print("\n" + "-" * 70) |
| | print("Overall metrics") |
| | print("-" * 70) |
| | |
| | func_acc = results["function_correct"] / total * 100 if total > 0 else 0 |
| | arg_acc = results["arguments_correct"] / total * 100 if total > 0 else 0 |
| | exact_acc = results["exact_match"] / total * 100 if total > 0 else 0 |
| | avg_arg_score = results["total_arg_score"] / total * 100 if total > 0 else 0 |
| | |
| | |
| | rejection_samples = sum(1 for f in results["by_function"].values() if "None" in str(f)) |
| | rejection_total = results["by_function"].get("None", {}).get("total", 0) |
| | rejection_acc = results["rejection_correct"] / rejection_total * 100 if rejection_total > 0 else 0 |
| | |
| | print(f"Function selection accuracy: {func_acc:.2f}%") |
| | print(f"Argument accuracy: {arg_acc:.2f}%") |
| | print(f"Exact match accuracy: {exact_acc:.2f}%") |
| | print(f"Average argument score: {avg_arg_score:.2f}%") |
| | print(f"Rejection accuracy: {rejection_acc:.2f}%") |
| | |
| | print("\n" + "-" * 70) |
| | print("By function") |
| | print("-" * 70) |
| | |
| | for func, stats in sorted(results["by_function"].items()): |
| | func_total = stats["total"] |
| | func_correct = stats["func_correct"] / func_total * 100 if func_total > 0 else 0 |
| | func_arg_score = stats["arg_score"] / func_total * 100 if func_total > 0 else 0 |
| | func_exact = stats["exact_match"] / func_total * 100 if func_total > 0 else 0 |
| | |
| | print(f"{func:15} | samples: {func_total:3} | func acc: {func_correct:6.2f}% | " |
| | f"arg score: {func_arg_score:6.2f}% | exact: {func_exact:6.2f}%") |
| | |
| | if results["errors"]: |
| | print("\n" + "-" * 70) |
| | print("Error samples") |
| | print("-" * 70) |
| | |
| | for err in results["errors"][:5]: |
| | print(f"\nID: {err['id']} | category: {err['category']}") |
| | print(f"Input: {err['input']}") |
| | print(f"Expected: {err['expected_func']} | Actual: {err['actual_func']}") |
| | print(f"Error: {err['error']}") |
| | |
| | print("\n" + "=" * 70) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="FunctionGemma evaluation (v2)") |
| | parser.add_argument("--model_path", type=str, required=True, help="Model path") |
| | parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter path") |
| | parser.add_argument("--benchmark_path", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Benchmark dataset path") |
| | parser.add_argument("--output_path", type=str, default=None, help="Output path (defaults to results/ with timestamp)") |
| | parser.add_argument("--chain", type=str, default="solana", help="Chain name") |
| | parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization") |
| | parser.add_argument("--load_in_4bit", action="store_true", help="Enable 4-bit quantization") |
| | parser.add_argument("--verbose", action="store_true", help="Verbose logging") |
| | parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads (default 4)") |
| | args = parser.parse_args() |
| | |
| | |
| | model, tokenizer = load_model( |
| | args.model_path, |
| | lora_path=args.lora_path, |
| | load_in_8bit=args.load_in_8bit, |
| | load_in_4bit=args.load_in_4bit, |
| | ) |
| | |
| | |
| | benchmark_path = Path(args.benchmark_path) |
| | logger.info(f"Loading benchmark: {benchmark_path}") |
| | with open(benchmark_path, 'r', encoding='utf-8') as f: |
| | benchmark = json.load(f) |
| | |
| | logger.info(f"Benchmark samples: {len(benchmark)}") |
| | |
| | |
| | logger.info("Starting evaluation...") |
| | results = evaluate_benchmark( |
| | model, tokenizer, benchmark, |
| | chain=args.chain, |
| | verbose=args.verbose, |
| | num_workers=args.num_workers, |
| | ) |
| | |
| | |
| | print_report(results) |
| | |
| | |
| | output_path = Path(args.output_path) if args.output_path else DEFAULT_RESULTS_DIR / f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | |
| | with open(output_path, 'w', encoding='utf-8') as f: |
| | json.dump(results, f, ensure_ascii=False, indent=2) |
| | logger.info(f"Evaluation saved to: {output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|