Spaces:
Sleeping
Sleeping
File size: 15,292 Bytes
f2532fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 | import os
import traceback
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
DEFAULT_MODEL = "google/gemma-3-4b-it"
ADAPTER_PATH = "./gemma-lecture-adapter"
HUB_ADAPTER_ID = "noufwithy/gemma-lecture-adapter"
SUMMARIZE_SYSTEM_PROMPT = """You are a lecture summarization assistant.
Summarize the following lecture transcription into a comprehensive, structured summary with these sections:
- **Summary**: A concise overview of what the lecture covered
- **Key Points**: The main concepts, definitions, and important details covered in the lecture (use bullet points)
- **Action Points**: Any tasks, assignments, or follow-up actions mentioned by the lecturer
Cover ALL topics discussed. Do not omit any major points.
Output ONLY the summary. No explanations or extra commentary."""
# Quiz prompts match the training data format exactly (one question per call)
MCQ_SYSTEM_PROMPT = """You are an educational quiz generator.
Based on the following lecture transcription, generate a multiple choice question
with 4 options labeled A-D and indicate the correct answer.
Format:
Q1. [Question]
A) [Option]
B) [Option]
C) [Option]
D) [Option]
Correct Answer: [Letter]
Output ONLY the question. No explanations or extra commentary."""
SHORT_ANSWER_SYSTEM_PROMPT = """You are an educational quiz generator.
Based on the following lecture transcription, generate a short answer question
with the expected answer.
Format:
Q1. [Question]
Expected Answer: [Brief answer]
Output ONLY the question. No explanations or extra commentary."""
NUM_MCQ = 5
NUM_SHORT_ANSWER = 3
_model = None
_tokenizer = None
def _load_model(model_id: str = DEFAULT_MODEL, adapter_path: str = ADAPTER_PATH):
global _model, _tokenizer
if _model is not None:
return _model, _tokenizer
_tokenizer = AutoTokenizer.from_pretrained(model_id)
# Try local adapter first, then HuggingFace Hub, then base model
adapter_source = adapter_path if os.path.isdir(adapter_path) else HUB_ADAPTER_ID
# Load in bfloat16 (bitsandbytes 4-bit/8-bit quantization broken with Gemma 3)
try:
print(f"Loading model with LoRA adapter from {adapter_source}...")
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
dtype=torch.bfloat16,
attn_implementation="eager",
)
_model = PeftModel.from_pretrained(base_model, adapter_source)
_model.eval()
print("LoRA adapter loaded successfully on bfloat16 base model.")
except Exception as e:
print(f"LoRA adapter failed ({e}), falling back to base model...")
traceback.print_exc()
_model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", dtype=torch.bfloat16,
)
return _model, _tokenizer
def _generate(messages, max_new_tokens=2048, do_sample=False, temperature=0.7):
"""Generate text using model.generate() directly."""
model, tokenizer = _load_model()
# Format chat messages into a string, then tokenize
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs["attention_mask"].to(model.device)
print(f"[DEBUG] input length: {input_ids.shape[-1]} tokens")
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature if do_sample else None,
top_p=0.9 if do_sample else None,
repetition_penalty=1.3,
)
# Decode only the new tokens (skip the input)
new_tokens = outputs[0][input_ids.shape[-1]:]
print(f"[DEBUG] generated {len(new_tokens)} new tokens")
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
return response.strip()
def _is_good_summary(text: str, transcript: str = "") -> bool:
"""Check if a summary meets minimum quality: long enough, not repetitive, not parroting."""
if len(text) < 100:
return False
# Check for excessive repetition (same line or sentence repeated 2+ times)
from collections import Counter
for chunks in [
[s.strip() for s in text.split("\n") if s.strip()],
[s.strip() for s in text.split(".") if s.strip()],
]:
if chunks:
counts = Counter(chunks)
most_common_count = counts.most_common(1)[0][1]
if most_common_count >= 2:
print(f"[QUALITY] Repetitive output detected ({most_common_count} repeats)")
return False
# Check if summary is just parroting the transcript (high word overlap)
if transcript:
summary_words = set(text.lower().split())
transcript_words = set(transcript.lower().split())
if summary_words and transcript_words:
overlap = len(summary_words & transcript_words) / len(summary_words)
if overlap > 0.85:
print(f"[QUALITY] Summary too similar to transcript ({overlap:.0%} word overlap)")
return False
# Check if summary has enough key points (at least 3 bullet points)
bullet_count = text.count("- ")
has_key_points = "key points" in text.lower()
if has_key_points and bullet_count < 3:
print(f"[QUALITY] Summary has too few key points ({bullet_count})")
return False
# Check minimum unique content (summary should have substance)
unique_lines = set(s.strip() for s in text.split("\n") if s.strip() and len(s.strip()) > 10)
if len(unique_lines) < 5:
print(f"[QUALITY] Summary too shallow ({len(unique_lines)} unique lines)")
return False
return True
def _generate_with_base_fallback(messages, transcript="", **kwargs):
"""Generate with adapter first. If output is bad, retry with base model."""
result = _generate(messages, **kwargs)
if _is_good_summary(result, transcript=transcript):
return result
# Adapter output is bad, try base model
model, _ = _load_model()
if isinstance(model, PeftModel):
print("[FALLBACK] Adapter output too short or repetitive, retrying with base model...")
model.disable_adapter_layers()
try:
result = _generate(messages, **kwargs)
finally:
model.enable_adapter_layers()
print(f"[FALLBACK] base model response length: {len(result)}")
return result
def _truncate_transcript(transcript: str, max_words: int = 4000) -> str:
"""Truncate transcript to fit model's effective context (trained on 3072 tokens)."""
words = transcript.split()
if len(words) <= max_words:
return transcript
print(f"[TRUNCATE] Transcript has {len(words)} words, truncating to {max_words}")
return " ".join(words[:max_words])
def summarize_lecture(transcript: str, model: str = DEFAULT_MODEL) -> str:
"""Summarize a lecture transcript using Gemma."""
if not transcript or not transcript.strip():
return ""
truncated = _truncate_transcript(transcript)
messages = [
{"role": "system", "content": SUMMARIZE_SYSTEM_PROMPT},
{"role": "user", "content": f"Lecture transcription:\n\n{truncated}"},
]
# Try adapter first, fall back to base model if quality is bad
result = _generate_with_base_fallback(messages, transcript=transcript, do_sample=True, temperature=0.3)
print(f"[DEBUG summarize] response length: {len(result)}")
return result
def _extract_question_text(result: str) -> str:
"""Extract just the question text (first line after Q number) for dedup comparison."""
import re
match = re.search(r'Q\d+\.\s*(.+)', result)
return match.group(1).strip().lower() if match else result.strip().lower()
def _is_good_quiz_answer(result: str, transcript: str = "") -> bool:
"""Check if a generated quiz question is reasonable quality."""
# Reject if response doesn't match any expected format (no question generated)
if "Correct Answer:" not in result and "Expected Answer:" not in result:
print(f"[QUALITY] Response has no valid question format (missing Correct/Expected Answer)")
return False
# Reject if there's no actual question (Q1. pattern)
if "Q1." not in result:
print(f"[QUALITY] Response missing Q1. question marker")
return False
# Short answer: reject if expected answer is just a transcript fragment with no real content
if "Expected Answer:" in result:
answer = result.split("Expected Answer:")[-1].strip()
# Reject vague/pointer answers like "right here", "this arrow", "at this point"
vague_phrases = ["right here", "this arrow", "at this point", "this one", "over here", "right there"]
if any(phrase in answer.lower() for phrase in vague_phrases):
print(f"[QUALITY] Short answer too vague: {answer}")
return False
if len(answer.split()) < 2:
print(f"[QUALITY] Short answer too short: {answer}")
return False
# MCQ: reject if it doesn't have 4 options or has duplicate options
if "Correct Answer:" in result and "Expected Answer:" not in result:
import re
for label in ["A)", "B)", "C)", "D)"]:
if label not in result:
print(f"[QUALITY] MCQ missing option {label}")
return False
# Reject if options are mostly duplicated
options = re.findall(r'[A-D]\)\s*(.+)', result)
unique_options = set(opt.strip().lower() for opt in options)
if len(unique_options) < 3:
print(f"[QUALITY] MCQ has duplicate options ({len(unique_options)} unique out of {len(options)})")
return False
return True
def _dedup_mcq_options(result: str) -> str:
"""Remove duplicate MCQ options, keeping unique ones only."""
import re
options = re.findall(r'([A-D])\)\s*(.+)', result)
if len(options) != 4:
return result
seen = {}
unique = []
for label, text in options:
key = text.strip().lower()
if key not in seen:
seen[key] = True
unique.append((label, text.strip()))
if len(unique) == len(options):
return result # no duplicates
print(f"[QUALITY] Removed {len(options) - len(unique)} duplicate MCQ option(s)")
# Rebuild with correct labels
lines = result.split("\n")
new_lines = []
option_idx = 0
labels = ["A", "B", "C", "D"]
for line in lines:
if re.match(r'^[A-D]\)', line):
if option_idx < len(unique):
new_lines.append(f"{labels[option_idx]}) {unique[option_idx][1]}")
option_idx += 1
else:
new_lines.append(line)
return "\n".join(new_lines)
def _generate_quiz_with_fallback(messages, transcript="", **kwargs):
"""Generate a quiz question with adapter, fall back to base model if bad."""
result = _generate(messages, **kwargs)
if _is_good_quiz_answer(result, transcript):
return result
model, _ = _load_model()
if isinstance(model, PeftModel):
print("[FALLBACK] Quiz answer bad, retrying with base model...")
model.disable_adapter_layers()
try:
result = _generate(messages, **kwargs)
finally:
model.enable_adapter_layers()
return result
def _normalize_words(text: str) -> set[str]:
"""Strip punctuation from words for cleaner comparison."""
import re
return set(re.sub(r'[^\w\s]', '', word) for word in text.split() if word.strip())
def _is_duplicate(result: str, existing_parts: list[str]) -> bool:
"""Check if a generated question is too similar to any already generated."""
new_q = _extract_question_text(result)
for part in existing_parts:
old_q = _extract_question_text(part)
# Check if questions share most of their words (punctuation-stripped)
new_words = _normalize_words(new_q)
old_words = _normalize_words(old_q)
if not new_words or not old_words:
continue
overlap = len(new_words & old_words) / min(len(new_words), len(old_words))
if overlap > 0.7:
print(f"[QUALITY] Duplicate question detected ({overlap:.0%} word overlap)")
return True
return False
def generate_quiz(transcript: str, model: str = DEFAULT_MODEL) -> str:
"""Generate quiz questions from a lecture transcript using Gemma.
Generates questions one at a time to match training format, then combines them.
Skips duplicate questions automatically.
"""
if not transcript or not transcript.strip():
return ""
transcript = _truncate_transcript(transcript)
parts = []
max_retries = 2 # extra attempts per question if duplicate
# Generate MCQs one at a time (matches training: one MCQ per example)
for i in range(NUM_MCQ):
print(f"[DEBUG quiz] generating MCQ {i + 1}/{NUM_MCQ}...")
messages = [
{"role": "system", "content": MCQ_SYSTEM_PROMPT},
{"role": "user", "content": f"Lecture transcription:\n\n{transcript}"},
]
good = False
for attempt in range(1 + max_retries):
result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True)
if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts):
good = True
break
print(f"[DEBUG quiz] MCQ {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...")
if good:
result = _dedup_mcq_options(result)
result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1)
parts.append(result)
else:
print(f"[DEBUG quiz] MCQ {i + 1} dropped (unreliable after {1 + max_retries} attempts)")
# Generate short answer questions one at a time
for i in range(NUM_SHORT_ANSWER):
q_num = NUM_MCQ + i + 1
print(f"[DEBUG quiz] generating short answer {i + 1}/{NUM_SHORT_ANSWER}...")
messages = [
{"role": "system", "content": SHORT_ANSWER_SYSTEM_PROMPT},
{"role": "user", "content": f"Lecture transcription:\n\n{transcript}"},
]
good = False
for attempt in range(1 + max_retries):
result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True)
if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts):
good = True
break
print(f"[DEBUG quiz] short answer {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...")
if good:
result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1)
parts.append(result)
else:
print(f"[DEBUG quiz] short answer {i + 1} dropped (unreliable after {1 + max_retries} attempts)")
combined = "\n\n".join(parts)
print(f"[DEBUG quiz] total response length: {len(combined)}")
return combined
|