Spaces:
Running
Running
timchen0618
Derive new_status from new_trajectory; fix sidebar check mark; fix question for incomplete
d14bce3 | from flask import Blueprint, jsonify, request | |
| from datasets import load_dataset | |
| import json | |
| bp = Blueprint("selected_tools", __name__, url_prefix="/api/selected-tools") | |
| VARIANTS: dict[str, dict] = { | |
| "traj_summary_orig_ext": { | |
| "repo": "timchen0618/browsecomp-plus-selected-tools-orig-analysis-v1", | |
| "label": "Summary of Trajectory", | |
| "description": "Selected tools (orig messages) · traj_summary_orig_ext conditioned", | |
| }, | |
| # test300 variants — excerpt extracted from <trajectory_summary> in run files | |
| "test300-gpt-oss-120b-less-chars": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-gpt-oss-120b-less-chars-v1", | |
| "label": "Gemini 2.5 Pro 0", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_seed0_less_chars", | |
| }, | |
| "test300-gpt-oss-120b": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-gpt-oss-120b-v1", | |
| "label": "Gemini 2.5 Pro 1", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_seed0", | |
| }, | |
| "test300-gemini-2p5-pro": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-gemini-2p5-pro-v1", | |
| "label": "Gemini 2.5 Pro 2", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_gemini-2.5-pro_1_seed0", | |
| }, | |
| "test300-gemini-3p1-pro": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-gemini-3p1-pro-v1", | |
| "label": "Gemini 3.1 Pro Preview", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_gemini_3.1-pro-preview_seed0", | |
| }, | |
| "test300-random-seed0": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed0-v1", | |
| "label": "Random Seed 0", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed0_gpt-oss-120b_seed0", | |
| }, | |
| "test300-random-seed1": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed1-v1", | |
| "label": "Random Seed 1", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed1_gpt-oss-120b_seed0", | |
| }, | |
| "test300-random-seed3": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed3-v1", | |
| "label": "Random Seed 3", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed3_gpt-oss-120b_seed0", | |
| }, | |
| "test300-random-seed4": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed4-v1", | |
| "label": "Random Seed 4", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed4_gpt-oss-120b_seed0", | |
| }, | |
| "test300-random-seed5": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed5-v1", | |
| "label": "Random Seed 5", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed5_gpt-oss-120b_seed0", | |
| }, | |
| "test300-random-seed6": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed6-v1", | |
| "label": "Random Seed 6", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed6_gpt-oss-120b_seed0", | |
| }, | |
| "test300-random-seed7": { | |
| "repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed7-v1", | |
| "label": "Random Seed 7", | |
| "description": "test300 · traj_summary_orig_ext_selected_tools_random_seed7_gpt-oss-120b_seed0", | |
| }, | |
| } | |
| DEFAULT_VARIANT = "traj_summary_orig_ext" | |
| _cache: dict[str, list] = {} | |
| def _load(variant: str) -> list: | |
| if variant in _cache: | |
| return _cache[variant] | |
| if variant not in VARIANTS: | |
| raise ValueError(f"Unknown variant: {variant!r}") | |
| repo = VARIANTS[variant]["repo"] | |
| ds = load_dataset(repo, split="train") | |
| rows = [] | |
| for row in ds: | |
| tool_counts = {} | |
| try: | |
| tool_counts = json.loads(row.get("tool_call_counts") or "{}") | |
| except Exception: | |
| pass | |
| total_tool_calls = sum(tool_counts.values()) if tool_counts else 0 | |
| sel_idx = row["selected_indices"] | |
| if isinstance(sel_idx, str): | |
| try: | |
| sel_idx = json.loads(sel_idx) | |
| except Exception: | |
| sel_idx = [] | |
| new_traj = row.get("new_trajectory") or "" | |
| new_status = "completed" if "[Final Answer]" in new_traj else "incomplete" | |
| rows.append({ | |
| "query_id": str(row["query_id"]), | |
| "rationale": row.get("rationale") or "", | |
| "selected_indices": sel_idx, | |
| "k_requested": int(row["k_requested"]), | |
| "k_effective": int(row["k_effective"]), | |
| "excerpt": row["excerpt"], | |
| "new_trajectory": new_traj, | |
| "direct_answer": bool(row["direct_answer"]), | |
| "tool_call_counts": tool_counts, | |
| "total_tool_calls": total_tool_calls, | |
| "status": row["status"], | |
| "new_status": new_status, | |
| "question": row.get("question") or "", | |
| "correct_answer": row.get("correct_answer") or "", | |
| "correct": row.get("correct"), # None if not available | |
| }) | |
| _cache[variant] = rows | |
| return rows | |
| def get_data(): | |
| variant = request.args.get("variant", DEFAULT_VARIANT) | |
| try: | |
| rows = _load(variant) | |
| return jsonify({"rows": rows, "variant": variant, "variants": VARIANTS}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def get_variants(): | |
| return jsonify({"variants": VARIANTS, "default": DEFAULT_VARIANT}) | |
| def reload_data(): | |
| import shutil, os | |
| variant = request.args.get("variant", DEFAULT_VARIANT) | |
| if variant in _cache: | |
| del _cache[variant] | |
| if variant not in VARIANTS: | |
| return jsonify({"error": f"Unknown variant: {variant!r}"}), 400 | |
| repo = VARIANTS[variant]["repo"] | |
| try: | |
| # Delete cached dataset dir so stale schema metadata doesn't block new columns | |
| cache_base = os.path.expanduser("~/.cache/huggingface/hub") | |
| dataset_cache_name = "datasets--" + repo.replace("/", "--") | |
| dataset_cache_path = os.path.join(cache_base, dataset_cache_name) | |
| if os.path.exists(dataset_cache_path): | |
| shutil.rmtree(dataset_cache_path) | |
| rows = _load(variant) | |
| return jsonify({"status": "ok", "count": len(rows), "variant": variant}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |