File size: 5,784 Bytes
7849935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
import logging
import re
from pathlib import Path
from typing import Any, Dict

# Configure logger
logger = logging.getLogger(__name__)


class TranscriptCorrector:
    """
    Utility class for correcting transcripts using a multi-pass approach:
    1. Therapeutic Terminology Validation
    2. LLM-based Contextual Correction (Mocked for now)
    3. Structural Alignment (Basic regex cleanup)
    """

    def __init__(self, config_path: str = "ai/config/therapeutic_terminology.json"):
        """
        Initialize the TranscriptCorrector with terminology configuration.

        Args:
            config_path: Path to the JSON configuration file containing
                therapeutic terms.
        """
        self.config_path = Path(config_path)
        self.terms: Dict[str, Any] = self._load_terminology()

    def _load_terminology(self) -> Dict[str, Any]:
        """Load therapeutic terminology from JSON config."""
        try:
            # Handle relative paths from project root if needed
            if not self.config_path.exists():
                # Try relative to the current file location
                # structure is usually ai/utils/transcript_corrector.py
                # config is at ai/config/therapeutic_terminology.json
                # so we go up 2 levels
                base_path = Path(__file__).parent.parent
                alt_path = base_path / "config" / "therapeutic_terminology.json"

                if alt_path.exists():
                    self.config_path = alt_path
                else:
                    logger.warning(
                        f"Terminology config not found at {self.config_path} or "
                        f"{alt_path}. Using empty config."
                    )
                    return {
                        "cptsd_terms": [],
                        "medical_terms": [],
                        "common_misinterpretations": {},
                    }

            with open(self.config_path, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception as e:
            logger.error(f"Failed to load terminology config: {e}")
            return {
                "cptsd_terms": [],
                "medical_terms": [],
                "common_misinterpretations": {},
            }

    def correct_transcript(self, text: str, context: str = "therapy_session") -> str:
        """
        Main entry point for transcript correction.

        Args:
            text: Single string containing the transcript text to correct.
            context: Context hint for LLM correction.

        Returns:
            Corrected transcript text.
        """
        if not text or not text.strip():
            return ""

        # Pass 1: Basic Structural Cleanup
        text = self._clean_structure(text)

        # Pass 2: Terminology Replacement
        text = self._apply_terminology_fixes(text)

        # Pass 3: LLM Contextual Correction (Mocked)
        text = self._llm_contextual_correction(text, context)

        return text

    def _clean_structure(self, text: str) -> str:
        """Remove filler words and normalize whitespace."""
        # Common filler words in speech, optionally followed by a comma
        fillers = r"\b(um|uh|err|ah|like|you know|I mean)\b,?\s*"

        # Remove fillers (case-insensitive)
        cleaned = re.sub(fillers, "", text, flags=re.IGNORECASE)

        # Normalize whitespace (replace multiple spaces with single space)
        cleaned = re.sub(r"\s+", " ", cleaned).strip()

        return cleaned

    def _apply_terminology_fixes(self, text: str) -> str:
        """Apply deterministic terminology fixes from config."""
        misinterpretations = self.terms.get("common_misinterpretations", {})

        for bad_term, good_term in misinterpretations.items():
            # Use word boundaries to match whole words/phrases ignoring case
            pattern = re.compile(re.escape(bad_term), re.IGNORECASE)
            text = pattern.sub(good_term, text)

        return text

    def _llm_contextual_correction(self, text: str, context: str) -> str:
        """
        Mock function for GPT-4 based correction.
        In the future, this will call the LLM service to fix grammar and nuances.
        """
        # TODO: Implement actual LLM call via external service or local model
        # For now, we just log that we would allow the LLM to process this
        # and return the text as is (or maybe apply a dummy transformation
        # for testing if needed)

        # Simulating a check for critical CPTSD terms that might be missed
        # If we had an LLM, we'd ask it: "Correct this transcript keeping CPTSD context
        # in mind."

        return text

    def validate_term_coverage(self, text: str) -> Dict[str, float]:
        """
        Calculate metrics on how well the transcript effectively uses domain
        terminology. Useful for validation pass.
        """
        cptsd_terms = {t.lower() for t in self.terms.get("cptsd_terms", [])}
        medical_terms = {t.lower() for t in self.terms.get("medical_terms", [])}

        text_lower = text.lower()

        found_cptsd = sum(term in text_lower for term in cptsd_terms)
        found_medical = sum(term in text_lower for term in medical_terms)

        total_domain_terms = len(cptsd_terms) + len(medical_terms)
        found_total = found_cptsd + found_medical

        # This is a naive metric, just for basic validation
        coverage_score = (
            found_total / total_domain_terms if total_domain_terms > 0 else 0.0
        )

        return {
            "cptsd_term_count": found_cptsd,
            "medical_term_count": found_medical,
            "domain_coverage_score": round(coverage_score, 4),
        }