pixelated-training-code / core /gestalt_engine.py
oneblackmage's picture
Upload folder using huggingface_hub
ac9bb45 verified
"""
PIX-149: Gestalt Fusion Engine
Unifies the three core emotional intelligence models of Pixelated Empathy
into a single, real-time inference call:
1. PsyDefDetect — DMRS Defense Mechanism Classifier (DeBERTa)
2. Plutchik — 8-emotion wheel scoring (passed in externally)
3. OCEAN — Big Five personality trait scoring (passed in externally)
The fused ``GestaltState`` dataclass powers:
- PIX-147: WebSocket "Live X-Ray" resistance monitor
- PIX-148: Adversarial Persona Injection (dynamic defense-aware prompts)
- PIX-150: Empathy PQ metric validation (trainee scoring)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
import torch
from transformers import AutoTokenizer
from ai.training.defense_mechanisms.constants import DEFENSE_LABELS
from ai.training.defense_mechanisms.dataset import format_dialogue
from ai.training.defense_mechanisms.model import DefenseClassifier, DefensePrediction
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Domain enumerations
# ---------------------------------------------------------------------------
PLUTCHIK_EMOTIONS = frozenset(
{
"anger",
"anticipation",
"disgust",
"fear",
"joy",
"sadness",
"surprise",
"trust",
}
)
OCEAN_TRAITS = frozenset(
{
"openness",
"conscientiousness",
"extraversion",
"agreeableness",
"neuroticism",
}
)
class CrisisLevel(str, Enum):
"""
Behavioral risk level derived from the fused Gestalt state.
"""
NONE = "none"
ELEVATED = "elevated"
HIGH = "high"
ACUTE = "acute"
# ---------------------------------------------------------------------------
# Fused output dataclass
# ---------------------------------------------------------------------------
@dataclass
class GestaltState:
"""
Unified emotional-psychological state for a single dialogue turn.
"""
# --- Defense Mechanism (PsyDefDetect) ---
defense_label: int
defense_label_name: str
defense_confidence: float
defense_maturity: Optional[float]
defense_probabilities: dict[str, float]
# --- Emotion (Plutchik) ---
plutchik_scores: dict[str, float]
dominant_emotion: str
dominant_emotion_intensity: float
# --- Personality (OCEAN) ---
ocean_scores: dict[str, float]
# --- Fused behavioral outputs ---
crisis_level: CrisisLevel
behavioral_prediction: str
persona_directive: str
breakthrough_score: float
raw_metadata: dict = field(default_factory=dict, repr=False)
# ---------------------------------------------------------------------------
# Crisis & behavioral prediction logic
# ---------------------------------------------------------------------------
_ACTION_DEFENSE_LABEL = 1
_HIGH_ADAPTIVE_LABEL = 7
_DISAVOWAL_LABEL = 3
_MAJOR_IMAGE_DISTORTING_LABEL = 2
_CRISIS_AMPLIFYING_EMOTIONS = frozenset({"sadness", "fear", "anger", "disgust"})
_INJECTION_MATURITY_THRESHOLD = 0.43
_BREAKTHROUGH_MATURITY_THRESHOLD = 0.71
def _dominant_emotion(plutchik: dict[str, float]) -> tuple[str, float]:
"""Return the emotion with the highest intensity score."""
if not plutchik:
return "unknown", 0.0
dominant = max(plutchik, key=lambda k: plutchik[k])
return dominant, plutchik[dominant]
def _compute_crisis_level(
defense_label: int,
defense_maturity: Optional[float],
dominant_emotion: str,
dominant_intensity: float,
ocean_neuroticism: float,
) -> CrisisLevel:
"""Determine crisis level from the fused signals."""
is_action = defense_label == _ACTION_DEFENSE_LABEL
is_major_distorting = defense_label == _MAJOR_IMAGE_DISTORTING_LABEL
emotion_is_crisis = dominant_emotion in _CRISIS_AMPLIFYING_EMOTIONS
if is_action:
return (
CrisisLevel.ACUTE
if emotion_is_crisis
and dominant_intensity > 0.6
or ocean_neuroticism > 0.75
else CrisisLevel.HIGH
)
if is_major_distorting and emotion_is_crisis:
return CrisisLevel.HIGH
maturity = defense_maturity if defense_maturity is not None else 0.5
if maturity < 0.3 and emotion_is_crisis:
return CrisisLevel.ELEVATED
return CrisisLevel.NONE
def _behavioral_prediction(
defense_label_name: str,
dominant_emotion: str,
crisis_level: CrisisLevel,
defense_maturity: Optional[float],
) -> str:
"""Produce a concise human-readable behavioral prediction string."""
maturity_str = (
f"maturity={defense_maturity:.2f}" if defense_maturity is not None else "N/A"
)
base = f"{dominant_emotion.capitalize()} + {defense_label_name} ({maturity_str})"
if crisis_level == CrisisLevel.ACUTE:
return f"{base} → acute distress risk. Therapist must de-escalate immediately."
if crisis_level == CrisisLevel.HIGH:
return (
f"{base} → high-risk defensive pattern. "
"Therapist should slow pace and validate."
)
if crisis_level == CrisisLevel.ELEVATED:
return (
f"{base} → elevated defensiveness. "
"Therapist should use reflective listening."
)
if (
defense_maturity is not None
and defense_maturity >= _BREAKTHROUGH_MATURITY_THRESHOLD
):
return (
f"{base} → adaptive coping detected. "
"Therapist can introduced deeper reframing."
)
return f"{base} → neutral or intermediate defense. Continue with exploration."
def _persona_directive(
defense_label: int,
defense_label_name: str,
defense_maturity: Optional[float],
) -> str:
"""Return a system-prompt injection clause for PIX-148."""
if defense_maturity is None or defense_maturity >= _INJECTION_MATURITY_THRESHOLD:
return ""
directives: dict[int, str] = {
_ACTION_DEFENSE_LABEL: (
"[System: You are in an Action Defense state. "
"Act out, change subject. Do NOT acknowledge emotions directly.]"
),
_MAJOR_IMAGE_DISTORTING_LABEL: (
"[System: You are in a Major Image-Distorting state. "
"Split people into all-good or all-bad.]"
),
_DISAVOWAL_LABEL: (
"[System: You are in a Disavowal state. Deny emotional reality.]"
),
}
return directives.get(
defense_label,
f"[System: Maintain your '{defense_label_name}' defense mechanism.]",
)
def _breakthrough_score(
defense_maturity: Optional[float],
previous_maturity: Optional[float],
) -> float:
"""Score the magnitude of a positive maturity shift."""
if defense_maturity is None or previous_maturity is None:
return 0.0
delta = defense_maturity - previous_maturity
return max(0.0, min(1.0, delta))
def _validate_scores(
scores: dict[str, float],
expected_keys: frozenset[str],
label: str,
) -> dict[str, float]:
"""Validate normalized scores."""
for key, value in scores.items():
if not isinstance(value, (int, float)):
raise ValueError(f"{label}['{key}'] must be numeric")
if not 0.0 <= float(value) <= 1.0:
raise ValueError(f"{label}['{key}'] score {value} out of range [0, 1]")
if unknown := set(scores) - expected_keys:
logger.warning(f"{label} contains unknown keys {sorted(unknown)}")
return {k: float(v) for k, v in scores.items()}
# ---------------------------------------------------------------------------
# GestaltEngine
# ---------------------------------------------------------------------------
class GestaltEngine:
"""
Real-time Gestalt Fusion Engine for the Empathy Gym™.
"""
def __init__(self) -> None:
self._defense_model = None
self._defense_tokenizer = None
self._previous_maturity: Optional[float] = None
def load_defense_model(
self,
checkpoint_path: Optional[str] = None,
device: str = "cpu",
) -> None:
"""Load the PsyDefDetect model (optional if using NIM)."""
if not checkpoint_path:
logger.info("GestaltEngine: No checkpoint provided, using NIM by default.")
self._defense_model = DefenseClassifier()
return
try:
self._initialize_local_model(checkpoint_path, device)
except Exception as exc:
logger.error(
f"GestaltEngine: Failed to load checkpoint {checkpoint_path}: {exc}"
)
logger.info("GestaltEngine: Falling back to NIM-based DefenseClassifier.")
self._defense_model = DefenseClassifier()
def _initialize_local_model(self, checkpoint_path: str, device: str) -> None:
"""Initialize the legacy DeBERTa model from a local checkpoint."""
checkpoint = torch.load(
checkpoint_path, map_location=device, weights_only=False
)
config = checkpoint.get("config", {})
model_name = config.get("base_model", "microsoft/deberta-v3-base")
model = DefenseClassifier(
model_name=model_name,
num_labels=config.get("num_labels", 9),
r_drop_enabled=False,
)
# Handle PyTorch-based DefenseClassifier if checkpoint is provided
if hasattr(model, "load_state_dict"):
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
model.to(device)
model.eval()
self._defense_tokenizer = AutoTokenizer.from_pretrained(model_name)
self._defense_model = model
logger.info(f"GestaltEngine: model loaded from {checkpoint_path}")
@property
def defense_model_loaded(self) -> bool:
return self._defense_model is not None
def reset_session(self) -> None:
self._previous_maturity = None
def _classify_defense(
self,
dialogue: list[dict[str, str]],
target_utterance: str,
max_turns: int = 40,
) -> tuple[int, str, float, Optional[float], dict[str, float]]:
"""Run PsyDefDetect inference."""
if self._defense_model is None:
raise RuntimeError(
"GestaltEngine: Defense model not initialized. "
"Call load_defense_model() first."
)
turns = [
{"speaker": t.get("speaker", "Unknown"), "text": t.get("text", "")}
for t in dialogue[-max_turns:]
]
# Ensure the target utterance is in the sequence for format_dialogue to mark it
target_normalized = target_utterance.strip().lower()
if all(t.get("text", "").strip().lower() != target_normalized for t in turns):
turns.append({"speaker": "User", "text": target_utterance})
formatted = format_dialogue(turns, target_utterance, max_turns)
# Handle the new text-based NIM classifier
if hasattr(self._defense_model, "nim"):
pred = self._defense_model.predict([formatted])[0]
else:
pred = self._legacy_inference(formatted)
prob_dict = {
DEFENSE_LABELS.get(i, str(i)): round(p, 4)
for i, p in enumerate(pred.probabilities)
}
return (
pred.label,
pred.label_name,
pred.confidence,
pred.maturity_score,
prob_dict,
)
def _legacy_inference(self, formatted_text: str) -> DefensePrediction:
"""Run inference using the local PyTorch DeBERTa model."""
if self._defense_tokenizer is None:
raise RuntimeError("GestaltEngine: No tokenizer for PyTorch model.")
encoding = self._defense_tokenizer(
formatted_text,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt",
)
device = next(self._defense_model.parameters()).device
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
return self._defense_model.predict(input_ids, attention_mask)[0]
def analyze_gestalt(
self,
dialogue: list[dict[str, str]],
target_utterance: str,
plutchik_scores: dict[str, float],
ocean_scores: dict[str, float],
max_turns: int = 40,
) -> GestaltState:
"""Fuse signals into a GestaltState."""
plutchik = _validate_scores(
plutchik_scores, PLUTCHIK_EMOTIONS, "plutchik_scores"
)
ocean = _validate_scores(ocean_scores, OCEAN_TRAITS, "ocean_scores")
full_plutchik = {e: plutchik.get(e, 0.0) for e in PLUTCHIK_EMOTIONS}
full_ocean = {t: ocean.get(t, 0.5) for t in OCEAN_TRAITS}
(def_l, def_n, def_c, def_m, def_p) = self._classify_defense(
dialogue, target_utterance, max_turns
)
dom_e, dom_i = _dominant_emotion(full_plutchik)
neuro = full_ocean.get("neuroticism", 0.5)
crisis = _compute_crisis_level(def_l, def_m, dom_e, dom_i, neuro)
behavioral = _behavioral_prediction(def_n, dom_e, crisis, def_m)
directive = _persona_directive(def_l, def_n, def_m)
breakthrough = _breakthrough_score(def_m, self._previous_maturity)
if def_m is not None:
self._previous_maturity = def_m
return GestaltState(
defense_label=def_l,
defense_label_name=def_n,
defense_confidence=def_c,
defense_maturity=def_m,
defense_probabilities=def_p,
plutchik_scores=full_plutchik,
dominant_emotion=dom_e,
dominant_emotion_intensity=dom_i,
ocean_scores=full_ocean,
crisis_level=crisis,
behavioral_prediction=behavioral,
persona_directive=directive,
breakthrough_score=breakthrough,
)