lecture-processor / lecture_processor.py
GitHub Actions
deploy from GitHub 2026-03-04_03:47:45
f2532fa
import os
import traceback
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
DEFAULT_MODEL = "google/gemma-3-4b-it"
ADAPTER_PATH = "./gemma-lecture-adapter"
HUB_ADAPTER_ID = "noufwithy/gemma-lecture-adapter"
SUMMARIZE_SYSTEM_PROMPT = """You are a lecture summarization assistant.
Summarize the following lecture transcription into a comprehensive, structured summary with these sections:
- **Summary**: A concise overview of what the lecture covered
- **Key Points**: The main concepts, definitions, and important details covered in the lecture (use bullet points)
- **Action Points**: Any tasks, assignments, or follow-up actions mentioned by the lecturer
Cover ALL topics discussed. Do not omit any major points.
Output ONLY the summary. No explanations or extra commentary."""
# Quiz prompts match the training data format exactly (one question per call)
MCQ_SYSTEM_PROMPT = """You are an educational quiz generator.
Based on the following lecture transcription, generate a multiple choice question
with 4 options labeled A-D and indicate the correct answer.
Format:
Q1. [Question]
A) [Option]
B) [Option]
C) [Option]
D) [Option]
Correct Answer: [Letter]
Output ONLY the question. No explanations or extra commentary."""
SHORT_ANSWER_SYSTEM_PROMPT = """You are an educational quiz generator.
Based on the following lecture transcription, generate a short answer question
with the expected answer.
Format:
Q1. [Question]
Expected Answer: [Brief answer]
Output ONLY the question. No explanations or extra commentary."""
NUM_MCQ = 5
NUM_SHORT_ANSWER = 3
_model = None
_tokenizer = None
def _load_model(model_id: str = DEFAULT_MODEL, adapter_path: str = ADAPTER_PATH):
global _model, _tokenizer
if _model is not None:
return _model, _tokenizer
_tokenizer = AutoTokenizer.from_pretrained(model_id)
# Try local adapter first, then HuggingFace Hub, then base model
adapter_source = adapter_path if os.path.isdir(adapter_path) else HUB_ADAPTER_ID
# Load in bfloat16 (bitsandbytes 4-bit/8-bit quantization broken with Gemma 3)
try:
print(f"Loading model with LoRA adapter from {adapter_source}...")
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
dtype=torch.bfloat16,
attn_implementation="eager",
)
_model = PeftModel.from_pretrained(base_model, adapter_source)
_model.eval()
print("LoRA adapter loaded successfully on bfloat16 base model.")
except Exception as e:
print(f"LoRA adapter failed ({e}), falling back to base model...")
traceback.print_exc()
_model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", dtype=torch.bfloat16,
)
return _model, _tokenizer
def _generate(messages, max_new_tokens=2048, do_sample=False, temperature=0.7):
"""Generate text using model.generate() directly."""
model, tokenizer = _load_model()
# Format chat messages into a string, then tokenize
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs["attention_mask"].to(model.device)
print(f"[DEBUG] input length: {input_ids.shape[-1]} tokens")
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature if do_sample else None,
top_p=0.9 if do_sample else None,
repetition_penalty=1.3,
)
# Decode only the new tokens (skip the input)
new_tokens = outputs[0][input_ids.shape[-1]:]
print(f"[DEBUG] generated {len(new_tokens)} new tokens")
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
return response.strip()
def _is_good_summary(text: str, transcript: str = "") -> bool:
"""Check if a summary meets minimum quality: long enough, not repetitive, not parroting."""
if len(text) < 100:
return False
# Check for excessive repetition (same line or sentence repeated 2+ times)
from collections import Counter
for chunks in [
[s.strip() for s in text.split("\n") if s.strip()],
[s.strip() for s in text.split(".") if s.strip()],
]:
if chunks:
counts = Counter(chunks)
most_common_count = counts.most_common(1)[0][1]
if most_common_count >= 2:
print(f"[QUALITY] Repetitive output detected ({most_common_count} repeats)")
return False
# Check if summary is just parroting the transcript (high word overlap)
if transcript:
summary_words = set(text.lower().split())
transcript_words = set(transcript.lower().split())
if summary_words and transcript_words:
overlap = len(summary_words & transcript_words) / len(summary_words)
if overlap > 0.85:
print(f"[QUALITY] Summary too similar to transcript ({overlap:.0%} word overlap)")
return False
# Check if summary has enough key points (at least 3 bullet points)
bullet_count = text.count("- ")
has_key_points = "key points" in text.lower()
if has_key_points and bullet_count < 3:
print(f"[QUALITY] Summary has too few key points ({bullet_count})")
return False
# Check minimum unique content (summary should have substance)
unique_lines = set(s.strip() for s in text.split("\n") if s.strip() and len(s.strip()) > 10)
if len(unique_lines) < 5:
print(f"[QUALITY] Summary too shallow ({len(unique_lines)} unique lines)")
return False
return True
def _generate_with_base_fallback(messages, transcript="", **kwargs):
"""Generate with adapter first. If output is bad, retry with base model."""
result = _generate(messages, **kwargs)
if _is_good_summary(result, transcript=transcript):
return result
# Adapter output is bad, try base model
model, _ = _load_model()
if isinstance(model, PeftModel):
print("[FALLBACK] Adapter output too short or repetitive, retrying with base model...")
model.disable_adapter_layers()
try:
result = _generate(messages, **kwargs)
finally:
model.enable_adapter_layers()
print(f"[FALLBACK] base model response length: {len(result)}")
return result
def _truncate_transcript(transcript: str, max_words: int = 4000) -> str:
"""Truncate transcript to fit model's effective context (trained on 3072 tokens)."""
words = transcript.split()
if len(words) <= max_words:
return transcript
print(f"[TRUNCATE] Transcript has {len(words)} words, truncating to {max_words}")
return " ".join(words[:max_words])
def summarize_lecture(transcript: str, model: str = DEFAULT_MODEL) -> str:
"""Summarize a lecture transcript using Gemma."""
if not transcript or not transcript.strip():
return ""
truncated = _truncate_transcript(transcript)
messages = [
{"role": "system", "content": SUMMARIZE_SYSTEM_PROMPT},
{"role": "user", "content": f"Lecture transcription:\n\n{truncated}"},
]
# Try adapter first, fall back to base model if quality is bad
result = _generate_with_base_fallback(messages, transcript=transcript, do_sample=True, temperature=0.3)
print(f"[DEBUG summarize] response length: {len(result)}")
return result
def _extract_question_text(result: str) -> str:
"""Extract just the question text (first line after Q number) for dedup comparison."""
import re
match = re.search(r'Q\d+\.\s*(.+)', result)
return match.group(1).strip().lower() if match else result.strip().lower()
def _is_good_quiz_answer(result: str, transcript: str = "") -> bool:
"""Check if a generated quiz question is reasonable quality."""
# Reject if response doesn't match any expected format (no question generated)
if "Correct Answer:" not in result and "Expected Answer:" not in result:
print(f"[QUALITY] Response has no valid question format (missing Correct/Expected Answer)")
return False
# Reject if there's no actual question (Q1. pattern)
if "Q1." not in result:
print(f"[QUALITY] Response missing Q1. question marker")
return False
# Short answer: reject if expected answer is just a transcript fragment with no real content
if "Expected Answer:" in result:
answer = result.split("Expected Answer:")[-1].strip()
# Reject vague/pointer answers like "right here", "this arrow", "at this point"
vague_phrases = ["right here", "this arrow", "at this point", "this one", "over here", "right there"]
if any(phrase in answer.lower() for phrase in vague_phrases):
print(f"[QUALITY] Short answer too vague: {answer}")
return False
if len(answer.split()) < 2:
print(f"[QUALITY] Short answer too short: {answer}")
return False
# MCQ: reject if it doesn't have 4 options or has duplicate options
if "Correct Answer:" in result and "Expected Answer:" not in result:
import re
for label in ["A)", "B)", "C)", "D)"]:
if label not in result:
print(f"[QUALITY] MCQ missing option {label}")
return False
# Reject if options are mostly duplicated
options = re.findall(r'[A-D]\)\s*(.+)', result)
unique_options = set(opt.strip().lower() for opt in options)
if len(unique_options) < 3:
print(f"[QUALITY] MCQ has duplicate options ({len(unique_options)} unique out of {len(options)})")
return False
return True
def _dedup_mcq_options(result: str) -> str:
"""Remove duplicate MCQ options, keeping unique ones only."""
import re
options = re.findall(r'([A-D])\)\s*(.+)', result)
if len(options) != 4:
return result
seen = {}
unique = []
for label, text in options:
key = text.strip().lower()
if key not in seen:
seen[key] = True
unique.append((label, text.strip()))
if len(unique) == len(options):
return result # no duplicates
print(f"[QUALITY] Removed {len(options) - len(unique)} duplicate MCQ option(s)")
# Rebuild with correct labels
lines = result.split("\n")
new_lines = []
option_idx = 0
labels = ["A", "B", "C", "D"]
for line in lines:
if re.match(r'^[A-D]\)', line):
if option_idx < len(unique):
new_lines.append(f"{labels[option_idx]}) {unique[option_idx][1]}")
option_idx += 1
else:
new_lines.append(line)
return "\n".join(new_lines)
def _generate_quiz_with_fallback(messages, transcript="", **kwargs):
"""Generate a quiz question with adapter, fall back to base model if bad."""
result = _generate(messages, **kwargs)
if _is_good_quiz_answer(result, transcript):
return result
model, _ = _load_model()
if isinstance(model, PeftModel):
print("[FALLBACK] Quiz answer bad, retrying with base model...")
model.disable_adapter_layers()
try:
result = _generate(messages, **kwargs)
finally:
model.enable_adapter_layers()
return result
def _normalize_words(text: str) -> set[str]:
"""Strip punctuation from words for cleaner comparison."""
import re
return set(re.sub(r'[^\w\s]', '', word) for word in text.split() if word.strip())
def _is_duplicate(result: str, existing_parts: list[str]) -> bool:
"""Check if a generated question is too similar to any already generated."""
new_q = _extract_question_text(result)
for part in existing_parts:
old_q = _extract_question_text(part)
# Check if questions share most of their words (punctuation-stripped)
new_words = _normalize_words(new_q)
old_words = _normalize_words(old_q)
if not new_words or not old_words:
continue
overlap = len(new_words & old_words) / min(len(new_words), len(old_words))
if overlap > 0.7:
print(f"[QUALITY] Duplicate question detected ({overlap:.0%} word overlap)")
return True
return False
def generate_quiz(transcript: str, model: str = DEFAULT_MODEL) -> str:
"""Generate quiz questions from a lecture transcript using Gemma.
Generates questions one at a time to match training format, then combines them.
Skips duplicate questions automatically.
"""
if not transcript or not transcript.strip():
return ""
transcript = _truncate_transcript(transcript)
parts = []
max_retries = 2 # extra attempts per question if duplicate
# Generate MCQs one at a time (matches training: one MCQ per example)
for i in range(NUM_MCQ):
print(f"[DEBUG quiz] generating MCQ {i + 1}/{NUM_MCQ}...")
messages = [
{"role": "system", "content": MCQ_SYSTEM_PROMPT},
{"role": "user", "content": f"Lecture transcription:\n\n{transcript}"},
]
good = False
for attempt in range(1 + max_retries):
result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True)
if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts):
good = True
break
print(f"[DEBUG quiz] MCQ {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...")
if good:
result = _dedup_mcq_options(result)
result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1)
parts.append(result)
else:
print(f"[DEBUG quiz] MCQ {i + 1} dropped (unreliable after {1 + max_retries} attempts)")
# Generate short answer questions one at a time
for i in range(NUM_SHORT_ANSWER):
q_num = NUM_MCQ + i + 1
print(f"[DEBUG quiz] generating short answer {i + 1}/{NUM_SHORT_ANSWER}...")
messages = [
{"role": "system", "content": SHORT_ANSWER_SYSTEM_PROMPT},
{"role": "user", "content": f"Lecture transcription:\n\n{transcript}"},
]
good = False
for attempt in range(1 + max_retries):
result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True)
if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts):
good = True
break
print(f"[DEBUG quiz] short answer {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...")
if good:
result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1)
parts.append(result)
else:
print(f"[DEBUG quiz] short answer {i + 1} dropped (unreliable after {1 + max_retries} attempts)")
combined = "\n\n".join(parts)
print(f"[DEBUG quiz] total response length: {len(combined)}")
return combined