File size: 3,704 Bytes
8097081
 
 
 
 
be50021
1435892
 
 
 
 
 
8097081
 
1435892
 
8097081
 
1435892
8097081
1435892
8097081
 
1435892
8097081
1435892
8097081
1435892
 
 
8097081
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# src/pytorch_debug_env/reward.py
from __future__ import annotations

from .bug_library import BUG_CATEGORIES

EPSILON = 1e-2


def clamp_score(value: float) -> float:
    """Clamp scores to the open interval (0, 1) for validator compliance."""
    return min(max(value, EPSILON), 1.0 - EPSILON)


def hypothesis_quality(hypothesis: dict, ground_truth: dict) -> float:
    """Score how well the current hypothesis matches the ground truth."""
    quality = 0.0

    if hypothesis.get("affected_file") == ground_truth["primary_bug_file"]:
        quality += 0.45
    elif hypothesis.get("affected_file") in ground_truth.get("related_files", []):
        quality += 0.15

    if hypothesis.get("bug_type") == ground_truth["bug_type"]:
        quality += 0.40
    elif BUG_CATEGORIES.get(hypothesis.get("bug_type")) == BUG_CATEGORIES.get(ground_truth["bug_type"]):
        quality += 0.13

    calibration = 1.0 - abs(hypothesis.get("confidence", 0.5) - min(quality, 1.0))
    quality += 0.15 * calibration
    return round(min(quality, 1.0), 4)


def final_diagnosis_score(diagnosis: dict, ground_truth: dict) -> float:
    """Score the committed diagnosis against the ground truth."""
    score = 0.0

    if diagnosis.get("bug_type") == ground_truth["bug_type"]:
        score += 0.40
    if diagnosis.get("affected_file") == ground_truth["primary_bug_file"]:
        score += 0.25

    predicted = diagnosis.get("line_range", [0, 0])
    actual = ground_truth.get("line_range", [0, 0])
    overlap = line_overlap(predicted, actual)
    score += 0.20 * overlap

    if diagnosis.get("fix_strategy") == ground_truth["fix_strategy"]:
        score += 0.15

    return round(clamp_score(min(score, 1.0)), 4)


def line_overlap(pred: list[int], actual: list[int]) -> float:
    """Compute overlap ratio between two line ranges."""
    p1, p2 = pred
    a1, a2 = actual
    inter = max(0, min(p2, a2) - max(p1, a1) + 1)
    union = max(p2, a2) - min(p1, a1) + 1
    return inter / union if union else 0.0


def compute_step_reward(
    previous_quality: float,
    current_hypothesis: dict,
    ground_truth: dict,
    investigation_target: str | None = None,
    committed_diagnosis: dict | None = None,
    step_num: int = 1,
    max_steps: int = 5,
) -> tuple[float, dict]:
    """Compute step-level reward and diagnostic components."""
    current_quality = hypothesis_quality(current_hypothesis, ground_truth)
    delta = current_quality - previous_quality

    confirmation_bonus = 0.03 * current_quality if abs(delta) < 0.01 else 0.0

    investigation_reward = 0.0
    if investigation_target:
        if investigation_target == ground_truth["primary_bug_file"]:
            investigation_reward = 0.07
        elif investigation_target in ground_truth.get("related_files", []):
            investigation_reward = 0.025
        elif investigation_target == ground_truth.get("red_herring_file"):
            investigation_reward = -0.04
        else:
            investigation_reward = -0.01

    diagnosis_reward = 0.0
    if committed_diagnosis:
        diagnosis_reward = final_diagnosis_score(committed_diagnosis, ground_truth)
        if diagnosis_reward > 0.7:
            diagnosis_reward += max(0.0, 0.08 * (max_steps - step_num))

    total = 0.60 * delta + 0.20 * investigation_reward + 0.20 * diagnosis_reward + confirmation_bonus
    total = round(clamp_score(min(max(total, 0.0), 1.0)), 4)

    return total, {
        "hypothesis_quality": current_quality,
        "hypothesis_delta": round(delta, 4),
        "investigation_reward": round(investigation_reward, 4),
        "diagnosis_reward": round(diagnosis_reward, 4),
        "confirmation_bonus": round(confirmation_bonus, 4),
    }