pavankumarvk commited on
Commit
a7e89c0
Β·
verified Β·
1 Parent(s): a685dfd

Upload 2 files

Browse files
Files changed (2) hide show
  1. text_detector_inference.py +131 -0
  2. text_detector_model.py +203 -0
text_detector_inference.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ text_detector_inference.py
3
+ ==========================
4
+ Inference wrapper for HybridAITextDetector.
5
+ Designed to be imported by app.py (Gradio) in the Hugging Face Space.
6
+
7
+ Usage
8
+ -----
9
+ from text_detector_inference import TextDetectorInference
10
+
11
+ detector = TextDetectorInference(
12
+ checkpoint="best_text_detector.pt",
13
+ threshold=0.5
14
+ )
15
+ result = detector.predict("Some text here...")
16
+ """
17
+
18
+ import os
19
+ import torch
20
+ from transformers import AutoTokenizer
21
+ from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
22
+
23
+
24
+ class TextDetectorInference:
25
+ """
26
+ Thin wrapper around HybridAITextDetector for single-text prediction.
27
+
28
+ Parameters
29
+ ----------
30
+ checkpoint : str
31
+ Path to the .pt state-dict file.
32
+ threshold : float
33
+ Decision boundary for the sigmoid probability (default 0.5).
34
+ Set to the optimal F1 threshold found during evaluation.
35
+ device : torch.device or None
36
+ Auto-detects CUDA if None.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ checkpoint: str = "best_text_detector.pt",
42
+ threshold: float = 0.5,
43
+ device: torch.device = None,
44
+ ):
45
+ self.threshold = threshold
46
+ self.device = device or torch.device(
47
+ "cuda" if torch.cuda.is_available() else "cpu"
48
+ )
49
+
50
+ print(f"[TextDetector] Loading tokenizer from {MODEL_NAME}...")
51
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
52
+
53
+ if os.path.exists(checkpoint):
54
+ print(f"[TextDetector] Loading checkpoint: {checkpoint}")
55
+ self.model = HybridAITextDetector()
56
+ self.model.load_state_dict(
57
+ torch.load(checkpoint, map_location=self.device)
58
+ )
59
+ self.model.eval().to(self.device)
60
+ print("[TextDetector] βœ… Model ready")
61
+ else:
62
+ print(f"[TextDetector] ⚠️ Checkpoint '{checkpoint}' not found. "
63
+ "Model NOT loaded β€” predictions will fail.")
64
+ self.model = None
65
+
66
+ # ------------------------------------------------------------------
67
+ def predict(self, text: str) -> dict:
68
+ """
69
+ Classify a single text string.
70
+
71
+ Returns
72
+ -------
73
+ dict with keys:
74
+ label : "AI-Generated" or "Human-Written"
75
+ confidence : probability of the predicted class (0-1)
76
+ ai_prob : raw P(AI-generated)
77
+ human_prob : 1 - ai_prob
78
+ """
79
+ if self.model is None:
80
+ return {"error": "Model not loaded β€” missing checkpoint file."}
81
+
82
+ text = text.strip()
83
+ if not text:
84
+ return {"error": "Input text is empty."}
85
+
86
+ enc = self.tokenizer(
87
+ text,
88
+ truncation=True,
89
+ padding="max_length",
90
+ max_length=MAX_LENGTH,
91
+ return_tensors="pt",
92
+ )
93
+
94
+ input_ids = enc["input_ids"].to(self.device)
95
+ attention_mask = enc["attention_mask"].to(self.device)
96
+ token_type_ids = enc.get(
97
+ "token_type_ids",
98
+ torch.zeros_like(enc["input_ids"]),
99
+ ).to(self.device)
100
+
101
+ with torch.no_grad():
102
+ logit = self.model(input_ids, attention_mask, token_type_ids)
103
+ ai_prob = torch.sigmoid(logit).item()
104
+
105
+ human_prob = 1.0 - ai_prob
106
+ is_ai = ai_prob >= self.threshold
107
+ label = "AI-Generated" if is_ai else "Human-Written"
108
+ confidence = ai_prob if is_ai else human_prob
109
+
110
+ return {
111
+ "label": label,
112
+ "confidence": round(confidence, 4),
113
+ "ai_prob": round(ai_prob, 4),
114
+ "human_prob": round(human_prob, 4),
115
+ }
116
+
117
+ # ------------------------------------------------------------------
118
+ def predict_batch(self, texts: list[str]) -> list[dict]:
119
+ """Run predict() on a list of texts. Returns list of result dicts."""
120
+ return [self.predict(t) for t in texts]
121
+
122
+ # ------------------------------------------------------------------
123
+ def format_for_gradio(self, text: str) -> tuple[str, float, dict]:
124
+ """
125
+ Convenience wrapper that returns values in a Gradio-friendly format:
126
+ (label_string, confidence_float, full_result_dict)
127
+ """
128
+ result = self.predict(text)
129
+ if "error" in result:
130
+ return result["error"], 0.0, result
131
+ return result["label"], result["confidence"], result
text_detector_model.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ text_detector_model.py
3
+ ======================
4
+ Standalone model definition for HybridAITextDetector.
5
+ Import this in both training scripts and the Gradio app.
6
+
7
+ Architecture:
8
+ DeBERTa-v3-small β†’ [BiLSTM | CNN | Transformer] β†’ CrossAttentionFusion β†’ Classifier
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from transformers import AutoModel
15
+
16
+ # ─── Constants ───────────────────────────────────────────────────────────────
17
+ MODEL_NAME = "microsoft/deberta-v3-small"
18
+ MAX_LENGTH = 128
19
+ NUM_CLASSES = 1 # binary: sigmoid output
20
+
21
+
22
+ # ─── Sub-modules ─────────────────────────────────────────────────────────────
23
+
24
+ class AttentionPool(nn.Module):
25
+ """Soft attention pooling over a sequence of vectors."""
26
+ def __init__(self, dim: int):
27
+ super().__init__()
28
+ self.attn = nn.Linear(dim, 1)
29
+
30
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
31
+ weights = self.attn(x) # (B, T, 1)
32
+ if mask is not None:
33
+ weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float("-inf"))
34
+ weights = torch.softmax(weights, dim=1) # (B, T, 1)
35
+ return (weights * x).sum(dim=1) # (B, dim)
36
+
37
+
38
+ class BiLSTMBranch(nn.Module):
39
+ """2-layer Bidirectional LSTM with Attention Pooling."""
40
+ def __init__(self, input_dim: int, hidden_dim: int = 128):
41
+ super().__init__()
42
+ self.lstm = nn.LSTM(
43
+ input_dim, hidden_dim,
44
+ num_layers=2,
45
+ batch_first=True,
46
+ dropout=0.2,
47
+ bidirectional=True,
48
+ )
49
+ self.pool = AttentionPool(hidden_dim * 2)
50
+ self.proj = nn.Linear(hidden_dim * 2, 128)
51
+
52
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
53
+ out, _ = self.lstm(x) # (B, T, 256)
54
+ pooled = self.pool(out, mask) # (B, 256)
55
+ return F.gelu(self.proj(pooled)) # (B, 128)
56
+
57
+
58
+ class CNNBranch(nn.Module):
59
+ """Multi-kernel 1D CNN with Global MaxPooling."""
60
+ def __init__(self, input_dim: int):
61
+ super().__init__()
62
+ self.conv3 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
63
+ self.conv5 = nn.Conv1d(input_dim, 64, kernel_size=5, padding=2)
64
+ self.conv7 = nn.Conv1d(input_dim, 64, kernel_size=7, padding=3)
65
+ self.bn3 = nn.BatchNorm1d(64)
66
+ self.bn5 = nn.BatchNorm1d(64)
67
+ self.bn7 = nn.BatchNorm1d(64)
68
+ self.proj = nn.Linear(192, 128)
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ x_t = x.permute(0, 2, 1) # (B, D, T)
72
+ c3 = F.gelu(self.bn3(self.conv3(x_t)))
73
+ c5 = F.gelu(self.bn5(self.conv5(x_t)))
74
+ c7 = F.gelu(self.bn7(self.conv7(x_t)))
75
+ p3 = c3.max(dim=-1).values
76
+ p5 = c5.max(dim=-1).values
77
+ p7 = c7.max(dim=-1).values
78
+ cat = torch.cat([p3, p5, p7], dim=-1) # (B, 192)
79
+ return F.gelu(self.proj(cat)) # (B, 128)
80
+
81
+
82
+ class TransformerBranch(nn.Module):
83
+ """Lightweight Transformer Encoder with Attention Pooling."""
84
+ def __init__(self, input_dim: int):
85
+ super().__init__()
86
+ self.proj_in = nn.Linear(input_dim, 128)
87
+ encoder_layer = nn.TransformerEncoderLayer(
88
+ d_model=128, nhead=4,
89
+ dim_feedforward=256,
90
+ dropout=0.1,
91
+ batch_first=True,
92
+ norm_first=True,
93
+ )
94
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
95
+ self.pool = AttentionPool(128)
96
+
97
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
98
+ x = F.gelu(self.proj_in(x)) # (B, T, 128)
99
+ src_key_padding_mask = (mask == 0) if mask is not None else None
100
+ out = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
101
+ return self.pool(out, mask) # (B, 128)
102
+
103
+
104
+ class CrossAttentionFusion(nn.Module):
105
+ """Fuse 3 branch outputs via multi-head self-attention (3-token sequence)."""
106
+ def __init__(self, dim: int = 128):
107
+ super().__init__()
108
+ self.q = nn.Linear(dim, dim)
109
+ self.k = nn.Linear(dim, dim)
110
+ self.v = nn.Linear(dim, dim)
111
+ self.scale = dim ** 0.5
112
+ self.proj = nn.Linear(dim, dim)
113
+
114
+ def forward(
115
+ self,
116
+ lstm_out: torch.Tensor,
117
+ cnn_out: torch.Tensor,
118
+ trans_out: torch.Tensor,
119
+ ) -> torch.Tensor:
120
+ stacked = torch.stack([lstm_out, cnn_out, trans_out], dim=1) # (B, 3, 128)
121
+ Q = self.q(stacked)
122
+ K = self.k(stacked)
123
+ V = self.v(stacked)
124
+ attn = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1)
125
+ out = torch.bmm(attn, V).mean(dim=1) # (B, 128)
126
+ return F.gelu(self.proj(out))
127
+
128
+
129
+ # ─── Main Model ──────────────────────────────────────────────────────────────
130
+
131
+ class HybridAITextDetector(nn.Module):
132
+ """
133
+ Hybrid AI-generated text detector.
134
+
135
+ Inputs
136
+ ------
137
+ input_ids : (B, T) long tensor
138
+ attention_mask : (B, T) long tensor β€” 1 = real token, 0 = pad
139
+ token_type_ids : (B, T) long tensor
140
+
141
+ Output
142
+ ------
143
+ logits : (B, 1) float β€” apply sigmoid to get P(AI-generated)
144
+ """
145
+
146
+ def __init__(self):
147
+ super().__init__()
148
+ self.deberta = AutoModel.from_pretrained(MODEL_NAME)
149
+
150
+ # Freeze first 6 transformer layers
151
+ for name, param in self.deberta.named_parameters():
152
+ if any(f"layer.{i}." in name for i in range(6)):
153
+ param.requires_grad = False
154
+ else:
155
+ param.requires_grad = True
156
+
157
+ hidden = self.deberta.config.hidden_size # 768 for deberta-v3-small
158
+
159
+ self.lstm_branch = BiLSTMBranch(hidden)
160
+ self.cnn_branch = CNNBranch(hidden)
161
+ self.trans_branch = TransformerBranch(hidden)
162
+ self.fusion = CrossAttentionFusion(dim=128)
163
+
164
+ self.classifier = nn.Sequential(
165
+ nn.LayerNorm(128),
166
+ nn.Linear(128, 128),
167
+ nn.GELU(),
168
+ nn.Dropout(0.4),
169
+ nn.Linear(128, 64),
170
+ nn.GELU(),
171
+ nn.Dropout(0.3),
172
+ nn.Linear(64, 1),
173
+ )
174
+
175
+ def forward(
176
+ self,
177
+ input_ids: torch.Tensor,
178
+ attention_mask: torch.Tensor,
179
+ token_type_ids: torch.Tensor,
180
+ ) -> torch.Tensor:
181
+ out = self.deberta(
182
+ input_ids=input_ids,
183
+ attention_mask=attention_mask,
184
+ token_type_ids=token_type_ids,
185
+ )
186
+ hidden = out.last_hidden_state # (B, T, 768)
187
+ lstm_out = self.lstm_branch(hidden, attention_mask)
188
+ cnn_out = self.cnn_branch(hidden)
189
+ trans_out = self.trans_branch(hidden, attention_mask)
190
+ fused = self.fusion(lstm_out, cnn_out, trans_out)
191
+ return self.classifier(fused) # (B, 1)
192
+
193
+
194
+ # ─── Convenience inference helper ────────────────────────────────────────────
195
+
196
+ def load_model(checkpoint_path: str, device: torch.device = None) -> HybridAITextDetector:
197
+ """Load a trained HybridAITextDetector from a .pt checkpoint."""
198
+ if device is None:
199
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
200
+ model = HybridAITextDetector()
201
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
202
+ model.eval().to(device)
203
+ return model