Sravanth18 commited on
Commit
9c8bac7
Β·
verified Β·
1 Parent(s): 78896e6

Add src/baseline_runner.py

Browse files
Files changed (1) hide show
  1. src/baseline_runner.py +114 -0
src/baseline_runner.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Baseline runners β€” normal LLM answer and prompt-only honesty baseline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import time
8
+ from pathlib import Path
9
+
10
+ from . import config
11
+ from .llm_client import llm_call
12
+ from .pipeline_runner import load_gold_cases
13
+ from .schemas import BaselineResult, GoldCase
14
+
15
+ # ── Prompts ────────────────────────────────────────────────────────────
16
+
17
+ BASELINE_NORMAL_SYSTEM = "Answer the user's question."
18
+
19
+ BASELINE_HONESTY_SYSTEM = (
20
+ "Answer only if supported by the provided evidence. "
21
+ "If not supported, say you do not know."
22
+ )
23
+
24
+ BASELINE_USER_TEMPLATE = """\
25
+ QUESTION:
26
+ {question}
27
+
28
+ EVIDENCE:
29
+ {evidence_text}
30
+ """
31
+
32
+
33
+ # ── Runner ─────────────────────────────────────────────────────────────
34
+
35
+ def run_baseline(
36
+ cases: list[GoldCase],
37
+ mode: str,
38
+ ) -> list[BaselineResult]:
39
+ system = BASELINE_NORMAL_SYSTEM if mode == "normal" else BASELINE_HONESTY_SYSTEM
40
+ results: list[BaselineResult] = []
41
+
42
+ for i, case in enumerate(cases, 1):
43
+ print(f"[baseline:{mode}] {i}/{len(cases)} case={case.id}")
44
+ t0 = time.perf_counter()
45
+ try:
46
+ user_msg = BASELINE_USER_TEMPLATE.format(
47
+ question=case.question,
48
+ evidence_text=case.evidence_text,
49
+ )
50
+ answer = llm_call(system, user_msg)
51
+ elapsed = (time.perf_counter() - t0) * 1000
52
+ results.append(
53
+ BaselineResult(
54
+ case_id=case.id,
55
+ category=case.category,
56
+ question=case.question,
57
+ answer=answer,
58
+ latency_ms=round(elapsed, 2),
59
+ )
60
+ )
61
+ except Exception as exc:
62
+ elapsed = (time.perf_counter() - t0) * 1000
63
+ results.append(
64
+ BaselineResult(
65
+ case_id=case.id,
66
+ category=case.category,
67
+ question=case.question,
68
+ answer="",
69
+ error=str(exc),
70
+ latency_ms=round(elapsed, 2),
71
+ )
72
+ )
73
+ return results
74
+
75
+
76
+ def save_baseline(results: list[BaselineResult], path: Path) -> None:
77
+ path.parent.mkdir(parents=True, exist_ok=True)
78
+ with open(path, "w") as f:
79
+ for r in results:
80
+ f.write(r.model_dump_json() + "\n")
81
+ print(f"[baseline] saved {len(results)} results β†’ {path}")
82
+
83
+
84
+ # ── CLI ────────────────────────────────────────────────────────────────
85
+
86
+ def main() -> None:
87
+ parser = argparse.ArgumentParser(description="Run baselines")
88
+ parser.add_argument(
89
+ "--mode",
90
+ choices=["normal", "honesty"],
91
+ required=True,
92
+ help="Baseline mode: 'normal' or 'honesty'",
93
+ )
94
+ parser.add_argument("--cases", type=str, default=None)
95
+ parser.add_argument("--output", type=str, default=None)
96
+ args = parser.parse_args()
97
+
98
+ cases = load_gold_cases(Path(args.cases) if args.cases else None)
99
+ print(f"[baseline:{args.mode}] loaded {len(cases)} cases")
100
+
101
+ results = run_baseline(cases, args.mode)
102
+
103
+ if args.output:
104
+ out = Path(args.output)
105
+ elif args.mode == "normal":
106
+ out = config.BASELINE_NORMAL_PATH
107
+ else:
108
+ out = config.BASELINE_HONESTY_PATH
109
+
110
+ save_baseline(results, out)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()