import os import traceback import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel DEFAULT_MODEL = "google/gemma-3-4b-it" ADAPTER_PATH = "./gemma-lecture-adapter" HUB_ADAPTER_ID = "noufwithy/gemma-lecture-adapter" SUMMARIZE_SYSTEM_PROMPT = """You are a lecture summarization assistant. Summarize the following lecture transcription into a comprehensive, structured summary with these sections: - **Summary**: A concise overview of what the lecture covered - **Key Points**: The main concepts, definitions, and important details covered in the lecture (use bullet points) - **Action Points**: Any tasks, assignments, or follow-up actions mentioned by the lecturer Cover ALL topics discussed. Do not omit any major points. Output ONLY the summary. No explanations or extra commentary.""" # Quiz prompts match the training data format exactly (one question per call) MCQ_SYSTEM_PROMPT = """You are an educational quiz generator. Based on the following lecture transcription, generate a multiple choice question with 4 options labeled A-D and indicate the correct answer. Format: Q1. [Question] A) [Option] B) [Option] C) [Option] D) [Option] Correct Answer: [Letter] Output ONLY the question. No explanations or extra commentary.""" SHORT_ANSWER_SYSTEM_PROMPT = """You are an educational quiz generator. Based on the following lecture transcription, generate a short answer question with the expected answer. Format: Q1. [Question] Expected Answer: [Brief answer] Output ONLY the question. No explanations or extra commentary.""" NUM_MCQ = 5 NUM_SHORT_ANSWER = 3 _model = None _tokenizer = None def _load_model(model_id: str = DEFAULT_MODEL, adapter_path: str = ADAPTER_PATH): global _model, _tokenizer if _model is not None: return _model, _tokenizer _tokenizer = AutoTokenizer.from_pretrained(model_id) # Try local adapter first, then HuggingFace Hub, then base model adapter_source = adapter_path if os.path.isdir(adapter_path) else HUB_ADAPTER_ID # Load in bfloat16 (bitsandbytes 4-bit/8-bit quantization broken with Gemma 3) try: print(f"Loading model with LoRA adapter from {adapter_source}...") base_model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", dtype=torch.bfloat16, attn_implementation="eager", ) _model = PeftModel.from_pretrained(base_model, adapter_source) _model.eval() print("LoRA adapter loaded successfully on bfloat16 base model.") except Exception as e: print(f"LoRA adapter failed ({e}), falling back to base model...") traceback.print_exc() _model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", dtype=torch.bfloat16, ) return _model, _tokenizer def _generate(messages, max_new_tokens=2048, do_sample=False, temperature=0.7): """Generate text using model.generate() directly.""" model, tokenizer = _load_model() # Format chat messages into a string, then tokenize prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) input_ids = inputs["input_ids"].to(model.device) attention_mask = inputs["attention_mask"].to(model.device) print(f"[DEBUG] input length: {input_ids.shape[-1]} tokens") with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature if do_sample else None, top_p=0.9 if do_sample else None, repetition_penalty=1.3, ) # Decode only the new tokens (skip the input) new_tokens = outputs[0][input_ids.shape[-1]:] print(f"[DEBUG] generated {len(new_tokens)} new tokens") response = tokenizer.decode(new_tokens, skip_special_tokens=True) return response.strip() def _is_good_summary(text: str, transcript: str = "") -> bool: """Check if a summary meets minimum quality: long enough, not repetitive, not parroting.""" if len(text) < 100: return False # Check for excessive repetition (same line or sentence repeated 2+ times) from collections import Counter for chunks in [ [s.strip() for s in text.split("\n") if s.strip()], [s.strip() for s in text.split(".") if s.strip()], ]: if chunks: counts = Counter(chunks) most_common_count = counts.most_common(1)[0][1] if most_common_count >= 2: print(f"[QUALITY] Repetitive output detected ({most_common_count} repeats)") return False # Check if summary is just parroting the transcript (high word overlap) if transcript: summary_words = set(text.lower().split()) transcript_words = set(transcript.lower().split()) if summary_words and transcript_words: overlap = len(summary_words & transcript_words) / len(summary_words) if overlap > 0.85: print(f"[QUALITY] Summary too similar to transcript ({overlap:.0%} word overlap)") return False # Check if summary has enough key points (at least 3 bullet points) bullet_count = text.count("- ") has_key_points = "key points" in text.lower() if has_key_points and bullet_count < 3: print(f"[QUALITY] Summary has too few key points ({bullet_count})") return False # Check minimum unique content (summary should have substance) unique_lines = set(s.strip() for s in text.split("\n") if s.strip() and len(s.strip()) > 10) if len(unique_lines) < 5: print(f"[QUALITY] Summary too shallow ({len(unique_lines)} unique lines)") return False return True def _generate_with_base_fallback(messages, transcript="", **kwargs): """Generate with adapter first. If output is bad, retry with base model.""" result = _generate(messages, **kwargs) if _is_good_summary(result, transcript=transcript): return result # Adapter output is bad, try base model model, _ = _load_model() if isinstance(model, PeftModel): print("[FALLBACK] Adapter output too short or repetitive, retrying with base model...") model.disable_adapter_layers() try: result = _generate(messages, **kwargs) finally: model.enable_adapter_layers() print(f"[FALLBACK] base model response length: {len(result)}") return result def _truncate_transcript(transcript: str, max_words: int = 4000) -> str: """Truncate transcript to fit model's effective context (trained on 3072 tokens).""" words = transcript.split() if len(words) <= max_words: return transcript print(f"[TRUNCATE] Transcript has {len(words)} words, truncating to {max_words}") return " ".join(words[:max_words]) def summarize_lecture(transcript: str, model: str = DEFAULT_MODEL) -> str: """Summarize a lecture transcript using Gemma.""" if not transcript or not transcript.strip(): return "" truncated = _truncate_transcript(transcript) messages = [ {"role": "system", "content": SUMMARIZE_SYSTEM_PROMPT}, {"role": "user", "content": f"Lecture transcription:\n\n{truncated}"}, ] # Try adapter first, fall back to base model if quality is bad result = _generate_with_base_fallback(messages, transcript=transcript, do_sample=True, temperature=0.3) print(f"[DEBUG summarize] response length: {len(result)}") return result def _extract_question_text(result: str) -> str: """Extract just the question text (first line after Q number) for dedup comparison.""" import re match = re.search(r'Q\d+\.\s*(.+)', result) return match.group(1).strip().lower() if match else result.strip().lower() def _is_good_quiz_answer(result: str, transcript: str = "") -> bool: """Check if a generated quiz question is reasonable quality.""" # Reject if response doesn't match any expected format (no question generated) if "Correct Answer:" not in result and "Expected Answer:" not in result: print(f"[QUALITY] Response has no valid question format (missing Correct/Expected Answer)") return False # Reject if there's no actual question (Q1. pattern) if "Q1." not in result: print(f"[QUALITY] Response missing Q1. question marker") return False # Short answer: reject if expected answer is just a transcript fragment with no real content if "Expected Answer:" in result: answer = result.split("Expected Answer:")[-1].strip() # Reject vague/pointer answers like "right here", "this arrow", "at this point" vague_phrases = ["right here", "this arrow", "at this point", "this one", "over here", "right there"] if any(phrase in answer.lower() for phrase in vague_phrases): print(f"[QUALITY] Short answer too vague: {answer}") return False if len(answer.split()) < 2: print(f"[QUALITY] Short answer too short: {answer}") return False # MCQ: reject if it doesn't have 4 options or has duplicate options if "Correct Answer:" in result and "Expected Answer:" not in result: import re for label in ["A)", "B)", "C)", "D)"]: if label not in result: print(f"[QUALITY] MCQ missing option {label}") return False # Reject if options are mostly duplicated options = re.findall(r'[A-D]\)\s*(.+)', result) unique_options = set(opt.strip().lower() for opt in options) if len(unique_options) < 3: print(f"[QUALITY] MCQ has duplicate options ({len(unique_options)} unique out of {len(options)})") return False return True def _dedup_mcq_options(result: str) -> str: """Remove duplicate MCQ options, keeping unique ones only.""" import re options = re.findall(r'([A-D])\)\s*(.+)', result) if len(options) != 4: return result seen = {} unique = [] for label, text in options: key = text.strip().lower() if key not in seen: seen[key] = True unique.append((label, text.strip())) if len(unique) == len(options): return result # no duplicates print(f"[QUALITY] Removed {len(options) - len(unique)} duplicate MCQ option(s)") # Rebuild with correct labels lines = result.split("\n") new_lines = [] option_idx = 0 labels = ["A", "B", "C", "D"] for line in lines: if re.match(r'^[A-D]\)', line): if option_idx < len(unique): new_lines.append(f"{labels[option_idx]}) {unique[option_idx][1]}") option_idx += 1 else: new_lines.append(line) return "\n".join(new_lines) def _generate_quiz_with_fallback(messages, transcript="", **kwargs): """Generate a quiz question with adapter, fall back to base model if bad.""" result = _generate(messages, **kwargs) if _is_good_quiz_answer(result, transcript): return result model, _ = _load_model() if isinstance(model, PeftModel): print("[FALLBACK] Quiz answer bad, retrying with base model...") model.disable_adapter_layers() try: result = _generate(messages, **kwargs) finally: model.enable_adapter_layers() return result def _normalize_words(text: str) -> set[str]: """Strip punctuation from words for cleaner comparison.""" import re return set(re.sub(r'[^\w\s]', '', word) for word in text.split() if word.strip()) def _is_duplicate(result: str, existing_parts: list[str]) -> bool: """Check if a generated question is too similar to any already generated.""" new_q = _extract_question_text(result) for part in existing_parts: old_q = _extract_question_text(part) # Check if questions share most of their words (punctuation-stripped) new_words = _normalize_words(new_q) old_words = _normalize_words(old_q) if not new_words or not old_words: continue overlap = len(new_words & old_words) / min(len(new_words), len(old_words)) if overlap > 0.7: print(f"[QUALITY] Duplicate question detected ({overlap:.0%} word overlap)") return True return False def generate_quiz(transcript: str, model: str = DEFAULT_MODEL) -> str: """Generate quiz questions from a lecture transcript using Gemma. Generates questions one at a time to match training format, then combines them. Skips duplicate questions automatically. """ if not transcript or not transcript.strip(): return "" transcript = _truncate_transcript(transcript) parts = [] max_retries = 2 # extra attempts per question if duplicate # Generate MCQs one at a time (matches training: one MCQ per example) for i in range(NUM_MCQ): print(f"[DEBUG quiz] generating MCQ {i + 1}/{NUM_MCQ}...") messages = [ {"role": "system", "content": MCQ_SYSTEM_PROMPT}, {"role": "user", "content": f"Lecture transcription:\n\n{transcript}"}, ] good = False for attempt in range(1 + max_retries): result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True) if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts): good = True break print(f"[DEBUG quiz] MCQ {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...") if good: result = _dedup_mcq_options(result) result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1) parts.append(result) else: print(f"[DEBUG quiz] MCQ {i + 1} dropped (unreliable after {1 + max_retries} attempts)") # Generate short answer questions one at a time for i in range(NUM_SHORT_ANSWER): q_num = NUM_MCQ + i + 1 print(f"[DEBUG quiz] generating short answer {i + 1}/{NUM_SHORT_ANSWER}...") messages = [ {"role": "system", "content": SHORT_ANSWER_SYSTEM_PROMPT}, {"role": "user", "content": f"Lecture transcription:\n\n{transcript}"}, ] good = False for attempt in range(1 + max_retries): result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True) if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts): good = True break print(f"[DEBUG quiz] short answer {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...") if good: result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1) parts.append(result) else: print(f"[DEBUG quiz] short answer {i + 1} dropped (unreliable after {1 + max_retries} attempts)") combined = "\n\n".join(parts) print(f"[DEBUG quiz] total response length: {len(combined)}") return combined