dashboard / backend /api /selected_tools.py
timchen0618
Derive new_status from new_trajectory; fix sidebar check mark; fix question for incomplete
d14bce3
from flask import Blueprint, jsonify, request
from datasets import load_dataset
import json
bp = Blueprint("selected_tools", __name__, url_prefix="/api/selected-tools")
VARIANTS: dict[str, dict] = {
"traj_summary_orig_ext": {
"repo": "timchen0618/browsecomp-plus-selected-tools-orig-analysis-v1",
"label": "Summary of Trajectory",
"description": "Selected tools (orig messages) · traj_summary_orig_ext conditioned",
},
# test300 variants — excerpt extracted from <trajectory_summary> in run files
"test300-gpt-oss-120b-less-chars": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-gpt-oss-120b-less-chars-v1",
"label": "Gemini 2.5 Pro 0",
"description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_seed0_less_chars",
},
"test300-gpt-oss-120b": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-gpt-oss-120b-v1",
"label": "Gemini 2.5 Pro 1",
"description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_seed0",
},
"test300-gemini-2p5-pro": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-gemini-2p5-pro-v1",
"label": "Gemini 2.5 Pro 2",
"description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_gemini-2.5-pro_1_seed0",
},
"test300-gemini-3p1-pro": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-gemini-3p1-pro-v1",
"label": "Gemini 3.1 Pro Preview",
"description": "test300 · traj_summary_orig_ext_selected_tools_gpt-oss-120b_gemini_3.1-pro-preview_seed0",
},
"test300-random-seed0": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed0-v1",
"label": "Random Seed 0",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed0_gpt-oss-120b_seed0",
},
"test300-random-seed1": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed1-v1",
"label": "Random Seed 1",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed1_gpt-oss-120b_seed0",
},
"test300-random-seed3": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed3-v1",
"label": "Random Seed 3",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed3_gpt-oss-120b_seed0",
},
"test300-random-seed4": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed4-v1",
"label": "Random Seed 4",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed4_gpt-oss-120b_seed0",
},
"test300-random-seed5": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed5-v1",
"label": "Random Seed 5",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed5_gpt-oss-120b_seed0",
},
"test300-random-seed6": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed6-v1",
"label": "Random Seed 6",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed6_gpt-oss-120b_seed0",
},
"test300-random-seed7": {
"repo": "timchen0618/browsecomp-plus-sel-tools-test300-random-seed7-v1",
"label": "Random Seed 7",
"description": "test300 · traj_summary_orig_ext_selected_tools_random_seed7_gpt-oss-120b_seed0",
},
}
DEFAULT_VARIANT = "traj_summary_orig_ext"
_cache: dict[str, list] = {}
def _load(variant: str) -> list:
if variant in _cache:
return _cache[variant]
if variant not in VARIANTS:
raise ValueError(f"Unknown variant: {variant!r}")
repo = VARIANTS[variant]["repo"]
ds = load_dataset(repo, split="train")
rows = []
for row in ds:
tool_counts = {}
try:
tool_counts = json.loads(row.get("tool_call_counts") or "{}")
except Exception:
pass
total_tool_calls = sum(tool_counts.values()) if tool_counts else 0
sel_idx = row["selected_indices"]
if isinstance(sel_idx, str):
try:
sel_idx = json.loads(sel_idx)
except Exception:
sel_idx = []
new_traj = row.get("new_trajectory") or ""
new_status = "completed" if "[Final Answer]" in new_traj else "incomplete"
rows.append({
"query_id": str(row["query_id"]),
"rationale": row.get("rationale") or "",
"selected_indices": sel_idx,
"k_requested": int(row["k_requested"]),
"k_effective": int(row["k_effective"]),
"excerpt": row["excerpt"],
"new_trajectory": new_traj,
"direct_answer": bool(row["direct_answer"]),
"tool_call_counts": tool_counts,
"total_tool_calls": total_tool_calls,
"status": row["status"],
"new_status": new_status,
"question": row.get("question") or "",
"correct_answer": row.get("correct_answer") or "",
"correct": row.get("correct"), # None if not available
})
_cache[variant] = rows
return rows
@bp.get("/")
def get_data():
variant = request.args.get("variant", DEFAULT_VARIANT)
try:
rows = _load(variant)
return jsonify({"rows": rows, "variant": variant, "variants": VARIANTS})
except Exception as e:
return jsonify({"error": str(e)}), 500
@bp.get("/variants")
def get_variants():
return jsonify({"variants": VARIANTS, "default": DEFAULT_VARIANT})
@bp.post("/reload")
def reload_data():
import shutil, os
variant = request.args.get("variant", DEFAULT_VARIANT)
if variant in _cache:
del _cache[variant]
if variant not in VARIANTS:
return jsonify({"error": f"Unknown variant: {variant!r}"}), 400
repo = VARIANTS[variant]["repo"]
try:
# Delete cached dataset dir so stale schema metadata doesn't block new columns
cache_base = os.path.expanduser("~/.cache/huggingface/hub")
dataset_cache_name = "datasets--" + repo.replace("/", "--")
dataset_cache_path = os.path.join(cache_base, dataset_cache_name)
if os.path.exists(dataset_cache_path):
shutil.rmtree(dataset_cache_path)
rows = _load(variant)
return jsonify({"status": "ok", "count": len(rows), "variant": variant})
except Exception as e:
return jsonify({"error": str(e)}), 500