| | import json |
| | import logging |
| | import re |
| | from pathlib import Path |
| | from typing import Any, Dict |
| |
|
| | |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class TranscriptCorrector: |
| | """ |
| | Utility class for correcting transcripts using a multi-pass approach: |
| | 1. Therapeutic Terminology Validation |
| | 2. LLM-based Contextual Correction (Mocked for now) |
| | 3. Structural Alignment (Basic regex cleanup) |
| | """ |
| |
|
| | def __init__(self, config_path: str = "ai/config/therapeutic_terminology.json"): |
| | """ |
| | Initialize the TranscriptCorrector with terminology configuration. |
| | |
| | Args: |
| | config_path: Path to the JSON configuration file containing |
| | therapeutic terms. |
| | """ |
| | self.config_path = Path(config_path) |
| | self.terms: Dict[str, Any] = self._load_terminology() |
| |
|
| | def _load_terminology(self) -> Dict[str, Any]: |
| | """Load therapeutic terminology from JSON config.""" |
| | try: |
| | |
| | if not self.config_path.exists(): |
| | |
| | |
| | |
| | |
| | base_path = Path(__file__).parent.parent |
| | alt_path = base_path / "config" / "therapeutic_terminology.json" |
| |
|
| | if alt_path.exists(): |
| | self.config_path = alt_path |
| | else: |
| | logger.warning( |
| | f"Terminology config not found at {self.config_path} or " |
| | f"{alt_path}. Using empty config." |
| | ) |
| | return { |
| | "cptsd_terms": [], |
| | "medical_terms": [], |
| | "common_misinterpretations": {}, |
| | } |
| |
|
| | with open(self.config_path, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| | except Exception as e: |
| | logger.error(f"Failed to load terminology config: {e}") |
| | return { |
| | "cptsd_terms": [], |
| | "medical_terms": [], |
| | "common_misinterpretations": {}, |
| | } |
| |
|
| | def correct_transcript(self, text: str, context: str = "therapy_session") -> str: |
| | """ |
| | Main entry point for transcript correction. |
| | |
| | Args: |
| | text: Single string containing the transcript text to correct. |
| | context: Context hint for LLM correction. |
| | |
| | Returns: |
| | Corrected transcript text. |
| | """ |
| | if not text or not text.strip(): |
| | return "" |
| |
|
| | |
| | text = self._clean_structure(text) |
| |
|
| | |
| | text = self._apply_terminology_fixes(text) |
| |
|
| | |
| | text = self._llm_contextual_correction(text, context) |
| |
|
| | return text |
| |
|
| | def _clean_structure(self, text: str) -> str: |
| | """Remove filler words and normalize whitespace.""" |
| | |
| | fillers = r"\b(um|uh|err|ah|like|you know|I mean)\b,?\s*" |
| |
|
| | |
| | cleaned = re.sub(fillers, "", text, flags=re.IGNORECASE) |
| |
|
| | |
| | cleaned = re.sub(r"\s+", " ", cleaned).strip() |
| |
|
| | return cleaned |
| |
|
| | def _apply_terminology_fixes(self, text: str) -> str: |
| | """Apply deterministic terminology fixes from config.""" |
| | misinterpretations = self.terms.get("common_misinterpretations", {}) |
| |
|
| | for bad_term, good_term in misinterpretations.items(): |
| | |
| | pattern = re.compile(re.escape(bad_term), re.IGNORECASE) |
| | text = pattern.sub(good_term, text) |
| |
|
| | return text |
| |
|
| | def _llm_contextual_correction(self, text: str, context: str) -> str: |
| | """ |
| | Mock function for GPT-4 based correction. |
| | In the future, this will call the LLM service to fix grammar and nuances. |
| | """ |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | return text |
| |
|
| | def validate_term_coverage(self, text: str) -> Dict[str, float]: |
| | """ |
| | Calculate metrics on how well the transcript effectively uses domain |
| | terminology. Useful for validation pass. |
| | """ |
| | cptsd_terms = {t.lower() for t in self.terms.get("cptsd_terms", [])} |
| | medical_terms = {t.lower() for t in self.terms.get("medical_terms", [])} |
| |
|
| | text_lower = text.lower() |
| |
|
| | found_cptsd = sum(term in text_lower for term in cptsd_terms) |
| | found_medical = sum(term in text_lower for term in medical_terms) |
| |
|
| | total_domain_terms = len(cptsd_terms) + len(medical_terms) |
| | found_total = found_cptsd + found_medical |
| |
|
| | |
| | coverage_score = ( |
| | found_total / total_domain_terms if total_domain_terms > 0 else 0.0 |
| | ) |
| |
|
| | return { |
| | "cptsd_term_count": found_cptsd, |
| | "medical_term_count": found_medical, |
| | "domain_coverage_score": round(coverage_score, 4), |
| | } |
| |
|