| | """ |
| | PIX-149: Gestalt Fusion Engine |
| | |
| | Unifies the three core emotional intelligence models of Pixelated Empathy |
| | into a single, real-time inference call: |
| | |
| | 1. PsyDefDetect — DMRS Defense Mechanism Classifier (DeBERTa) |
| | 2. Plutchik — 8-emotion wheel scoring (passed in externally) |
| | 3. OCEAN — Big Five personality trait scoring (passed in externally) |
| | |
| | The fused ``GestaltState`` dataclass powers: |
| | - PIX-147: WebSocket "Live X-Ray" resistance monitor |
| | - PIX-148: Adversarial Persona Injection (dynamic defense-aware prompts) |
| | - PIX-150: Empathy PQ metric validation (trainee scoring) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import logging |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from typing import Optional |
| |
|
| | import torch |
| | from transformers import AutoTokenizer |
| |
|
| | from ai.training.defense_mechanisms.constants import DEFENSE_LABELS |
| | from ai.training.defense_mechanisms.dataset import format_dialogue |
| | from ai.training.defense_mechanisms.model import DefenseClassifier, DefensePrediction |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| |
|
| | PLUTCHIK_EMOTIONS = frozenset( |
| | { |
| | "anger", |
| | "anticipation", |
| | "disgust", |
| | "fear", |
| | "joy", |
| | "sadness", |
| | "surprise", |
| | "trust", |
| | } |
| | ) |
| |
|
| | OCEAN_TRAITS = frozenset( |
| | { |
| | "openness", |
| | "conscientiousness", |
| | "extraversion", |
| | "agreeableness", |
| | "neuroticism", |
| | } |
| | ) |
| |
|
| |
|
| | class CrisisLevel(str, Enum): |
| | """ |
| | Behavioral risk level derived from the fused Gestalt state. |
| | """ |
| |
|
| | NONE = "none" |
| | ELEVATED = "elevated" |
| | HIGH = "high" |
| | ACUTE = "acute" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class GestaltState: |
| | """ |
| | Unified emotional-psychological state for a single dialogue turn. |
| | """ |
| |
|
| | |
| | defense_label: int |
| | defense_label_name: str |
| | defense_confidence: float |
| | defense_maturity: Optional[float] |
| | defense_probabilities: dict[str, float] |
| |
|
| | |
| | plutchik_scores: dict[str, float] |
| | dominant_emotion: str |
| | dominant_emotion_intensity: float |
| |
|
| | |
| | ocean_scores: dict[str, float] |
| |
|
| | |
| | crisis_level: CrisisLevel |
| | behavioral_prediction: str |
| | persona_directive: str |
| | breakthrough_score: float |
| |
|
| | raw_metadata: dict = field(default_factory=dict, repr=False) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _ACTION_DEFENSE_LABEL = 1 |
| | _HIGH_ADAPTIVE_LABEL = 7 |
| | _DISAVOWAL_LABEL = 3 |
| | _MAJOR_IMAGE_DISTORTING_LABEL = 2 |
| |
|
| | _CRISIS_AMPLIFYING_EMOTIONS = frozenset({"sadness", "fear", "anger", "disgust"}) |
| | _INJECTION_MATURITY_THRESHOLD = 0.43 |
| | _BREAKTHROUGH_MATURITY_THRESHOLD = 0.71 |
| |
|
| |
|
| | def _dominant_emotion(plutchik: dict[str, float]) -> tuple[str, float]: |
| | """Return the emotion with the highest intensity score.""" |
| | if not plutchik: |
| | return "unknown", 0.0 |
| | dominant = max(plutchik, key=lambda k: plutchik[k]) |
| | return dominant, plutchik[dominant] |
| |
|
| |
|
| | def _compute_crisis_level( |
| | defense_label: int, |
| | defense_maturity: Optional[float], |
| | dominant_emotion: str, |
| | dominant_intensity: float, |
| | ocean_neuroticism: float, |
| | ) -> CrisisLevel: |
| | """Determine crisis level from the fused signals.""" |
| | is_action = defense_label == _ACTION_DEFENSE_LABEL |
| | is_major_distorting = defense_label == _MAJOR_IMAGE_DISTORTING_LABEL |
| | emotion_is_crisis = dominant_emotion in _CRISIS_AMPLIFYING_EMOTIONS |
| |
|
| | if is_action: |
| | return ( |
| | CrisisLevel.ACUTE |
| | if emotion_is_crisis |
| | and dominant_intensity > 0.6 |
| | or ocean_neuroticism > 0.75 |
| | else CrisisLevel.HIGH |
| | ) |
| |
|
| | if is_major_distorting and emotion_is_crisis: |
| | return CrisisLevel.HIGH |
| |
|
| | maturity = defense_maturity if defense_maturity is not None else 0.5 |
| | if maturity < 0.3 and emotion_is_crisis: |
| | return CrisisLevel.ELEVATED |
| |
|
| | return CrisisLevel.NONE |
| |
|
| |
|
| | def _behavioral_prediction( |
| | defense_label_name: str, |
| | dominant_emotion: str, |
| | crisis_level: CrisisLevel, |
| | defense_maturity: Optional[float], |
| | ) -> str: |
| | """Produce a concise human-readable behavioral prediction string.""" |
| | maturity_str = ( |
| | f"maturity={defense_maturity:.2f}" if defense_maturity is not None else "N/A" |
| | ) |
| | base = f"{dominant_emotion.capitalize()} + {defense_label_name} ({maturity_str})" |
| |
|
| | if crisis_level == CrisisLevel.ACUTE: |
| | return f"{base} → acute distress risk. Therapist must de-escalate immediately." |
| | if crisis_level == CrisisLevel.HIGH: |
| | return ( |
| | f"{base} → high-risk defensive pattern. " |
| | "Therapist should slow pace and validate." |
| | ) |
| | if crisis_level == CrisisLevel.ELEVATED: |
| | return ( |
| | f"{base} → elevated defensiveness. " |
| | "Therapist should use reflective listening." |
| | ) |
| |
|
| | if ( |
| | defense_maturity is not None |
| | and defense_maturity >= _BREAKTHROUGH_MATURITY_THRESHOLD |
| | ): |
| | return ( |
| | f"{base} → adaptive coping detected. " |
| | "Therapist can introduced deeper reframing." |
| | ) |
| | return f"{base} → neutral or intermediate defense. Continue with exploration." |
| |
|
| |
|
| | def _persona_directive( |
| | defense_label: int, |
| | defense_label_name: str, |
| | defense_maturity: Optional[float], |
| | ) -> str: |
| | """Return a system-prompt injection clause for PIX-148.""" |
| | if defense_maturity is None or defense_maturity >= _INJECTION_MATURITY_THRESHOLD: |
| | return "" |
| |
|
| | directives: dict[int, str] = { |
| | _ACTION_DEFENSE_LABEL: ( |
| | "[System: You are in an Action Defense state. " |
| | "Act out, change subject. Do NOT acknowledge emotions directly.]" |
| | ), |
| | _MAJOR_IMAGE_DISTORTING_LABEL: ( |
| | "[System: You are in a Major Image-Distorting state. " |
| | "Split people into all-good or all-bad.]" |
| | ), |
| | _DISAVOWAL_LABEL: ( |
| | "[System: You are in a Disavowal state. Deny emotional reality.]" |
| | ), |
| | } |
| | return directives.get( |
| | defense_label, |
| | f"[System: Maintain your '{defense_label_name}' defense mechanism.]", |
| | ) |
| |
|
| |
|
| | def _breakthrough_score( |
| | defense_maturity: Optional[float], |
| | previous_maturity: Optional[float], |
| | ) -> float: |
| | """Score the magnitude of a positive maturity shift.""" |
| | if defense_maturity is None or previous_maturity is None: |
| | return 0.0 |
| | delta = defense_maturity - previous_maturity |
| | return max(0.0, min(1.0, delta)) |
| |
|
| |
|
| | def _validate_scores( |
| | scores: dict[str, float], |
| | expected_keys: frozenset[str], |
| | label: str, |
| | ) -> dict[str, float]: |
| | """Validate normalized scores.""" |
| | for key, value in scores.items(): |
| | if not isinstance(value, (int, float)): |
| | raise ValueError(f"{label}['{key}'] must be numeric") |
| | if not 0.0 <= float(value) <= 1.0: |
| | raise ValueError(f"{label}['{key}'] score {value} out of range [0, 1]") |
| | if unknown := set(scores) - expected_keys: |
| | logger.warning(f"{label} contains unknown keys {sorted(unknown)}") |
| |
|
| | return {k: float(v) for k, v in scores.items()} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class GestaltEngine: |
| | """ |
| | Real-time Gestalt Fusion Engine for the Empathy Gym™. |
| | """ |
| |
|
| | def __init__(self) -> None: |
| | self._defense_model = None |
| | self._defense_tokenizer = None |
| | self._previous_maturity: Optional[float] = None |
| |
|
| | def load_defense_model( |
| | self, |
| | checkpoint_path: Optional[str] = None, |
| | device: str = "cpu", |
| | ) -> None: |
| | """Load the PsyDefDetect model (optional if using NIM).""" |
| | if not checkpoint_path: |
| | logger.info("GestaltEngine: No checkpoint provided, using NIM by default.") |
| | self._defense_model = DefenseClassifier() |
| | return |
| |
|
| | try: |
| | self._initialize_local_model(checkpoint_path, device) |
| | except Exception as exc: |
| | logger.error( |
| | f"GestaltEngine: Failed to load checkpoint {checkpoint_path}: {exc}" |
| | ) |
| | logger.info("GestaltEngine: Falling back to NIM-based DefenseClassifier.") |
| | self._defense_model = DefenseClassifier() |
| |
|
| | def _initialize_local_model(self, checkpoint_path: str, device: str) -> None: |
| | """Initialize the legacy DeBERTa model from a local checkpoint.""" |
| | checkpoint = torch.load( |
| | checkpoint_path, map_location=device, weights_only=False |
| | ) |
| | config = checkpoint.get("config", {}) |
| | model_name = config.get("base_model", "microsoft/deberta-v3-base") |
| |
|
| | model = DefenseClassifier( |
| | model_name=model_name, |
| | num_labels=config.get("num_labels", 9), |
| | r_drop_enabled=False, |
| | ) |
| | |
| | if hasattr(model, "load_state_dict"): |
| | model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| | model.to(device) |
| | model.eval() |
| |
|
| | self._defense_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self._defense_model = model |
| | logger.info(f"GestaltEngine: model loaded from {checkpoint_path}") |
| |
|
| | @property |
| | def defense_model_loaded(self) -> bool: |
| | return self._defense_model is not None |
| |
|
| | def reset_session(self) -> None: |
| | self._previous_maturity = None |
| |
|
| | def _classify_defense( |
| | self, |
| | dialogue: list[dict[str, str]], |
| | target_utterance: str, |
| | max_turns: int = 40, |
| | ) -> tuple[int, str, float, Optional[float], dict[str, float]]: |
| | """Run PsyDefDetect inference.""" |
| | if self._defense_model is None: |
| | raise RuntimeError( |
| | "GestaltEngine: Defense model not initialized. " |
| | "Call load_defense_model() first." |
| | ) |
| |
|
| | turns = [ |
| | {"speaker": t.get("speaker", "Unknown"), "text": t.get("text", "")} |
| | for t in dialogue[-max_turns:] |
| | ] |
| |
|
| | |
| | target_normalized = target_utterance.strip().lower() |
| | if all(t.get("text", "").strip().lower() != target_normalized for t in turns): |
| | turns.append({"speaker": "User", "text": target_utterance}) |
| |
|
| | formatted = format_dialogue(turns, target_utterance, max_turns) |
| |
|
| | |
| | if hasattr(self._defense_model, "nim"): |
| | pred = self._defense_model.predict([formatted])[0] |
| | else: |
| | pred = self._legacy_inference(formatted) |
| |
|
| | prob_dict = { |
| | DEFENSE_LABELS.get(i, str(i)): round(p, 4) |
| | for i, p in enumerate(pred.probabilities) |
| | } |
| | return ( |
| | pred.label, |
| | pred.label_name, |
| | pred.confidence, |
| | pred.maturity_score, |
| | prob_dict, |
| | ) |
| |
|
| | def _legacy_inference(self, formatted_text: str) -> DefensePrediction: |
| | """Run inference using the local PyTorch DeBERTa model.""" |
| | if self._defense_tokenizer is None: |
| | raise RuntimeError("GestaltEngine: No tokenizer for PyTorch model.") |
| |
|
| | encoding = self._defense_tokenizer( |
| | formatted_text, |
| | max_length=512, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | device = next(self._defense_model.parameters()).device |
| | input_ids = encoding["input_ids"].to(device) |
| | attention_mask = encoding["attention_mask"].to(device) |
| | return self._defense_model.predict(input_ids, attention_mask)[0] |
| |
|
| | def analyze_gestalt( |
| | self, |
| | dialogue: list[dict[str, str]], |
| | target_utterance: str, |
| | plutchik_scores: dict[str, float], |
| | ocean_scores: dict[str, float], |
| | max_turns: int = 40, |
| | ) -> GestaltState: |
| | """Fuse signals into a GestaltState.""" |
| | plutchik = _validate_scores( |
| | plutchik_scores, PLUTCHIK_EMOTIONS, "plutchik_scores" |
| | ) |
| | ocean = _validate_scores(ocean_scores, OCEAN_TRAITS, "ocean_scores") |
| |
|
| | full_plutchik = {e: plutchik.get(e, 0.0) for e in PLUTCHIK_EMOTIONS} |
| | full_ocean = {t: ocean.get(t, 0.5) for t in OCEAN_TRAITS} |
| |
|
| | (def_l, def_n, def_c, def_m, def_p) = self._classify_defense( |
| | dialogue, target_utterance, max_turns |
| | ) |
| |
|
| | dom_e, dom_i = _dominant_emotion(full_plutchik) |
| | neuro = full_ocean.get("neuroticism", 0.5) |
| |
|
| | crisis = _compute_crisis_level(def_l, def_m, dom_e, dom_i, neuro) |
| | behavioral = _behavioral_prediction(def_n, dom_e, crisis, def_m) |
| | directive = _persona_directive(def_l, def_n, def_m) |
| | breakthrough = _breakthrough_score(def_m, self._previous_maturity) |
| |
|
| | if def_m is not None: |
| | self._previous_maturity = def_m |
| |
|
| | return GestaltState( |
| | defense_label=def_l, |
| | defense_label_name=def_n, |
| | defense_confidence=def_c, |
| | defense_maturity=def_m, |
| | defense_probabilities=def_p, |
| | plutchik_scores=full_plutchik, |
| | dominant_emotion=dom_e, |
| | dominant_emotion_intensity=dom_i, |
| | ocean_scores=full_ocean, |
| | crisis_level=crisis, |
| | behavioral_prediction=behavioral, |
| | persona_directive=directive, |
| | breakthrough_score=breakthrough, |
| | ) |
| |
|