Spaces:
Paused
Paused
| """ | |
| inference.py — PayOps LLM Inference Script | |
| =========================================== | |
| Runs an LLM agent (via OpenAI-compatible client) against the PayOps | |
| OpenEnv environment and prints per-task scores plus a final normalised score. | |
| Required environment variables | |
| -------------------------------- | |
| API_BASE_URL OpenAI-compatible API endpoint (e.g. https://api.openai.com/v1 | |
| or a HuggingFace Inference endpoint) | |
| MODEL_NAME Model identifier (e.g. gpt-4o-mini, meta-llama/Llama-3-8B-Instruct) | |
| HF_TOKEN API key / HuggingFace token used as the Bearer credential | |
| Optional | |
| -------- | |
| PAYOPS_BASE_URL Base URL of the running PayOps server (default: http://localhost:7860) | |
| Usage | |
| ----- | |
| export API_BASE_URL="https://api.openai.com/v1" | |
| export MODEL_NAME="gpt-4o-mini" | |
| export HF_TOKEN="sk-..." | |
| python inference.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import ssl | |
| import sys | |
| import time | |
| import urllib.request | |
| import urllib.error | |
| from typing import Optional | |
| # SSL context: HF Spaces use Let's Encrypt certs that Python 3.14 on macOS | |
| # cannot verify because the system trust store is not used by default. | |
| # We disable verification only for requests to the PayOps server itself; | |
| # the OpenAI client has its own SSL handling. | |
| _SSL_CTX = ssl.create_default_context() | |
| _SSL_CTX.check_hostname = False | |
| _SSL_CTX.verify_mode = ssl.CERT_NONE | |
| # ── OpenAI client (uses env vars) ───────────────────────────────────────── | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| print("ERROR: openai package not installed. Run: pip install openai", file=sys.stderr) | |
| sys.exit(1) | |
| # ── Config ───────────────────────────────────────────────────────────────── | |
| API_BASE_URL: str = (os.environ.get("API_BASE_URL") or "https://router.huggingface.co/v1").rstrip("/") | |
| MODEL_NAME: str = os.environ.get("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" | |
| # HF_TOKEN is mandatory per competition spec. | |
| HF_TOKEN: str = os.environ.get("HF_TOKEN") or "" | |
| OPENAI_API_KEY: str = os.environ.get("OPENAI_API_KEY") or "" | |
| _API_KEY: str = OPENAI_API_KEY or HF_TOKEN # resolved credential | |
| if not _API_KEY: | |
| raise ValueError("HF_TOKEN environment variable must be set") | |
| PAYOPS_URL: str = os.environ.get("PAYOPS_BASE_URL", "http://localhost:7860").rstrip("/") | |
| # Fixed seed keeps per-episode amount/risk jitter deterministic across runs. | |
| # Override with INFERENCE_SEED=<int> or INFERENCE_SEED=random for a fresh episode. | |
| _SEED_ENV = os.environ.get("INFERENCE_SEED", "42") | |
| INFERENCE_SEED: Optional[int] = None if _SEED_ENV.lower() == "random" else int(_SEED_ENV) | |
| VALID_ACTIONS = { | |
| "approve", "reject", "flag", "escalate", "hold", | |
| "inspect", "request_docs", "verify_kyc", "contact_sender", "file_sar", | |
| } | |
| # ── Contest-spec structured logging ────────────────────────────────────────── | |
| # Exactly three line types to stdout (machine-parseable by the evaluator): | |
| # [START] once per episode | |
| # [STEP] once per env.step() call | |
| # [END] always, even on exception | |
| TASK_NAME = os.environ.get("PAYOPS_TASK", "payops") | |
| ENV_NAME = "payops_env" | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) | |
| def log_end(success: bool, steps: int, rewards: list, **_kw: object) -> None: | |
| rewards_str = "[" + ", ".join(f"{r:.2f}" for r in rewards) + "]" | |
| print(f"[END] success={success} steps={steps} rewards={rewards_str}", flush=True) | |
| SYSTEM_PROMPT = """You are a Payment Operations analyst reviewing financial transactions. | |
| For each transaction you will receive structured data including: | |
| - transaction_id, amount, currency, sender, receiver | |
| - risk_score (0.0=safe, 1.0=highly risky) | |
| - ml_confidence (model confidence in fraud detection) | |
| - flags (list of risk indicators) | |
| - kyc_status (verified/pending/failed/none) | |
| - velocity_1h, velocity_24h (transaction counts in time windows) | |
| - country_risk (sender/receiver country risk levels) | |
| - account_age_days, previous_sars | |
| - inspection_notes, docs_notes, kyc_notes, contact_notes (if investigation was done) | |
| You must respond with EXACTLY one of these action names (lowercase, no punctuation): | |
| approve - transaction appears legitimate, release funds | |
| reject - transaction is fraudulent, block and reverse | |
| flag - suspicious, mark for manual review | |
| escalate - complex case, send to senior analyst | |
| hold - pause transaction pending further info | |
| inspect - examine transaction history and patterns (investigation, costs 0.1 budget) | |
| request_docs - request supporting documents from sender (costs 0.2 budget) | |
| verify_kyc - run full KYC verification (costs 0.2 budget) | |
| contact_sender - contact transaction sender for confirmation (costs 0.3 budget) | |
| file_sar - file Suspicious Activity Report with regulators (costs 0.05 budget) | |
| Investigation actions (inspect/request_docs/verify_kyc/contact_sender/file_sar) do NOT | |
| advance the task — they reveal more information. Terminal actions (approve/reject/flag/ | |
| escalate/hold) resolve the task and advance to the next one. | |
| Budget: you start with 5.0 units. Investigation actions cost budget. Stay within budget. | |
| Respond with ONLY the action name. No explanation, no punctuation, no quotes. | |
| """ | |
| def _make_request(url: str, method: str = "GET", body: Optional[dict] = None) -> dict: | |
| """Simple HTTP helper using stdlib (no httpx dependency).""" | |
| data = json.dumps(body).encode() if body else None | |
| headers = {"Content-Type": "application/json"} | |
| req = urllib.request.Request(url, data=data, headers=headers, method=method) | |
| try: | |
| with urllib.request.urlopen(req, timeout=30, context=_SSL_CTX) as resp: | |
| return json.loads(resp.read()) | |
| except urllib.error.HTTPError as e: | |
| err_body = e.read().decode() | |
| raise RuntimeError(f"HTTP {e.code} from {url}: {err_body}") from e | |
| def _unwrap_response(resp: dict) -> dict: | |
| """ | |
| Normalise server responses that may use the official OpenEnv wire format | |
| ``{"observation": {...}, "reward": ..., "done": ...}`` or the legacy flat | |
| format where the observation fields are at the top level. Always returns | |
| a flat observation dict with ``reward`` and ``done`` populated. | |
| """ | |
| if "observation" in resp and isinstance(resp["observation"], dict): | |
| obs = dict(resp["observation"]) | |
| # Promote top-level reward/done into the obs dict for uniform access | |
| if "reward" in resp: | |
| obs["reward"] = resp["reward"] | |
| if "done" in resp: | |
| obs["done"] = resp["done"] | |
| return obs | |
| # Legacy flat format — already the observation | |
| return resp | |
| def build_observation_text(obs: dict) -> str: | |
| """Format the observation dict into a readable prompt for the LLM.""" | |
| lines = [ | |
| f"Transaction ID : {obs.get('transaction_id', 'N/A')}", | |
| f"Task : {obs.get('task_id', 'N/A')} ({obs.get('task_difficulty', 'N/A')} difficulty)", | |
| f"Amount : {obs.get('currency', 'USD')} {obs.get('amount', 0):,.2f}", | |
| f"Sender : {obs.get('sender', 'N/A')}", | |
| f"Receiver : {obs.get('receiver', 'N/A')}", | |
| f"Risk Score : {obs.get('risk_score', 0):.2f}", | |
| f"ML Confidence : {obs.get('ml_confidence', 0):.2f}", | |
| f"Flags : {', '.join(obs.get('flags', [])) or 'none'}", | |
| f"KYC Status : {obs.get('kyc_status', 'N/A')}", | |
| f"Velocity 1h : {obs.get('velocity_1h', 0)} txns", | |
| f"Velocity 24h : {obs.get('velocity_24h', 0)} txns", | |
| f"Country Risk : {obs.get('country_risk', 'N/A')}", | |
| f"Account Age : {obs.get('account_age_days', 0)} days", | |
| f"Previous SARs : {obs.get('previous_sars', 0)}", | |
| f"Budget Remaining: {obs.get('budget_remaining', 5.0):.2f}", | |
| f"Step : {obs.get('chain_step', 1)}/{obs.get('chain_total', 1)}", | |
| ] | |
| # Append revealed intel if present | |
| for field, label in [ | |
| ("inspection_notes", "Inspection Notes"), | |
| ("docs_notes", "Documents Notes"), | |
| ("kyc_notes", "KYC Notes"), | |
| ("contact_notes", "Contact Notes"), | |
| ]: | |
| if obs.get(field): | |
| lines.append(f"{label:15}: {obs[field]}") | |
| return "\n".join(lines) | |
| def call_llm(client: OpenAI, observation_text: str) -> str: | |
| """Call the LLM and return a cleaned action string.""" | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Transaction details:\n\n{observation_text}\n\nYour action:"}, | |
| ], | |
| max_tokens=16, | |
| temperature=0.0, | |
| ) | |
| raw = response.choices[0].message.content.strip().lower() | |
| # Strip any punctuation/quotes the model might add | |
| action = raw.strip("\"'.,;: \n\t") | |
| return action | |
| def run_inference() -> tuple: | |
| """ | |
| Run one full episode of PayOps with the LLM agent. | |
| Returns (grader_result_dict, per_step_rewards, total_steps). | |
| """ | |
| # ── validate env vars ─────────────────────────────────────────────── | |
| # OPENAI_API_KEY or HF_TOKEN must be set (either is acceptable) | |
| if not _API_KEY: | |
| print("ERROR: Set OPENAI_API_KEY (or HF_TOKEN) with your API credential.", file=sys.stderr) | |
| sys.exit(1) | |
| # API_BASE_URL and MODEL_NAME always have built-in defaults (via 'or' fallback above). | |
| # Only HF_TOKEN / OPENAI_API_KEY genuinely requires the caller to supply a value. | |
| print(f"PayOps Inference", file=sys.stderr) | |
| print(f" Model : {MODEL_NAME}", file=sys.stderr) | |
| print(f" API Base : {API_BASE_URL}", file=sys.stderr) | |
| print(f" Server : {PAYOPS_URL}", file=sys.stderr) | |
| print(file=sys.stderr) | |
| # ── init OpenAI client ────────────────────────────────────────────── | |
| client = OpenAI( | |
| api_key=_API_KEY, | |
| base_url=API_BASE_URL, | |
| ) | |
| # ── health check ──────────────────────────────────────────────────── | |
| try: | |
| health = _make_request(f"{PAYOPS_URL}/health") | |
| print(f"Server health : {health.get('status')} (v{health.get('version', '?')})", file=sys.stderr) | |
| except Exception as e: | |
| raise RuntimeError(f"Cannot reach PayOps server at {PAYOPS_URL}: {e}") from e | |
| # ── reset environment ──────────────────────────────────────────────── | |
| # Pass a fixed seed so per-episode jitter is deterministic (reproducible scores). | |
| reset_body: dict = {} | |
| if INFERENCE_SEED is not None: | |
| reset_body["seed"] = INFERENCE_SEED | |
| print(f"Episode seed : {INFERENCE_SEED} (set INFERENCE_SEED=random for a fresh run)", file=sys.stderr) | |
| _reset_resp = _make_request(f"{PAYOPS_URL}/reset", method="POST", body=reset_body or None) | |
| # Support official wrapped format {"observation":{...},"reward":...} and legacy flat format | |
| obs = _unwrap_response(_reset_resp) | |
| print(f"Episode start : task={obs['task_id']}, budget={obs['budget_remaining']}", file=sys.stderr) | |
| print(file=sys.stderr) | |
| step_count = 0 | |
| rewards_list: list = [] | |
| start_time = time.time() | |
| # Hard caps to guarantee <20 min runtime on any inference endpoint. | |
| # Worst case: 55 steps × 15 s/call ≈ 825 s (~14 min); well inside the 20 min limit. | |
| MAX_STEPS = 60 # absolute hard cap across the whole episode | |
| MAX_INV_PER_TASK = 3 # max investigation actions per task before forcing terminal decision | |
| INVESTIGATION_ACTIONS = {"inspect", "request_docs", "verify_kyc", "contact_sender", "file_sar"} | |
| TERMINAL_ACTIONS = {"approve", "reject", "flag", "escalate", "hold"} | |
| # Investigation actions rotated when chain gate forces us to investigate | |
| _CHAIN_INV_ROTATION = ["inspect", "request_docs", "verify_kyc", "contact_sender"] | |
| # Per-task investigation counter: task_id → count | |
| inv_counts: dict = {} | |
| # Per-task chain-gate forced-investigation counter: task_id → cumulative count | |
| # (never reset mid-task — only new task IDs start at 0 implicitly) | |
| chain_forced: dict = {} | |
| # ── main loop ──────────────────────────────────────────────────────── | |
| while not obs.get("done", False) and step_count < MAX_STEPS: | |
| obs_text = build_observation_text(obs) | |
| task_id = obs.get("task_id", "?") | |
| budget = obs.get("budget_remaining", 5.0) | |
| # Chain gate detection: if the env blocked our last terminal action because | |
| # the task needs investigation sub-actions first, force one automatically | |
| # without wasting an LLM call. | |
| last_event = obs.get("info", {}).get("event", "") | |
| if last_event == "chain_gate_blocked": | |
| # Accumulate (never reset): ensures we rotate to new inv actions each block | |
| chain_forced[task_id] = chain_forced.get(task_id, 0) + 1 | |
| action = _CHAIN_INV_ROTATION[(chain_forced[task_id] - 1) % len(_CHAIN_INV_ROTATION)] | |
| print(f" ⚠ Chain gate on {task_id} (forced inv #{chain_forced[task_id]}) " | |
| f"→ '{action}'", file=sys.stderr) | |
| else: | |
| # Get action from LLM | |
| try: | |
| action = call_llm(client, obs_text) | |
| except Exception as e: | |
| print(f" LLM error on {task_id}: {e} — defaulting to 'flag'", file=sys.stderr) | |
| action = "flag" | |
| # Validate action | |
| if action not in VALID_ACTIONS: | |
| print(f" ⚠ LLM returned invalid action '{action}' → defaulting to 'flag'", file=sys.stderr) | |
| action = "flag" | |
| # Budget guard: if budget is low, skip expensive investigation actions | |
| expensive = {"contact_sender": 0.3, "request_docs": 0.2, "verify_kyc": 0.2} | |
| if action in expensive and budget < expensive[action] + 0.1: | |
| print(f" ⚠ Insufficient budget ({budget:.2f}) for '{action}' → using 'inspect'", file=sys.stderr) | |
| action = "inspect" if budget >= 0.1 else "flag" | |
| # Loop guard: if model keeps investigating without deciding, force a terminal action | |
| if action in INVESTIGATION_ACTIONS: | |
| inv_counts[task_id] = inv_counts.get(task_id, 0) + 1 | |
| if inv_counts[task_id] > MAX_INV_PER_TASK: | |
| print(f" ⚠ Investigation loop on {task_id} ({inv_counts[task_id]} sub-actions) " | |
| f"→ forcing 'flag'", file=sys.stderr) | |
| action = "flag" | |
| else: | |
| # Reset counter on terminal action (task will advance) | |
| inv_counts[task_id] = 0 | |
| # Step | |
| try: | |
| _step_resp = _make_request( | |
| f"{PAYOPS_URL}/step", | |
| method="POST", | |
| body={"action_type": action, "transaction_id": obs.get("transaction_id")}, | |
| ) | |
| except Exception as e: | |
| print(f" Step error: {e}", file=sys.stderr) | |
| break | |
| # Unwrap official {"observation":{...},"reward":...} or legacy flat format | |
| obs = _unwrap_response(_step_resp) | |
| reward = obs.get("reward", 0) | |
| new_task = obs.get("task_id", "?") | |
| done_flag = obs.get("done", False) | |
| step_count += 1 | |
| rewards_list.append(reward) | |
| log_step(step=step_count, action=action, reward=reward, done=done_flag, error=None) | |
| marker = "✓" if reward > 0 else ("✗" if reward < 0 else "·") | |
| print(f" [{marker}] {task_id:12s} → {action:15s} reward={reward:+.3f} " | |
| f"budget={obs.get('budget_remaining', 0):.2f}", file=sys.stderr) | |
| elapsed = time.time() - start_time | |
| print(f"\nEpisode complete: {step_count} steps in {elapsed:.1f}s", file=sys.stderr) | |
| # ── grade ──────────────────────────────────────────────────────────── | |
| try: | |
| grader = _make_request(f"{PAYOPS_URL}/grader") | |
| except Exception as e: | |
| raise RuntimeError(f"Could not retrieve grader results: {e}") from e | |
| score = grader.get("normalised_score", 0) | |
| passed = grader.get("passed", False) | |
| budgets = grader.get("budget_spent", 0) | |
| print(f"\n{'='*50}", file=sys.stderr) | |
| print(f" Normalised Score : {score:.4f}", file=sys.stderr) | |
| print(f" Total Reward : {grader.get('total_reward', 0):.3f}", file=sys.stderr) | |
| print(f" Max Possible : {grader.get('max_possible_reward', 0):.3f}", file=sys.stderr) | |
| print(f" Budget Spent : {budgets:.2f}", file=sys.stderr) | |
| print(f" Budget Penalty : {grader.get('budget_penalty', 0):.3f}", file=sys.stderr) | |
| print(f" Passed : {'YES ✓' if passed else 'NO ✗'}", file=sys.stderr) | |
| print(f"{'='*50}", file=sys.stderr) | |
| # Per-task breakdown | |
| per_task = grader.get("per_task", []) | |
| correct_count = sum(1 for t in per_task if t.get("correct", False)) | |
| wrong_count = len(per_task) - correct_count | |
| print("\nPer-task breakdown:", file=sys.stderr) | |
| print(f" {'Task':12s} {'Difficulty':10s} {'Agent Action':15s} {'Correct Action':15s} {'Reward':>8s}", file=sys.stderr) | |
| print(f" {'-'*65}", file=sys.stderr) | |
| for t in per_task: | |
| ok = t.get("correct", False) | |
| sym = "✓" if ok else "✗" | |
| print(f" [{sym}] {t.get('task_id','?'):12s} {t.get('difficulty','?'):10s} " | |
| f"{t.get('terminal_action','?'):15s} {t.get('correct_action','?'):15s} " | |
| f"{t.get('weighted_reward', 0):+8.3f}", file=sys.stderr) | |
| pct = (100 * correct_count / len(per_task)) if per_task else 0 | |
| print(f"\n Tasks correct : {correct_count}/{len(per_task)} " | |
| f"({pct:.0f}%) " | |
| f"Wrong: {wrong_count}", file=sys.stderr) | |
| return grader, rewards_list, step_count | |
| if __name__ == "__main__": | |
| import traceback | |
| _rewards: list = [] | |
| _steps = 0 | |
| _success = False | |
| try: | |
| log_start(task=TASK_NAME, env=ENV_NAME, model=MODEL_NAME) | |
| result, _rewards, _steps = run_inference() | |
| _success = result.get("passed", False) | |
| except Exception as exc: | |
| print(f"[ERROR] Fatal exception in inference.py:", file=sys.stderr) | |
| traceback.print_exc(file=sys.stderr) | |
| finally: | |
| log_end(success=_success, steps=_steps, rewards=_rewards) | |
| # Always exit 0 — non-zero exit is interpreted by the evaluator as a crash. | |
| # Agent performance is communicated through the grader score in log_end. | |
| sys.exit(0) | |