| | import os |
| | import traceback |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| |
|
| | DEFAULT_MODEL = "google/gemma-3-4b-it" |
| | ADAPTER_PATH = "./gemma-lecture-adapter" |
| | HUB_ADAPTER_ID = "noufwithy/gemma-lecture-adapter" |
| |
|
| | SUMMARIZE_SYSTEM_PROMPT = """You are a lecture summarization assistant. |
| | Summarize the following lecture transcription into a comprehensive, structured summary with these sections: |
| | - **Summary**: A concise overview of what the lecture covered |
| | - **Key Points**: The main concepts, definitions, and important details covered in the lecture (use bullet points) |
| | - **Action Points**: Any tasks, assignments, or follow-up actions mentioned by the lecturer |
| | |
| | Cover ALL topics discussed. Do not omit any major points. |
| | Output ONLY the summary. No explanations or extra commentary.""" |
| |
|
| | |
| | MCQ_SYSTEM_PROMPT = """You are an educational quiz generator. |
| | Based on the following lecture transcription, generate a multiple choice question |
| | with 4 options labeled A-D and indicate the correct answer. |
| | |
| | Format: |
| | Q1. [Question] |
| | A) [Option] |
| | B) [Option] |
| | C) [Option] |
| | D) [Option] |
| | Correct Answer: [Letter] |
| | |
| | Output ONLY the question. No explanations or extra commentary.""" |
| |
|
| | SHORT_ANSWER_SYSTEM_PROMPT = """You are an educational quiz generator. |
| | Based on the following lecture transcription, generate a short answer question |
| | with the expected answer. |
| | |
| | Format: |
| | Q1. [Question] |
| | Expected Answer: [Brief answer] |
| | |
| | Output ONLY the question. No explanations or extra commentary.""" |
| |
|
| | NUM_MCQ = 5 |
| | NUM_SHORT_ANSWER = 3 |
| |
|
| | _model = None |
| | _tokenizer = None |
| |
|
| |
|
| | def _load_model(model_id: str = DEFAULT_MODEL, adapter_path: str = ADAPTER_PATH): |
| | global _model, _tokenizer |
| | if _model is not None: |
| | return _model, _tokenizer |
| |
|
| | _tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
|
| | |
| | adapter_source = adapter_path if os.path.isdir(adapter_path) else HUB_ADAPTER_ID |
| |
|
| | |
| | try: |
| | print(f"Loading model with LoRA adapter from {adapter_source}...") |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | device_map="auto", |
| | dtype=torch.bfloat16, |
| | attn_implementation="eager", |
| | ) |
| | _model = PeftModel.from_pretrained(base_model, adapter_source) |
| | _model.eval() |
| | print("LoRA adapter loaded successfully on bfloat16 base model.") |
| | except Exception as e: |
| | print(f"LoRA adapter failed ({e}), falling back to base model...") |
| | traceback.print_exc() |
| | _model = AutoModelForCausalLM.from_pretrained( |
| | model_id, device_map="auto", dtype=torch.bfloat16, |
| | ) |
| |
|
| | return _model, _tokenizer |
| |
|
| |
|
| | def _generate(messages, max_new_tokens=2048, do_sample=False, temperature=0.7): |
| | """Generate text using model.generate() directly.""" |
| | model, tokenizer = _load_model() |
| |
|
| | |
| | prompt = tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) |
| | input_ids = inputs["input_ids"].to(model.device) |
| | attention_mask = inputs["attention_mask"].to(model.device) |
| |
|
| | print(f"[DEBUG] input length: {input_ids.shape[-1]} tokens") |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=do_sample, |
| | temperature=temperature if do_sample else None, |
| | top_p=0.9 if do_sample else None, |
| | repetition_penalty=1.3, |
| | ) |
| |
|
| | |
| | new_tokens = outputs[0][input_ids.shape[-1]:] |
| | print(f"[DEBUG] generated {len(new_tokens)} new tokens") |
| |
|
| | response = tokenizer.decode(new_tokens, skip_special_tokens=True) |
| | return response.strip() |
| |
|
| |
|
| | def _is_good_summary(text: str, transcript: str = "") -> bool: |
| | """Check if a summary meets minimum quality: long enough, not repetitive, not parroting.""" |
| | if len(text) < 100: |
| | return False |
| |
|
| | |
| | from collections import Counter |
| | for chunks in [ |
| | [s.strip() for s in text.split("\n") if s.strip()], |
| | [s.strip() for s in text.split(".") if s.strip()], |
| | ]: |
| | if chunks: |
| | counts = Counter(chunks) |
| | most_common_count = counts.most_common(1)[0][1] |
| | if most_common_count >= 2: |
| | print(f"[QUALITY] Repetitive output detected ({most_common_count} repeats)") |
| | return False |
| |
|
| | |
| | if transcript: |
| | summary_words = set(text.lower().split()) |
| | transcript_words = set(transcript.lower().split()) |
| | if summary_words and transcript_words: |
| | overlap = len(summary_words & transcript_words) / len(summary_words) |
| | if overlap > 0.85: |
| | print(f"[QUALITY] Summary too similar to transcript ({overlap:.0%} word overlap)") |
| | return False |
| |
|
| | |
| | bullet_count = text.count("- ") |
| | has_key_points = "key points" in text.lower() |
| | if has_key_points and bullet_count < 3: |
| | print(f"[QUALITY] Summary has too few key points ({bullet_count})") |
| | return False |
| |
|
| | |
| | unique_lines = set(s.strip() for s in text.split("\n") if s.strip() and len(s.strip()) > 10) |
| | if len(unique_lines) < 5: |
| | print(f"[QUALITY] Summary too shallow ({len(unique_lines)} unique lines)") |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def _generate_with_base_fallback(messages, transcript="", **kwargs): |
| | """Generate with adapter first. If output is bad, retry with base model.""" |
| | result = _generate(messages, **kwargs) |
| |
|
| | if _is_good_summary(result, transcript=transcript): |
| | return result |
| |
|
| | |
| | model, _ = _load_model() |
| | if isinstance(model, PeftModel): |
| | print("[FALLBACK] Adapter output too short or repetitive, retrying with base model...") |
| | model.disable_adapter_layers() |
| | try: |
| | result = _generate(messages, **kwargs) |
| | finally: |
| | model.enable_adapter_layers() |
| | print(f"[FALLBACK] base model response length: {len(result)}") |
| |
|
| | return result |
| |
|
| |
|
| | def _truncate_transcript(transcript: str, max_words: int = 4000) -> str: |
| | """Truncate transcript to fit model's effective context (trained on 3072 tokens).""" |
| | words = transcript.split() |
| | if len(words) <= max_words: |
| | return transcript |
| | print(f"[TRUNCATE] Transcript has {len(words)} words, truncating to {max_words}") |
| | return " ".join(words[:max_words]) |
| |
|
| |
|
| | def summarize_lecture(transcript: str, model: str = DEFAULT_MODEL) -> str: |
| | """Summarize a lecture transcript using Gemma.""" |
| | if not transcript or not transcript.strip(): |
| | return "" |
| |
|
| | truncated = _truncate_transcript(transcript) |
| | messages = [ |
| | {"role": "system", "content": SUMMARIZE_SYSTEM_PROMPT}, |
| | {"role": "user", "content": f"Lecture transcription:\n\n{truncated}"}, |
| | ] |
| | |
| | result = _generate_with_base_fallback(messages, transcript=transcript, do_sample=True, temperature=0.3) |
| | print(f"[DEBUG summarize] response length: {len(result)}") |
| | return result |
| |
|
| |
|
| | def _extract_question_text(result: str) -> str: |
| | """Extract just the question text (first line after Q number) for dedup comparison.""" |
| | import re |
| | match = re.search(r'Q\d+\.\s*(.+)', result) |
| | return match.group(1).strip().lower() if match else result.strip().lower() |
| |
|
| |
|
| | def _is_good_quiz_answer(result: str, transcript: str = "") -> bool: |
| | """Check if a generated quiz question is reasonable quality.""" |
| | |
| | if "Correct Answer:" not in result and "Expected Answer:" not in result: |
| | print(f"[QUALITY] Response has no valid question format (missing Correct/Expected Answer)") |
| | return False |
| |
|
| | |
| | if "Q1." not in result: |
| | print(f"[QUALITY] Response missing Q1. question marker") |
| | return False |
| |
|
| | |
| | if "Expected Answer:" in result: |
| | answer = result.split("Expected Answer:")[-1].strip() |
| | |
| | vague_phrases = ["right here", "this arrow", "at this point", "this one", "over here", "right there"] |
| | if any(phrase in answer.lower() for phrase in vague_phrases): |
| | print(f"[QUALITY] Short answer too vague: {answer}") |
| | return False |
| | if len(answer.split()) < 2: |
| | print(f"[QUALITY] Short answer too short: {answer}") |
| | return False |
| |
|
| | |
| | if "Correct Answer:" in result and "Expected Answer:" not in result: |
| | import re |
| | for label in ["A)", "B)", "C)", "D)"]: |
| | if label not in result: |
| | print(f"[QUALITY] MCQ missing option {label}") |
| | return False |
| | |
| | options = re.findall(r'[A-D]\)\s*(.+)', result) |
| | unique_options = set(opt.strip().lower() for opt in options) |
| | if len(unique_options) < 3: |
| | print(f"[QUALITY] MCQ has duplicate options ({len(unique_options)} unique out of {len(options)})") |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def _dedup_mcq_options(result: str) -> str: |
| | """Remove duplicate MCQ options, keeping unique ones only.""" |
| | import re |
| | options = re.findall(r'([A-D])\)\s*(.+)', result) |
| | if len(options) != 4: |
| | return result |
| |
|
| | seen = {} |
| | unique = [] |
| | for label, text in options: |
| | key = text.strip().lower() |
| | if key not in seen: |
| | seen[key] = True |
| | unique.append((label, text.strip())) |
| |
|
| | if len(unique) == len(options): |
| | return result |
| |
|
| | print(f"[QUALITY] Removed {len(options) - len(unique)} duplicate MCQ option(s)") |
| | |
| | lines = result.split("\n") |
| | new_lines = [] |
| | option_idx = 0 |
| | labels = ["A", "B", "C", "D"] |
| | for line in lines: |
| | if re.match(r'^[A-D]\)', line): |
| | if option_idx < len(unique): |
| | new_lines.append(f"{labels[option_idx]}) {unique[option_idx][1]}") |
| | option_idx += 1 |
| | else: |
| | new_lines.append(line) |
| |
|
| | return "\n".join(new_lines) |
| |
|
| |
|
| | def _generate_quiz_with_fallback(messages, transcript="", **kwargs): |
| | """Generate a quiz question with adapter, fall back to base model if bad.""" |
| | result = _generate(messages, **kwargs) |
| |
|
| | if _is_good_quiz_answer(result, transcript): |
| | return result |
| |
|
| | model, _ = _load_model() |
| | if isinstance(model, PeftModel): |
| | print("[FALLBACK] Quiz answer bad, retrying with base model...") |
| | model.disable_adapter_layers() |
| | try: |
| | result = _generate(messages, **kwargs) |
| | finally: |
| | model.enable_adapter_layers() |
| |
|
| | return result |
| |
|
| |
|
| | def _normalize_words(text: str) -> set[str]: |
| | """Strip punctuation from words for cleaner comparison.""" |
| | import re |
| | return set(re.sub(r'[^\w\s]', '', word) for word in text.split() if word.strip()) |
| |
|
| |
|
| | def _is_duplicate(result: str, existing_parts: list[str]) -> bool: |
| | """Check if a generated question is too similar to any already generated.""" |
| | new_q = _extract_question_text(result) |
| | for part in existing_parts: |
| | old_q = _extract_question_text(part) |
| | |
| | new_words = _normalize_words(new_q) |
| | old_words = _normalize_words(old_q) |
| | if not new_words or not old_words: |
| | continue |
| | overlap = len(new_words & old_words) / min(len(new_words), len(old_words)) |
| | if overlap > 0.7: |
| | print(f"[QUALITY] Duplicate question detected ({overlap:.0%} word overlap)") |
| | return True |
| | return False |
| |
|
| |
|
| | def generate_quiz(transcript: str, model: str = DEFAULT_MODEL) -> str: |
| | """Generate quiz questions from a lecture transcript using Gemma. |
| | |
| | Generates questions one at a time to match training format, then combines them. |
| | Skips duplicate questions automatically. |
| | """ |
| | if not transcript or not transcript.strip(): |
| | return "" |
| |
|
| | transcript = _truncate_transcript(transcript) |
| | parts = [] |
| | max_retries = 2 |
| |
|
| | |
| | for i in range(NUM_MCQ): |
| | print(f"[DEBUG quiz] generating MCQ {i + 1}/{NUM_MCQ}...") |
| | messages = [ |
| | {"role": "system", "content": MCQ_SYSTEM_PROMPT}, |
| | {"role": "user", "content": f"Lecture transcription:\n\n{transcript}"}, |
| | ] |
| | good = False |
| | for attempt in range(1 + max_retries): |
| | result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True) |
| | if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts): |
| | good = True |
| | break |
| | print(f"[DEBUG quiz] MCQ {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...") |
| | if good: |
| | result = _dedup_mcq_options(result) |
| | result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1) |
| | parts.append(result) |
| | else: |
| | print(f"[DEBUG quiz] MCQ {i + 1} dropped (unreliable after {1 + max_retries} attempts)") |
| |
|
| | |
| | for i in range(NUM_SHORT_ANSWER): |
| | q_num = NUM_MCQ + i + 1 |
| | print(f"[DEBUG quiz] generating short answer {i + 1}/{NUM_SHORT_ANSWER}...") |
| | messages = [ |
| | {"role": "system", "content": SHORT_ANSWER_SYSTEM_PROMPT}, |
| | {"role": "user", "content": f"Lecture transcription:\n\n{transcript}"}, |
| | ] |
| | good = False |
| | for attempt in range(1 + max_retries): |
| | result = _generate_quiz_with_fallback(messages, transcript=transcript, max_new_tokens=256, do_sample=True) |
| | if _is_good_quiz_answer(result, transcript) and not _is_duplicate(result, parts): |
| | good = True |
| | break |
| | print(f"[DEBUG quiz] short answer {i + 1} attempt {attempt + 1} was bad or duplicate, retrying...") |
| | if good: |
| | result = result.replace("Q1.", f"Q{len(parts) + 1}.", 1) |
| | parts.append(result) |
| | else: |
| | print(f"[DEBUG quiz] short answer {i + 1} dropped (unreliable after {1 + max_retries} attempts)") |
| |
|
| | combined = "\n\n".join(parts) |
| | print(f"[DEBUG quiz] total response length: {len(combined)}") |
| | return combined |
| |
|