| |
| """ |
| Data preprocessing script. |
| |
| Convert the generated dataset into a format directly consumable by SFTTrainer. |
| FunctionGemma expects a specific chat template structure. |
| |
| Usage: |
| python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json |
| """ |
|
|
| import json |
| import argparse |
| from pathlib import Path |
| from typing import List, Dict, Any |
|
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json" |
| DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json" |
|
|
|
|
| def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str: |
| """Convert tool_calls into plain text (FunctionGemma format).""" |
| if not tool_calls: |
| return "" |
| |
| result_parts = [] |
| for tc in tool_calls: |
| func = tc.get("function", {}) |
| name = func.get("name", "") |
| args = func.get("arguments", {}) |
| |
| |
| args_str = json.dumps(args, ensure_ascii=False) |
| result_parts.append(f"{name}({args_str})") |
| |
| return "\n".join(result_parts) |
|
|
|
|
| def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]: |
| """ |
| Convert message format for SFTTrainer. |
| |
| Input: |
| [ |
| {"role": "developer", "content": "..."}, |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} |
| ] |
| |
| Output: |
| [ |
| {"role": "system", "content": "..."}, # developer -> system |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} # tool_calls flattened to text |
| ] |
| """ |
| converted = [] |
| |
| |
| tools_description = "" |
| if tools: |
| tools_desc_parts = [] |
| for tool in tools: |
| if tool.get("type") == "function": |
| func = tool.get("function", {}) |
| name = func.get("name", "") |
| desc = func.get("description", "") |
| params = func.get("parameters", {}) |
| tools_desc_parts.append(f"- {name}: {desc}") |
| if tools_desc_parts: |
| tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts) |
| |
| for msg in messages: |
| role = msg.get("role", "") |
| |
| if role == "developer": |
| |
| content = msg.get("content", "") |
| if tools_description: |
| content = content + tools_description |
| converted.append({ |
| "role": "system", |
| "content": content |
| }) |
| |
| elif role == "user": |
| converted.append({ |
| "role": "user", |
| "content": msg.get("content", "") |
| }) |
| |
| elif role == "assistant": |
| if "tool_calls" in msg: |
| |
| tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"]) |
| converted.append({ |
| "role": "assistant", |
| "content": tool_calls_text |
| }) |
| else: |
| converted.append({ |
| "role": "assistant", |
| "content": msg.get("content", "") |
| }) |
| |
| elif role == "tool": |
| |
| converted.append({ |
| "role": "tool", |
| "content": msg.get("content", "") |
| }) |
| |
| return converted |
|
|
|
|
| def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"): |
| """ |
| Prepare dataset. |
| |
| format_type: |
| - "messages": output {"messages": [...]} |
| - "text": output {"text": "..."} (flattened text) |
| """ |
| print(f"Loading dataset: {input_path}") |
| |
| with open(input_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| print(f"Raw samples: {len(data)}") |
| |
| prepared_data = [] |
| |
| for i, item in enumerate(data): |
| messages = item.get("messages", []) |
| tools = item.get("tools", []) |
| |
| |
| converted_messages = convert_messages_for_sft(messages, tools) |
| |
| if format_type == "messages": |
| prepared_data.append({ |
| "messages": converted_messages |
| }) |
| elif format_type == "text": |
| |
| text_parts = [] |
| for msg in converted_messages: |
| role = msg["role"] |
| content = msg["content"] |
| if role == "system": |
| text_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") |
| elif role == "user": |
| text_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
| elif role == "assistant": |
| text_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
| |
| prepared_data.append({ |
| "text": "\n".join(text_parts) |
| }) |
| |
| print(f"Processed samples: {len(prepared_data)}") |
| |
| |
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(prepared_data, f, ensure_ascii=False, indent=2) |
| |
| print(f"Saved to: {output_path}") |
| |
| |
| print("\n" + "=" * 60) |
| print("Example:") |
| print("=" * 60) |
| |
| if format_type == "messages": |
| example = prepared_data[0] |
| for msg in example["messages"]: |
| print(f"\n[{msg['role']}]") |
| print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]) |
| else: |
| print(prepared_data[0]["text"][:500] + "...") |
| |
| return prepared_data |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Dataset preparation") |
| parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path") |
| parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path") |
| parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format") |
| |
| args = parser.parse_args() |
| |
| prepare_dataset(args.input, args.output, args.format) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|