pixelated-training-code / utils /transcript_corrector.py
oneblackmage's picture
Upload folder using huggingface_hub
7849935 verified
import json
import logging
import re
from pathlib import Path
from typing import Any, Dict
# Configure logger
logger = logging.getLogger(__name__)
class TranscriptCorrector:
"""
Utility class for correcting transcripts using a multi-pass approach:
1. Therapeutic Terminology Validation
2. LLM-based Contextual Correction (Mocked for now)
3. Structural Alignment (Basic regex cleanup)
"""
def __init__(self, config_path: str = "ai/config/therapeutic_terminology.json"):
"""
Initialize the TranscriptCorrector with terminology configuration.
Args:
config_path: Path to the JSON configuration file containing
therapeutic terms.
"""
self.config_path = Path(config_path)
self.terms: Dict[str, Any] = self._load_terminology()
def _load_terminology(self) -> Dict[str, Any]:
"""Load therapeutic terminology from JSON config."""
try:
# Handle relative paths from project root if needed
if not self.config_path.exists():
# Try relative to the current file location
# structure is usually ai/utils/transcript_corrector.py
# config is at ai/config/therapeutic_terminology.json
# so we go up 2 levels
base_path = Path(__file__).parent.parent
alt_path = base_path / "config" / "therapeutic_terminology.json"
if alt_path.exists():
self.config_path = alt_path
else:
logger.warning(
f"Terminology config not found at {self.config_path} or "
f"{alt_path}. Using empty config."
)
return {
"cptsd_terms": [],
"medical_terms": [],
"common_misinterpretations": {},
}
with open(self.config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load terminology config: {e}")
return {
"cptsd_terms": [],
"medical_terms": [],
"common_misinterpretations": {},
}
def correct_transcript(self, text: str, context: str = "therapy_session") -> str:
"""
Main entry point for transcript correction.
Args:
text: Single string containing the transcript text to correct.
context: Context hint for LLM correction.
Returns:
Corrected transcript text.
"""
if not text or not text.strip():
return ""
# Pass 1: Basic Structural Cleanup
text = self._clean_structure(text)
# Pass 2: Terminology Replacement
text = self._apply_terminology_fixes(text)
# Pass 3: LLM Contextual Correction (Mocked)
text = self._llm_contextual_correction(text, context)
return text
def _clean_structure(self, text: str) -> str:
"""Remove filler words and normalize whitespace."""
# Common filler words in speech, optionally followed by a comma
fillers = r"\b(um|uh|err|ah|like|you know|I mean)\b,?\s*"
# Remove fillers (case-insensitive)
cleaned = re.sub(fillers, "", text, flags=re.IGNORECASE)
# Normalize whitespace (replace multiple spaces with single space)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
return cleaned
def _apply_terminology_fixes(self, text: str) -> str:
"""Apply deterministic terminology fixes from config."""
misinterpretations = self.terms.get("common_misinterpretations", {})
for bad_term, good_term in misinterpretations.items():
# Use word boundaries to match whole words/phrases ignoring case
pattern = re.compile(re.escape(bad_term), re.IGNORECASE)
text = pattern.sub(good_term, text)
return text
def _llm_contextual_correction(self, text: str, context: str) -> str:
"""
Mock function for GPT-4 based correction.
In the future, this will call the LLM service to fix grammar and nuances.
"""
# TODO: Implement actual LLM call via external service or local model
# For now, we just log that we would allow the LLM to process this
# and return the text as is (or maybe apply a dummy transformation
# for testing if needed)
# Simulating a check for critical CPTSD terms that might be missed
# If we had an LLM, we'd ask it: "Correct this transcript keeping CPTSD context
# in mind."
return text
def validate_term_coverage(self, text: str) -> Dict[str, float]:
"""
Calculate metrics on how well the transcript effectively uses domain
terminology. Useful for validation pass.
"""
cptsd_terms = {t.lower() for t in self.terms.get("cptsd_terms", [])}
medical_terms = {t.lower() for t in self.terms.get("medical_terms", [])}
text_lower = text.lower()
found_cptsd = sum(term in text_lower for term in cptsd_terms)
found_medical = sum(term in text_lower for term in medical_terms)
total_domain_terms = len(cptsd_terms) + len(medical_terms)
found_total = found_cptsd + found_medical
# This is a naive metric, just for basic validation
coverage_score = (
found_total / total_domain_terms if total_domain_terms > 0 else 0.0
)
return {
"cptsd_term_count": found_cptsd,
"medical_term_count": found_medical,
"domain_coverage_score": round(coverage_score, 4),
}