| | """ |
| | GestaltSimulator - Offline Batch Mode Simulator |
| | |
| | Runs the GestaltEngine over existing dialogue pairs and uses the PersonaManager |
| | to rewrite the AI patient's responses to be more human and defense-aware via |
| | an LLM (e.g. Gemini 2.5 Flash / Pro). |
| | """ |
| |
|
| | import json |
| | import logging |
| | import time |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from ai.utils.llm_capabilities import ensure_valid_key, get_best_available_gemini_model |
| |
|
| | try: |
| | from google import genai |
| | from google.genai import types |
| | except ImportError: |
| | genai = None |
| |
|
| | from ai.core.gestalt_engine import OCEAN_TRAITS, PLUTCHIK_EMOTIONS, GestaltEngine |
| | from ai.core.persona_manager import PersonaManager |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class GestaltSimulator: |
| | """Offline batch simulator for regenerating dialogues with Gestalt behaviors.""" |
| |
|
| | def __init__( |
| | self, |
| | defense_model_path: Optional[str] = None, |
| | device: str = "cpu", |
| | api_key: str = None, |
| | ): |
| | self.gestalt_engine = GestaltEngine() |
| | if defense_model_path: |
| | logger.info("Loading defense model from %s", defense_model_path) |
| | try: |
| | self.gestalt_engine.load_defense_model( |
| | defense_model_path, device=device |
| | ) |
| | except Exception as exc: |
| | logger.warning( |
| | "Could not load defense model, running in dry-run/mock mode: %s", |
| | exc, |
| | ) |
| | else: |
| | logger.info( |
| | "No defense model path provided, initializing GestaltEngine with " |
| | "NIM defaults." |
| | ) |
| | self.gestalt_engine.load_defense_model() |
| |
|
| | self.persona_manager = PersonaManager() |
| |
|
| | self.api_key = api_key or ensure_valid_key() |
| | if self.api_key and genai: |
| | self.client = genai.Client(api_key=self.api_key) |
| | else: |
| | self.client = None |
| | logger.warning( |
| | "Gemini API key not found or genai not installed. " |
| | "Generation will be mocked." |
| | ) |
| |
|
| | def _call_llm( |
| | self, |
| | system_prompt: str, |
| | conversation_history: List[Dict[str, str]], |
| | max_retries: int = 3, |
| | ) -> str: |
| | """Call the LLM to generate the next response.""" |
| | if not self.client: |
| | return "I don't want to talk about it right now." |
| |
|
| | contents = [ |
| | types.Content( |
| | role="user" if msg["role"] == "user" else "model", |
| | parts=[types.Part.from_text(text=msg["content"])], |
| | ) |
| | for msg in conversation_history |
| | ] |
| |
|
| | config = types.GenerateContentConfig( |
| | system_instruction=system_prompt, |
| | temperature=0.7, |
| | ) |
| |
|
| | for attempt in range(max_retries): |
| | try: |
| | response = self.client.models.generate_content( |
| | model=get_best_available_gemini_model(self.client), |
| | contents=contents, |
| | config=config, |
| | ) |
| | text = response.text |
| | if self.persona_manager.validate_human_likeness(text): |
| | return text |
| | logger.debug( |
| | "Generation failed human likeness check on attempt %d", attempt + 1 |
| | ) |
| | except Exception as exc: |
| | logger.error("LLM API error on attempt %d: %s", attempt + 1, exc) |
| | time.sleep(2**attempt) |
| |
|
| | return "I guess I just don't have much to say about that." |
| |
|
| | def simulate_turn( |
| | self, |
| | dialogue: List[Dict[str, str]], |
| | target_utterance: str, |
| | persona_id: str = None, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Simulate a single turn. |
| | |
| | 1. Run GestaltEngine on the current dialogue. |
| | 2. Get the persona directive. |
| | 3. Inject directive into the system prompt, then generate a response. |
| | """ |
| | persona = ( |
| | self.persona_manager.get_persona(persona_id) |
| | if persona_id |
| | else self.persona_manager.get_random_persona() |
| | ) |
| |
|
| | |
| | mock_plutchik = {e: 0.2 for e in PLUTCHIK_EMOTIONS} |
| | mock_plutchik["sadness"] = 0.6 |
| | mock_ocean = {t: persona.traits.get(t, 0.5) for t in OCEAN_TRAITS} |
| |
|
| | if self.gestalt_engine.defense_model_loaded: |
| | gestalt_state = self.gestalt_engine.analyze_gestalt( |
| | dialogue=dialogue, |
| | target_utterance=target_utterance, |
| | plutchik_scores=mock_plutchik, |
| | ocean_scores=mock_ocean, |
| | ) |
| | directive = gestalt_state.persona_directive |
| | else: |
| | logger.debug( |
| | "Defense model not loaded, using default persona defense directive." |
| | ) |
| | directive = ( |
| | f"[System: Maintain your '{persona.default_defense}' " |
| | "defense mechanism.]" |
| | ) |
| | gestalt_state = None |
| |
|
| | system_prompt = persona.generate_system_prompt() |
| | if directive: |
| | system_prompt += f"\n\nCRITICAL DIRECTIVE:\n{directive}" |
| |
|
| | |
| | llm_history = [] |
| | for turn in dialogue: |
| | role = ( |
| | "user" |
| | if turn.get("speaker", "user") |
| | in ("human", "user", "client", "therapist") |
| | else "assistant" |
| | ) |
| | llm_history.append({"role": role, "content": turn.get("text", "")}) |
| | llm_history.append({"role": "user", "content": target_utterance}) |
| |
|
| | new_response = self._call_llm(system_prompt, llm_history) |
| |
|
| | return { |
| | "original_utterance": target_utterance, |
| | "new_response": new_response, |
| | "persona_id": persona.archetype_id, |
| | "directive_used": directive, |
| | "gestalt_state": gestalt_state.__dict__ if gestalt_state else None, |
| | } |
| |
|
| | def process_batch( |
| | self, input_file: str, output_file: str, max_records: int = 5000 |
| | ) -> int: |
| | """ |
| | Process a JSONL file of dialogue pairs and rewrite the assistant responses. |
| | |
| | Returns the number of records written. |
| | """ |
| | logger.info("Starting batch simulation from %s → %s", input_file, output_file) |
| |
|
| | processed_count = 0 |
| | with ( |
| | open(input_file, "r", encoding="utf-8") as infile, |
| | open(output_file, "w", encoding="utf-8") as outfile, |
| | ): |
| | for line in infile: |
| | if processed_count >= max_records: |
| | break |
| |
|
| | try: |
| | record = json.loads(line) |
| | messages = record.get("messages", []) |
| |
|
| | if len(messages) < 3: |
| | continue |
| | if ( |
| | messages[-1]["role"] != "assistant" |
| | or messages[-2]["role"] != "user" |
| | ): |
| | continue |
| |
|
| | target_user_utterance = messages[-2]["content"] |
| |
|
| | history_for_engine = [ |
| | { |
| | "speaker": ( |
| | "therapist" if msg["role"] == "user" else "client" |
| | ), |
| | "text": msg["content"], |
| | } |
| | for msg in messages[:-2] |
| | if msg["role"] != "system" |
| | ] |
| | result = self.simulate_turn( |
| | history_for_engine, target_user_utterance |
| | ) |
| |
|
| | record["messages"][-1]["content"] = result["new_response"] |
| | record.setdefault("metadata", {})["gestalt_simulation"] = { |
| | "persona_id": result["persona_id"], |
| | "directive": result["directive_used"], |
| | } |
| |
|
| | outfile.write(json.dumps(record, ensure_ascii=False) + "\n") |
| | processed_count += 1 |
| |
|
| | if processed_count % 100 == 0: |
| | logger.info("Processed %d records...", processed_count) |
| |
|
| | except Exception as exc: |
| | logger.error("Error processing record %d: %s", processed_count, exc) |
| |
|
| | logger.info( |
| | "Batch complete. Wrote %d records to %s", processed_count, output_file |
| | ) |
| | return processed_count |
| |
|