| |
| """ |
| Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive, GPU Monitoring, and API Fallback |
| |
| This application provides a Gradio web interface for multimodal chat with a |
| local vLLM model. It establishes SSH tunnels to a local vLLM server and |
| the nvidia-smi monitoring endpoint, with fallback to Hyperbolic API if needed. |
| """ |
|
|
| import os |
| import time |
| import threading |
| import logging |
| import base64 |
| import json |
| import requests |
| from io import BytesIO |
| import gradio as gr |
| from openai import OpenAI |
| from ssh_tunneler import SSHTunnel |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger('app') |
|
|
| |
| SSH_HOST = os.environ.get('SSH_HOST') |
| SSH_PORT = int(os.environ.get('SSH_PORT', 22)) |
| SSH_USERNAME = os.environ.get('SSH_USERNAME') |
| SSH_PASSWORD = os.environ.get('SSH_PASSWORD') |
| REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) |
| LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) |
| GPU_REMOTE_PORT = 5000 |
| GPU_LOCAL_PORT = 5020 |
| VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it') |
| HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY') |
| FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' |
|
|
| |
| MAX_CONCURRENT = int(os.environ.get('MAX_CONCURRENT', 3)) |
|
|
| |
| VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1" |
| HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1" |
| GPU_JSON_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/json" |
| GPU_TXT_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/txt" |
|
|
| |
| api_tunnel = None |
| gpu_tunnel = None |
| use_fallback = False |
| api_tunnel_status = {"is_running": False, "message": "Initializing API tunnel..."} |
| gpu_tunnel_status = {"is_running": False, "message": "Initializing GPU monitoring tunnel..."} |
| gpu_data = {"timestamp": "", "gpus": [], "processes": [], "success": False} |
| gpu_monitor_thread = None |
| gpu_monitor_running = False |
|
|
| def start_ssh_tunnels(): |
| """ |
| Start the SSH tunnels and monitor their status. |
| """ |
| global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status |
| |
| if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]): |
| logger.error("Missing SSH connection details. Falling back to Hyperbolic API.") |
| use_fallback = True |
| api_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} |
| gpu_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} |
| return |
| |
| try: |
| |
| logger.info("Starting API SSH tunnel...") |
| api_tunnel = SSHTunnel( |
| ssh_host=SSH_HOST, |
| ssh_port=SSH_PORT, |
| username=SSH_USERNAME, |
| password=SSH_PASSWORD, |
| remote_port=REMOTE_PORT, |
| local_port=LOCAL_PORT, |
| reconnect_interval=30, |
| keep_alive_interval=15 |
| ) |
| |
| if api_tunnel.start(): |
| logger.info("API SSH tunnel started successfully") |
| api_tunnel_status = {"is_running": True, "message": "Connected"} |
| else: |
| logger.warning("Failed to start API SSH tunnel. Falling back to Hyperbolic API.") |
| use_fallback = True |
| api_tunnel_status = {"is_running": False, "message": "Connection failed"} |
| |
| |
| logger.info("Starting GPU monitoring SSH tunnel...") |
| gpu_tunnel = SSHTunnel( |
| ssh_host=SSH_HOST, |
| ssh_port=SSH_PORT, |
| username=SSH_USERNAME, |
| password=SSH_PASSWORD, |
| remote_port=GPU_REMOTE_PORT, |
| local_port=GPU_LOCAL_PORT, |
| reconnect_interval=30, |
| keep_alive_interval=15 |
| ) |
| |
| if gpu_tunnel.start(): |
| logger.info("GPU monitoring SSH tunnel started successfully") |
| gpu_tunnel_status = {"is_running": True, "message": "Connected"} |
| |
| start_gpu_monitoring() |
| else: |
| logger.warning("Failed to start GPU monitoring SSH tunnel.") |
| gpu_tunnel_status = {"is_running": False, "message": "Connection failed"} |
| |
| except Exception as e: |
| logger.error(f"Error starting SSH tunnels: {str(e)}") |
| use_fallback = True |
| api_tunnel_status = {"is_running": False, "message": "Connection error"} |
| gpu_tunnel_status = {"is_running": False, "message": "Connection error"} |
|
|
| def check_vllm_api_health(): |
| """ |
| Check if the vLLM API is actually responding by querying the /v1/models endpoint. |
| |
| Returns: |
| tuple: (is_healthy, message) |
| """ |
| try: |
| response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5) |
| if response.status_code == 200: |
| try: |
| data = response.json() |
| if 'data' in data and len(data['data']) > 0: |
| model_id = data['data'][0].get('id', 'Unknown model') |
| return True, f"API is healthy. Available model: {model_id}" |
| else: |
| return True, "API is healthy but no models found" |
| except Exception as e: |
| return False, f"API returned 200 but invalid JSON: {str(e)}" |
| else: |
| return False, f"API returned status code: {response.status_code}" |
| except Exception as e: |
| return False, f"API request failed: {str(e)}" |
|
|
| def fetch_gpu_info(): |
| """ |
| Fetch GPU information from the remote server in JSON format. |
| |
| Returns: |
| dict: GPU information or error message |
| """ |
| global gpu_tunnel_status |
| |
| try: |
| response = requests.get(GPU_JSON_ENDPOINT, timeout=5) |
| if response.status_code == 200: |
| return response.json() |
| else: |
| logger.warning(f"Error fetching GPU info: HTTP {response.status_code}") |
| return { |
| "success": False, |
| "error": f"HTTP Error: {response.status_code}", |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| "gpus": [], |
| "processes": [] |
| } |
| except Exception as e: |
| logger.warning(f"Error fetching GPU info: {str(e)}") |
| return { |
| "success": False, |
| "error": str(e), |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| "gpus": [], |
| "processes": [] |
| } |
|
|
| def fetch_gpu_text(): |
| """ |
| Fetch raw nvidia-smi output from the remote server for backward compatibility. |
| |
| Returns: |
| str: nvidia-smi output or error message |
| """ |
| try: |
| response = requests.get(GPU_TXT_ENDPOINT, timeout=5) |
| if response.status_code == 200: |
| return response.text |
| else: |
| return f"Error fetching GPU info: HTTP {response.status_code}" |
| except Exception as e: |
| return f"Error fetching GPU info: {str(e)}" |
|
|
| def start_gpu_monitoring(): |
| """ |
| Start the GPU monitoring thread. |
| """ |
| global gpu_monitor_thread, gpu_monitor_running, gpu_data |
| |
| if gpu_monitor_running: |
| return |
| |
| gpu_monitor_running = True |
| |
| def monitor_loop(): |
| global gpu_data |
| while gpu_monitor_running: |
| try: |
| gpu_data = fetch_gpu_info() |
| except Exception as e: |
| logger.error(f"Error in GPU monitoring loop: {str(e)}") |
| gpu_data = { |
| "success": False, |
| "error": str(e), |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
| "gpus": [], |
| "processes": [] |
| } |
| time.sleep(2) |
| |
| gpu_monitor_thread = threading.Thread(target=monitor_loop, daemon=True) |
| gpu_monitor_thread.start() |
| logger.info("GPU monitoring thread started") |
|
|
| def process_chat(message_dict, history): |
| """ |
| Process user message and send to the appropriate API. |
| |
| Args: |
| message_dict (dict): User message containing text and files |
| history (list): Chat history |
| |
| Returns: |
| list: Updated chat history |
| """ |
| global use_fallback |
| |
| text = message_dict.get("text", "") |
| files = message_dict.get("files", []) |
| |
| if not history: |
| history = [] |
| |
| if files: |
| for file in files: |
| history.append({"role": "user", "content": (file,)}) |
| |
| if text.strip(): |
| history.append({"role": "user", "content": text}) |
| else: |
| if not files: |
| history.append({"role": "user", "content": ""}) |
| |
| base64_images = convert_files_to_base64(files) |
| openai_messages = [] |
| |
| for h in history: |
| if h["role"] == "user": |
| if isinstance(h["content"], tuple): |
| continue |
| else: |
| openai_messages.append({ |
| "role": "user", |
| "content": h["content"] |
| }) |
| elif h["role"] == "assistant": |
| openai_messages.append({ |
| "role": "assistant", |
| "content": h["content"] |
| }) |
| |
| if base64_images: |
| if openai_messages and openai_messages[-1]["role"] == "user": |
| last_msg = openai_messages[-1] |
| content_list = [] |
| if last_msg["content"]: |
| content_list.append({"type": "text", "text": last_msg["content"]}) |
| for img_b64 in base64_images: |
| content_list.append({ |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/jpeg;base64,{img_b64}" |
| } |
| }) |
| last_msg["content"] = content_list |
| |
| try: |
| client = get_openai_client() |
| model = get_model_name() |
| |
| response = client.chat.completions.create( |
| model=model, |
| messages=openai_messages, |
| stream=True |
| ) |
| |
| assistant_message = "" |
| for chunk in response: |
| if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: |
| assistant_message += chunk.choices[0].delta.content |
| history_with_stream = history.copy() |
| history_with_stream.append({"role": "assistant", "content": assistant_message}) |
| yield history_with_stream |
| |
| if not assistant_message: |
| assistant_message = "No response received from the model." |
| |
| if not history or history[-1]["role"] != "assistant": |
| history.append({"role": "assistant", "content": assistant_message}) |
| |
| return history |
| |
| except Exception as primary_error: |
| logger.error(f"Primary API error: {str(primary_error)}") |
| if not use_fallback: |
| try: |
| logger.info("Falling back to Hyperbolic API") |
| client = get_openai_client(use_fallback_api=True) |
| model = get_model_name(use_fallback_api=True) |
| |
| response = client.chat.completions.create( |
| model=model, |
| messages=openai_messages, |
| stream=True |
| ) |
| |
| assistant_message = "" |
| for chunk in response: |
| if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: |
| assistant_message += chunk.choices[0].delta.content |
| history_with_stream = history.copy() |
| history_with_stream.append({"role": "assistant", "content": assistant_message}) |
| yield history_with_stream |
| |
| if not assistant_message: |
| assistant_message = "No response received from the fallback model." |
| |
| if not history or history[-1]["role"] != "assistant": |
| history.append({"role": "assistant", "content": assistant_message}) |
| |
| use_fallback = True |
| return history |
| |
| except Exception as fallback_error: |
| logger.error(f"Fallback API error: {str(fallback_error)}") |
| error_msg = "Error connecting to both primary and fallback APIs." |
| history.append({"role": "assistant", "content": error_msg}) |
| return history |
| else: |
| error_msg = "An error occurred with the model service." |
| history.append({"role": "assistant", "content": error_msg}) |
| return history |
|
|
| def monitor_tunnels(): |
| """ |
| Monitor the SSH tunnels status and update the global variables. |
| """ |
| global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status |
| |
| logger.info("Starting tunnel monitoring thread") |
| |
| while True: |
| try: |
| if api_tunnel is not None: |
| ssh_status = api_tunnel.check_status() |
| if ssh_status["is_running"]: |
| is_healthy, message = check_vllm_api_health() |
| if is_healthy: |
| use_fallback = False |
| api_tunnel_status = { |
| "is_running": True, |
| "message": f"Connected and healthy. {message}" |
| } |
| else: |
| use_fallback = True |
| api_tunnel_status = { |
| "is_running": False, |
| "message": "Tunnel connected but vLLM API unhealthy" |
| } |
| else: |
| logger.error(f"API SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}") |
| use_fallback = True |
| api_tunnel_status = { |
| "is_running": False, |
| "message": "Disconnected - Check server status" |
| } |
| else: |
| use_fallback = True |
| api_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} |
| |
| if gpu_tunnel is not None: |
| ssh_status = gpu_tunnel.check_status() |
| if ssh_status["is_running"]: |
| gpu_tunnel_status = { |
| "is_running": True, |
| "message": "Connected" |
| } |
| if not gpu_monitor_running: |
| start_gpu_monitoring() |
| else: |
| logger.error(f"GPU SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}") |
| gpu_tunnel_status = { |
| "is_running": False, |
| "message": "Disconnected - Check server status" |
| } |
| else: |
| gpu_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} |
| |
| except Exception as e: |
| logger.error(f"Error monitoring tunnels: {str(e)}") |
| use_fallback = True |
| api_tunnel_status = {"is_running": False, "message": "Monitoring error"} |
| gpu_tunnel_status = {"is_running": False, "message": "Monitoring error"} |
| |
| time.sleep(5) |
|
|
| def get_openai_client(use_fallback_api=None): |
| """ |
| Create and return an OpenAI client configured for the appropriate endpoint. |
| |
| Args: |
| use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM. |
| If None, use the global use_fallback setting. |
| |
| Returns: |
| OpenAI: Configured OpenAI client |
| """ |
| global use_fallback |
| if use_fallback_api is None: |
| use_fallback_api = use_fallback |
| |
| if use_fallback_api: |
| logger.info("Using Hyperbolic API") |
| return OpenAI( |
| api_key=HYPERBOLIC_KEY, |
| base_url=HYPERBOLIC_ENDPOINT |
| ) |
| else: |
| logger.info("Using local vLLM API") |
| return OpenAI( |
| api_key="EMPTY", |
| base_url=VLLM_ENDPOINT |
| ) |
|
|
| def get_model_name(use_fallback_api=None): |
| """ |
| Return the appropriate model name based on the API being used. |
| |
| Args: |
| use_fallback_api (bool): If True, use fallback model. If None, use the global setting. |
| |
| Returns: |
| str: Model name |
| """ |
| global use_fallback |
| if use_fallback_api is None: |
| use_fallback_api = use_fallback |
| return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL |
|
|
| def convert_files_to_base64(files): |
| """ |
| Convert uploaded files to base64 strings. |
| |
| Args: |
| files (list): List of file paths |
| |
| Returns: |
| list: List of base64-encoded strings |
| """ |
| base64_images = [] |
| for file in files: |
| with open(file, "rb") as image_file: |
| base64_data = base64.b64encode(image_file.read()).decode("utf-8") |
| base64_images.append(base64_data) |
| return base64_images |
|
|
| def format_simplified_gpu_data(gpu_data): |
| """ |
| Format GPU data into a simplified, focused display. |
| |
| Args: |
| gpu_data (dict): GPU data in JSON format |
| |
| Returns: |
| str: Formatted GPU data |
| """ |
| if not gpu_data.get("success", False): |
| return f"Error fetching GPU data: {gpu_data.get('error', 'Unknown error')}" |
| |
| output = [] |
| output.append(f"Last updated: {gpu_data.get('timestamp', 'Unknown')}") |
| |
| for i, gpu in enumerate(gpu_data.get("gpus", [])): |
| output.append(f"GPU {gpu.get('index', i)}: {gpu.get('name', 'Unknown')}") |
| output.append(f" Memory: {gpu.get('memory_used', 0):6.0f} MB / {gpu.get('memory_total', 0):6.0f} MB ({gpu.get('memory_utilization', 0):5.1f}%)") |
| output.append(f" Power: {gpu.get('power_draw', 0):5.1f}W / {gpu.get('power_limit', 0):5.1f}W") |
| if 'fan_speed' in gpu: |
| output.append(f" Fan: {gpu.get('fan_speed', 0):5.1f}%") |
| output.append(f" Temp: {gpu.get('temperature', 0):5.1f}°C") |
| output.append("") |
| |
| return "\n".join(output) |
|
|
| def update_gpu_status(): |
| """ |
| Fetch and format the current GPU status. |
| |
| Returns: |
| str: Formatted GPU status |
| """ |
| global gpu_data, gpu_tunnel_status |
| if not gpu_tunnel_status["is_running"]: |
| return "GPU monitoring tunnel is not connected." |
| return format_simplified_gpu_data(gpu_data) |
|
|
| def get_tunnel_status_message(): |
| """ |
| Return a formatted status message for display in the UI. |
| """ |
| global api_tunnel_status, gpu_tunnel_status, use_fallback, MAX_CONCURRENT |
| api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" |
| model = get_model_name() |
| api_status_color = "🟢" if (api_tunnel_status["is_running"] and not use_fallback) else "🔴" |
| api_status_text = api_tunnel_status["message"] |
| gpu_status_color = "🟢" if gpu_tunnel_status["is_running"] else "🔴" |
| gpu_status_text = gpu_tunnel_status["message"] |
| return (f"{api_status_color} API Tunnel: {api_status_text}\n" |
| f"{gpu_status_color} GPU Tunnel: {gpu_status_text}\n" |
| f"Current API: {api_mode}\n" |
| f"Current Model: {model}\n" |
| f"Concurrent Requests: {MAX_CONCURRENT}") |
|
|
| def get_gpu_json(): |
| """ |
| Return the raw GPU JSON data for debugging. |
| """ |
| global gpu_data |
| return json.dumps(gpu_data, indent=2) |
|
|
| def toggle_api(): |
| """ |
| Toggle between local vLLM and Hyperbolic API. |
| """ |
| global use_fallback |
| use_fallback = not use_fallback |
| api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" |
| model = get_model_name() |
| return f"Switched to {api_mode} using {model}" |
|
|
| def update_concurrency(new_value): |
| """ |
| Update the MAX_CONCURRENT value. |
| |
| Args: |
| new_value (str): New concurrency value as string |
| |
| Returns: |
| str: Status message |
| """ |
| global MAX_CONCURRENT |
| try: |
| value = int(new_value) |
| if value < 1: |
| return f"Error: Concurrency must be at least 1. Keeping current value: {MAX_CONCURRENT}" |
| MAX_CONCURRENT = value |
| return f"Concurrency updated to {MAX_CONCURRENT}. You may need to refresh the page for all changes to take effect." |
| except ValueError: |
| return f"Error: Invalid number. Keeping current value: {MAX_CONCURRENT}" |
|
|
| |
| if __name__ == "__main__": |
| start_ssh_tunnels() |
| monitor_thread = threading.Thread(target=monitor_tunnels, daemon=True) |
| monitor_thread.start() |
| |
| with gr.Blocks(theme="soft") as demo: |
| gr.Markdown("# Multimodal Chat Interface") |
| |
| chatbot = gr.Chatbot( |
| label="Conversation", |
| type="messages", |
| show_copy_button=True, |
| avatar_images=("👤", "🗣️"), |
| height=400 |
| ) |
| |
| with gr.Row(): |
| textbox = gr.MultimodalTextbox( |
| file_types=["image", "video"], |
| file_count="multiple", |
| placeholder="Type your message here and/or upload images...", |
| label="Message", |
| show_label=False, |
| scale=9 |
| ) |
| submit_btn = gr.Button("Send", size="sm", scale=1) |
| |
| clear_btn = gr.Button("Clear Chat") |
| |
| submit_event = textbox.submit( |
| fn=process_chat, |
| inputs=[textbox, chatbot], |
| outputs=chatbot, |
| concurrency_limit=MAX_CONCURRENT |
| ).then( |
| fn=lambda: {"text": "", "files": []}, |
| inputs=None, |
| outputs=textbox |
| ) |
| |
| submit_btn.click( |
| fn=process_chat, |
| inputs=[textbox, chatbot], |
| outputs=chatbot, |
| concurrency_limit=MAX_CONCURRENT |
| ).then( |
| fn=lambda: {"text": "", "files": []}, |
| inputs=None, |
| outputs=textbox |
| ) |
| |
| clear_btn.click(lambda: [], None, chatbot) |
| |
| examples = [] |
| example_images = { |
| "dog_pic.jpg": "What breed is this?", |
| "ghostimg.png": "What's in this image?", |
| "newspaper.png": "Provide a python list of dicts about everything on this page." |
| } |
| for img_name, prompt_text in example_images.items(): |
| img_path = os.path.join(os.path.dirname(__file__), img_name) |
| if os.path.exists(img_path): |
| examples.append([{"text": prompt_text, "files": [img_path]}]) |
| if examples: |
| gr.Examples( |
| examples=examples, |
| inputs=textbox |
| ) |
| |
| status_text = gr.Textbox( |
| label="Tunnel and API Status", |
| value=get_tunnel_status_message(), |
| interactive=False |
| ) |
| |
| with gr.Accordion("GPU Status", open=False): |
| |
| gpu_status = gr.HTML( |
| value=lambda: f"<pre style='font-family: monospace; white-space: pre; overflow: auto;'>{update_gpu_status()}</pre>", |
| every=2 |
| ) |
| |
| with gr.Row(): |
| refresh_btn = gr.Button("Refresh Status") |
| toggle_api_btn = gr.Button("Toggle API") |
| |
| refresh_btn.click( |
| fn=get_tunnel_status_message, |
| inputs=None, |
| outputs=status_text |
| ) |
| |
| toggle_api_btn.click( |
| fn=toggle_api, |
| inputs=None, |
| outputs=status_text |
| ).then( |
| fn=get_tunnel_status_message, |
| inputs=None, |
| outputs=status_text |
| ) |
| |
| demo.load( |
| fn=get_tunnel_status_message, |
| inputs=None, |
| outputs=status_text |
| ) |
| |
| demo.queue(default_concurrency_limit=MAX_CONCURRENT) |
| demo.launch() |