Text Generation
Transformers
PyTorch
English
experimental
research
bit-level
transformer
reversible
safety
telemetry
language-modeling
Instructions to use WCNegentropy/BitTransformerLM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WCNegentropy/BitTransformerLM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="WCNegentropy/BitTransformerLM")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("WCNegentropy/BitTransformerLM", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use WCNegentropy/BitTransformerLM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "WCNegentropy/BitTransformerLM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/WCNegentropy/BitTransformerLM
- SGLang
How to use WCNegentropy/BitTransformerLM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "WCNegentropy/BitTransformerLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "WCNegentropy/BitTransformerLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use WCNegentropy/BitTransformerLM with Docker Model Runner:
docker model run hf.co/WCNegentropy/BitTransformerLM
| #!/usr/bin/env python3 | |
| """ | |
| BitTransformerLM Gradio Dashboard | |
| ================================= | |
| Comprehensive Gradio interface for BitTransformerLM with full feature parity to the Flask dashboard. | |
| Supports both local deployment and HuggingFace Spaces integration while maintaining MCP server compatibility. | |
| """ | |
| import io | |
| import json | |
| import os | |
| import sys | |
| import traceback | |
| import warnings | |
| from typing import Any, Dict, List, Optional, Union, Tuple | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import numpy as np | |
| from pathlib import Path | |
| import threading | |
| import time | |
| import requests | |
| from concurrent.futures import ThreadPoolExecutor | |
| import uuid | |
| # Add BitTransformerLM to path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| # BitTransformerLM imports | |
| from bit_transformer.model import BitTransformerLM, infer_long_sequence | |
| from bit_transformer.optimization import configure_optimizer | |
| from bit_transformer.collapse import collapse_submodel | |
| from bit_transformer.dashboard import plot_telemetry | |
| from bit_transformer.scale import expand_model | |
| from bit_transformer.bit_io import text_to_bits, bits_to_text | |
| from bit_transformer.safety import hil_safe_inference | |
| from bit_transformer.compression import model_output_decompress, compress_bits | |
| from bit_transformer.distributed import wrap_fsdp | |
| from bit_transformer.training import train_loop | |
| from bit_transformer.telemetry import detect_metric_drift | |
| from bit_transformer.quantization import prepare_qat_fx, convert_qat_fx | |
| from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint | |
| from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset | |
| # Global state management | |
| class GradioModelManager: | |
| """Enhanced ModelManager for Gradio interface with thread safety.""" | |
| def __init__(self): | |
| self.model = None | |
| self.config = {} | |
| self.telemetry_log = { | |
| "negentropy": [], | |
| "lz_complexity": [], | |
| "symbiosis_score": [], | |
| "steps": [] | |
| } | |
| self.c_floor = 0.3 | |
| self.s_floor = 0.5 | |
| self.lambda_weights = {"K": 1.0, "C": 1.0, "S": 1.0} | |
| self.compression_enabled = False | |
| self.qat_enabled = False | |
| self.diffusion_enabled = False | |
| self.gpu_enabled = False | |
| # Background job management | |
| self.executor = ThreadPoolExecutor(max_workers=4) | |
| self.jobs = {} | |
| self.mcp_server_addr = os.getenv("MCP_SERVER_ADDR") | |
| # Thread safety | |
| self.lock = threading.Lock() | |
| def init_model(self, model_config: dict): | |
| """Initialize BitTransformerLM model with given configuration.""" | |
| with self.lock: | |
| try: | |
| # Clean config - remove None values | |
| clean_config = {k: v for k, v in model_config.items() if v is not None and v != ""} | |
| self.model = BitTransformerLM(**clean_config) | |
| self.config = clean_config | |
| # Apply transformations | |
| if self.qat_enabled: | |
| self.model = prepare_qat_fx(self.model) | |
| if self.gpu_enabled and torch.cuda.is_available(): | |
| self.model = self.model.cuda() | |
| return f"โ Model initialized with config: {clean_config}" | |
| except Exception as e: | |
| return f"โ Model initialization failed: {str(e)}" | |
| def train_step(self, bits_input, epochs=1): | |
| """Execute training step(s) with given bit input.""" | |
| if self.model is None: | |
| return "โ Model not initialized", None, None | |
| try: | |
| # Parse bits input | |
| if isinstance(bits_input, str): | |
| if bits_input.strip().startswith('['): | |
| # JSON format | |
| bits = json.loads(bits_input) | |
| else: | |
| # Space-separated format | |
| bits = [int(x) for x in bits_input.strip().split()] | |
| else: | |
| bits = bits_input | |
| tensor = torch.tensor(bits, dtype=torch.long) | |
| if self.gpu_enabled and torch.cuda.is_available(): | |
| tensor = tensor.cuda() | |
| # Training loop | |
| total_loss = 0 | |
| compression_ratio = 1.0 | |
| for epoch in range(epochs): | |
| self.model.train() | |
| # Forward pass with telemetry | |
| if self.compression_enabled: | |
| compressed_bits, ratio = compress_bits(bits) | |
| tensor = torch.tensor(compressed_bits, dtype=torch.long) | |
| compression_ratio = ratio | |
| output, telemetry = self.model(tensor.unsqueeze(0)) | |
| # Compute loss | |
| if output.dim() == 3: | |
| loss = F.cross_entropy( | |
| output.view(-1, output.size(-1)), | |
| tensor[:-1].unsqueeze(0).contiguous().view(-1), | |
| ignore_index=-1 | |
| ) | |
| else: | |
| loss = F.cross_entropy(output, tensor.unsqueeze(0)) | |
| # Backward pass | |
| loss.backward() | |
| # Update telemetry | |
| self._update_telemetry(telemetry) | |
| total_loss += loss.item() | |
| avg_loss = total_loss / epochs | |
| return f"โ Training completed. Average Loss: {avg_loss:.4f}", avg_loss, compression_ratio | |
| except Exception as e: | |
| return f"โ Training failed: {str(e)}", None, None | |
| def inference(self, bits_input, long_inference=False, ctx_bits=4096, overlap=256): | |
| """Run inference on bit input.""" | |
| if self.model is None: | |
| return "โ Model not initialized", None | |
| try: | |
| # Parse bits input | |
| if isinstance(bits_input, str): | |
| if bits_input.strip().startswith('['): | |
| bits = json.loads(bits_input) | |
| else: | |
| bits = [int(x) for x in bits_input.strip().split()] | |
| else: | |
| bits = bits_input | |
| tensor = torch.tensor(bits, dtype=torch.long) | |
| if self.gpu_enabled and torch.cuda.is_available(): | |
| tensor = tensor.cuda() | |
| self.model.eval() | |
| with torch.inference_mode(): | |
| if long_inference or len(bits) > ctx_bits: | |
| # Long sequence inference | |
| output, telemetry = infer_long_sequence( | |
| self.model, tensor.unsqueeze(0), | |
| ctx_bits=ctx_bits, overlap=overlap | |
| ) | |
| else: | |
| # Standard inference with safety gates | |
| output, telemetry = hil_safe_inference( | |
| self.model, tensor.unsqueeze(0), | |
| c_floor=self.c_floor, s_floor=self.s_floor | |
| ) | |
| # Update telemetry | |
| self._update_telemetry(telemetry) | |
| output_bits = output.squeeze(0).cpu().tolist() | |
| return f"โ Inference completed. Output length: {len(output_bits)}", output_bits | |
| except Exception as e: | |
| return f"โ Inference failed: {str(e)}", None | |
| def text_inference(self, text_input): | |
| """Convert text to bits, run inference, convert back to text.""" | |
| try: | |
| # Text to bits | |
| bits = text_to_bits(text_input) | |
| # Run inference | |
| result, output_bits = self.inference(bits) | |
| if output_bits is None: | |
| return result, None | |
| # Convert back to text | |
| try: | |
| output_text = bits_to_text(output_bits) | |
| return f"โ Text inference completed.", output_text | |
| except Exception as e: | |
| return f"โ Inference completed, but text conversion failed: {str(e)}", str(output_bits) | |
| except Exception as e: | |
| return f"โ Text inference failed: {str(e)}", None | |
| def scale_model(self, width_multiplier): | |
| """Scale up model width.""" | |
| if self.model is None: | |
| return "โ Model not initialized" | |
| try: | |
| with self.lock: | |
| self.model = expand_model(self.model, width_multiplier) | |
| return f"โ Model scaled by factor {width_multiplier}" | |
| except Exception as e: | |
| return f"โ Model scaling failed: {str(e)}" | |
| def collapse_model(self, cluster_bits, target_params, width_scale=1.0): | |
| """Collapse model using cluster analysis.""" | |
| if self.model is None: | |
| return "โ Model not initialized" | |
| try: | |
| # Parse inputs | |
| if isinstance(cluster_bits, str): | |
| clusters = json.loads(cluster_bits) | |
| else: | |
| clusters = cluster_bits | |
| if isinstance(target_params, str): | |
| params = json.loads(target_params) | |
| else: | |
| params = target_params | |
| with self.lock: | |
| collapsed_model = collapse_submodel( | |
| self.model, clusters, params, width_scale | |
| ) | |
| self.model = collapsed_model | |
| return f"โ Model collapsed successfully" | |
| except Exception as e: | |
| return f"โ Model collapse failed: {str(e)}" | |
| def get_model_status(self): | |
| """Get current model status and configuration.""" | |
| if self.model is None: | |
| return "โ No model initialized" | |
| try: | |
| param_count = sum(p.numel() for p in self.model.parameters()) | |
| status = { | |
| "initialized": True, | |
| "parameters": param_count, | |
| "config": self.config, | |
| "gpu_enabled": self.gpu_enabled, | |
| "qat_enabled": self.qat_enabled, | |
| "compression_enabled": self.compression_enabled, | |
| "diffusion_enabled": self.diffusion_enabled, | |
| } | |
| return json.dumps(status, indent=2) | |
| except Exception as e: | |
| return f"โ Status check failed: {str(e)}" | |
| def get_telemetry_plot(self): | |
| """Generate telemetry plot.""" | |
| try: | |
| if not any(self.telemetry_log.values()): | |
| # Return empty plot | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.text(0.5, 0.5, 'No telemetry data yet', ha='center', va='center', transform=ax.transAxes) | |
| ax.set_title('Telemetry Metrics') | |
| return fig | |
| fig, axes = plot_telemetry( | |
| self.telemetry_log, | |
| k_floor=0.5, # Negentropy floor | |
| c_floor=self.c_floor, | |
| s_floor=self.s_floor | |
| ) | |
| return fig | |
| except Exception as e: | |
| # Return error plot | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.text(0.5, 0.5, f'Plot error: {str(e)}', ha='center', va='center', transform=ax.transAxes) | |
| ax.set_title('Telemetry Metrics - Error') | |
| return fig | |
| def _update_telemetry(self, telemetry_dict): | |
| """Update telemetry log with new values.""" | |
| if not telemetry_dict: | |
| return | |
| step = len(self.telemetry_log["steps"]) | |
| self.telemetry_log["steps"].append(step) | |
| # Extract metrics with defaults | |
| self.telemetry_log["negentropy"].append( | |
| float(telemetry_dict.get("negentropy", torch.tensor(0.0)).mean().item()) | |
| ) | |
| self.telemetry_log["lz_complexity"].append( | |
| float(telemetry_dict.get("lz_complexity_logits", torch.tensor(0.0)).mean().item()) | |
| ) | |
| self.telemetry_log["symbiosis_score"].append( | |
| float(telemetry_dict.get("symbiosis_score", torch.tensor(0.0)).mean().item()) | |
| ) | |
| def huggingface_upload(self, repo_id, hf_token=None): | |
| """Upload model to HuggingFace.""" | |
| if self.model is None: | |
| return "โ Model not initialized" | |
| try: | |
| if hf_token: | |
| hf_login(hf_token) | |
| save_checkpoint(self.model, repo_id, self.config) | |
| return f"โ Model uploaded to {repo_id}" | |
| except Exception as e: | |
| return f"โ HF upload failed: {str(e)}" | |
| def huggingface_download(self, repo_id, hf_token=None): | |
| """Download model from HuggingFace.""" | |
| try: | |
| if hf_token: | |
| hf_login(hf_token) | |
| with self.lock: | |
| model, config = download_checkpoint(repo_id) | |
| self.model = model | |
| self.config = config | |
| return f"โ Model downloaded from {repo_id}" | |
| except Exception as e: | |
| return f"โ HF download failed: {str(e)}" | |
| def mcp_request(self, endpoint, data=None, method="POST"): | |
| """Make request to MCP server if available.""" | |
| if not self.mcp_server_addr: | |
| return "โ MCP server not configured" | |
| try: | |
| url = self.mcp_server_addr.rstrip("/") + endpoint | |
| if method == "POST": | |
| resp = requests.post(url, json=data, timeout=30) | |
| else: | |
| resp = requests.get(url, timeout=30) | |
| resp.raise_for_status() | |
| if resp.headers.get("Content-Type", "").startswith("image/"): | |
| return "โ MCP request completed (binary data)" | |
| return f"โ MCP request completed: {resp.json()}" | |
| except Exception as e: | |
| return f"โ MCP request failed: {str(e)}" | |
| # Global manager instance | |
| manager = GradioModelManager() | |
| def create_gradio_interface(): | |
| """Create the main Gradio interface with all BitTransformerLM features.""" | |
| # Helper functions for Gradio callbacks | |
| def init_model_callback(d_model, nhead, num_layers, dim_feedforward, max_seq_len, | |
| chunk_size, overlap, reversible, use_checkpoint, act_threshold, | |
| c_floor, s_floor): | |
| """Initialize model with form parameters.""" | |
| config = { | |
| "d_model": d_model, | |
| "nhead": nhead, | |
| "num_layers": num_layers, | |
| "dim_feedforward": dim_feedforward, | |
| "max_seq_len": max_seq_len, | |
| "chunk_size": chunk_size if chunk_size > 0 else None, | |
| "overlap": overlap, | |
| "reversible": reversible, | |
| "use_checkpoint": use_checkpoint, | |
| "act_threshold": act_threshold | |
| } | |
| # Update safety floors | |
| manager.c_floor = c_floor | |
| manager.s_floor = s_floor | |
| result = manager.init_model(config) | |
| status = manager.get_model_status() | |
| plot = manager.get_telemetry_plot() | |
| return result, status, plot | |
| def train_callback(bits_input, epochs, file_input): | |
| """Training callback with file upload support.""" | |
| if file_input is not None: | |
| # Process uploaded file | |
| try: | |
| if file_input.name.endswith(('.txt', '.md')): | |
| with open(file_input.name, 'r') as f: | |
| text = f.read() | |
| bits = text_to_bits(text) | |
| else: | |
| with open(file_input.name, 'rb') as f: | |
| data = f.read() | |
| # Convert bytes to bits | |
| bits = [] | |
| for byte in data: | |
| for i in range(8): | |
| bits.append((byte >> (7-i)) & 1) | |
| result, loss, ratio = manager.train_step(bits, epochs) | |
| except Exception as e: | |
| result = f"โ File processing failed: {str(e)}" | |
| loss, ratio = None, None | |
| else: | |
| result, loss, ratio = manager.train_step(bits_input, epochs) | |
| status = manager.get_model_status() | |
| plot = manager.get_telemetry_plot() | |
| return result, status, plot, f"Compression Ratio: {ratio:.2f}" if ratio else "" | |
| def inference_callback(bits_input, file_input): | |
| """Standard inference callback.""" | |
| if file_input is not None: | |
| # Process uploaded file similar to training | |
| try: | |
| if file_input.name.endswith(('.txt', '.md')): | |
| with open(file_input.name, 'r') as f: | |
| text = f.read() | |
| bits = text_to_bits(text) | |
| else: | |
| with open(file_input.name, 'rb') as f: | |
| data = f.read() | |
| bits = [] | |
| for byte in data: | |
| for i in range(8): | |
| bits.append((byte >> (7-i)) & 1) | |
| result, output_bits = manager.inference(bits) | |
| except Exception as e: | |
| result = f"โ File processing failed: {str(e)}" | |
| output_bits = None | |
| else: | |
| result, output_bits = manager.inference(bits_input) | |
| return result, str(output_bits) if output_bits else "" | |
| def long_inference_callback(bits_input, ctx_bits, overlap): | |
| """Long sequence inference callback.""" | |
| result, output_bits = manager.inference(bits_input, long_inference=True, | |
| ctx_bits=ctx_bits, overlap=overlap) | |
| return result, str(output_bits) if output_bits else "" | |
| def text_inference_callback(text_input): | |
| """Text-to-text inference callback.""" | |
| result, output_text = manager.text_inference(text_input) | |
| return result, output_text if output_text else "" | |
| # Create Gradio interface | |
| with gr.Blocks(title="BitTransformerLM Dashboard", | |
| theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# ๐ค BitTransformerLM Interactive Dashboard") | |
| gr.Markdown("*Experimental bit-native transformer with comprehensive training and inference capabilities*") | |
| with gr.Tab("๐๏ธ Model Configuration"): | |
| gr.Markdown("## Initialize BitTransformerLM") | |
| with gr.Row(): | |
| with gr.Column(): | |
| d_model = gr.Number(label="d_model", value=64, info="Model width") | |
| nhead = gr.Number(label="nhead", value=4, info="Attention heads") | |
| num_layers = gr.Number(label="num_layers", value=2, info="Transformer layers") | |
| dim_feedforward = gr.Number(label="dim_feedforward", value=256, info="FFN dimension") | |
| with gr.Column(): | |
| max_seq_len = gr.Number(label="max_seq_len", value=512, info="Max sequence length") | |
| chunk_size = gr.Number(label="chunk_size", value=0, info="Chunk size (0=auto)") | |
| overlap = gr.Number(label="overlap", value=64, info="Sliding window overlap") | |
| act_threshold = gr.Number(label="act_threshold", value=0.95, info="ACT halt threshold") | |
| with gr.Row(): | |
| reversible = gr.Checkbox(label="Reversible Layers", value=False) | |
| use_checkpoint = gr.Checkbox(label="Gradient Checkpointing", value=True) | |
| with gr.Row(): | |
| c_floor = gr.Number(label="c_floor", value=0.3, info="Complexity safety floor") | |
| s_floor = gr.Number(label="s_floor", value=0.5, info="Symbiosis safety floor") | |
| init_btn = gr.Button("๐ Initialize Model", variant="primary") | |
| init_output = gr.Textbox(label="Initialization Result", interactive=False) | |
| with gr.Tab("๐ฏ Training"): | |
| gr.Markdown("## Train BitTransformerLM") | |
| with gr.Row(): | |
| with gr.Column(): | |
| train_bits = gr.Textbox( | |
| label="Bit Input", | |
| placeholder="0 1 0 1 or [0,1,0,1] or upload file", | |
| lines=3 | |
| ) | |
| train_file = gr.File(label="Upload Training File", file_types=[".txt", ".md", ".bin"]) | |
| train_epochs = gr.Number(label="Epochs", value=1, minimum=1) | |
| with gr.Column(): | |
| train_btn = gr.Button("๐ Start Training", variant="primary") | |
| train_output = gr.Textbox(label="Training Result", interactive=False) | |
| compression_output = gr.Textbox(label="Compression Info", interactive=False) | |
| with gr.Tab("๐ง Inference"): | |
| with gr.Tab("Standard Inference"): | |
| gr.Markdown("## Standard Inference") | |
| with gr.Row(): | |
| with gr.Column(): | |
| infer_bits = gr.Textbox( | |
| label="Bit Input", | |
| placeholder="0 1 0 1 or [0,1,0,1]", | |
| lines=3 | |
| ) | |
| infer_file = gr.File(label="Upload Inference File") | |
| with gr.Column(): | |
| infer_btn = gr.Button("๐ฏ Run Inference", variant="primary") | |
| infer_result = gr.Textbox(label="Result", interactive=False) | |
| infer_output = gr.Textbox(label="Output Bits", lines=5, interactive=False) | |
| with gr.Tab("Long Sequence Inference"): | |
| gr.Markdown("## Long Sequence Inference") | |
| with gr.Row(): | |
| with gr.Column(): | |
| long_bits = gr.Textbox( | |
| label="Long Bit Sequence", | |
| lines=5, | |
| placeholder="Long sequence of bits..." | |
| ) | |
| long_ctx_bits = gr.Number(label="Context Bits", value=4096) | |
| long_overlap = gr.Number(label="Overlap", value=256) | |
| with gr.Column(): | |
| long_infer_btn = gr.Button("๐ Run Long Inference", variant="primary") | |
| long_result = gr.Textbox(label="Result", interactive=False) | |
| long_output = gr.Textbox(label="Output Bits", lines=5, interactive=False) | |
| with gr.Tab("Text Inference"): | |
| gr.Markdown("## Text-to-Text Inference") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text to process...", | |
| lines=3 | |
| ) | |
| text_infer_btn = gr.Button("๐ Process Text", variant="primary") | |
| with gr.Column(): | |
| text_result = gr.Textbox(label="Result", interactive=False) | |
| text_output = gr.Textbox( | |
| label="Output Text", | |
| lines=5, | |
| interactive=False | |
| ) | |
| with gr.Tab("โ๏ธ Model Operations"): | |
| with gr.Tab("Scale Model"): | |
| gr.Markdown("## Scale Model Width") | |
| with gr.Row(): | |
| width_mult = gr.Number(label="Width Multiplier", value=1.5, step=0.1) | |
| scale_btn = gr.Button("๐ Scale Model", variant="secondary") | |
| scale_output = gr.Textbox(label="Scaling Result", interactive=False) | |
| with gr.Tab("Collapse Model"): | |
| gr.Markdown("## Collapse Submodel") | |
| with gr.Row(): | |
| with gr.Column(): | |
| cluster_bits = gr.Textbox( | |
| label="Cluster Bits (JSON)", | |
| placeholder='[[0,1,0,1],[1,1,0,0]]', | |
| lines=3 | |
| ) | |
| target_params = gr.Textbox( | |
| label="Target Parameters (JSON)", | |
| placeholder='{"d_model":32,"nhead":4,"num_layers":1}', | |
| lines=3 | |
| ) | |
| width_scale = gr.Number(label="Width Scale", value=1.0, step=0.1) | |
| with gr.Column(): | |
| collapse_btn = gr.Button("๐๏ธ Collapse Model", variant="secondary") | |
| collapse_output = gr.Textbox(label="Collapse Result", interactive=False) | |
| with gr.Tab("๐ Monitoring"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Model Status") | |
| status_output = gr.Code(label="Current Status", language="json") | |
| refresh_btn = gr.Button("๐ Refresh Status") | |
| with gr.Column(): | |
| gr.Markdown("## System Settings") | |
| with gr.Row(): | |
| gpu_checkbox = gr.Checkbox(label="๐ฅ Enable GPU/FSDP", value=False) | |
| qat_checkbox = gr.Checkbox(label="โก Enable 4-bit QAT", value=False) | |
| with gr.Row(): | |
| compression_checkbox = gr.Checkbox(label="๐๏ธ Enable Compression", value=False) | |
| diffusion_checkbox = gr.Checkbox(label="๐ Enable Diffusion Mode", value=False) | |
| gr.Markdown("## ๐ Telemetry Metrics") | |
| telemetry_plot = gr.Plot(label="K/C/S Metrics Over Time") | |
| with gr.Tab("โ๏ธ HuggingFace Integration"): | |
| gr.Markdown("## HuggingFace Model Hub") | |
| with gr.Row(): | |
| with gr.Column(): | |
| hf_repo_id = gr.Textbox(label="Repository ID", placeholder="username/model-name") | |
| hf_token = gr.Textbox(label="HF Token (optional)", type="password") | |
| with gr.Column(): | |
| with gr.Row(): | |
| hf_upload_btn = gr.Button("โฌ๏ธ Upload to HF", variant="secondary") | |
| hf_download_btn = gr.Button("โฌ๏ธ Download from HF", variant="secondary") | |
| hf_result = gr.Textbox(label="HuggingFace Result", interactive=False) | |
| # Event handlers | |
| init_btn.click( | |
| init_model_callback, | |
| inputs=[d_model, nhead, num_layers, dim_feedforward, max_seq_len, | |
| chunk_size, overlap, reversible, use_checkpoint, act_threshold, | |
| c_floor, s_floor], | |
| outputs=[init_output, status_output, telemetry_plot] | |
| ) | |
| train_btn.click( | |
| train_callback, | |
| inputs=[train_bits, train_epochs, train_file], | |
| outputs=[train_output, status_output, telemetry_plot, compression_output] | |
| ) | |
| infer_btn.click( | |
| inference_callback, | |
| inputs=[infer_bits, infer_file], | |
| outputs=[infer_result, infer_output] | |
| ) | |
| long_infer_btn.click( | |
| long_inference_callback, | |
| inputs=[long_bits, long_ctx_bits, long_overlap], | |
| outputs=[long_result, long_output] | |
| ) | |
| text_infer_btn.click( | |
| text_inference_callback, | |
| inputs=[text_input], | |
| outputs=[text_result, text_output] | |
| ) | |
| scale_btn.click( | |
| manager.scale_model, | |
| inputs=[width_mult], | |
| outputs=[scale_output] | |
| ) | |
| collapse_btn.click( | |
| manager.collapse_model, | |
| inputs=[cluster_bits, target_params, width_scale], | |
| outputs=[collapse_output] | |
| ) | |
| refresh_btn.click( | |
| manager.get_model_status, | |
| outputs=[status_output] | |
| ) | |
| hf_upload_btn.click( | |
| manager.huggingface_upload, | |
| inputs=[hf_repo_id, hf_token], | |
| outputs=[hf_result] | |
| ) | |
| hf_download_btn.click( | |
| manager.huggingface_download, | |
| inputs=[hf_repo_id, hf_token], | |
| outputs=[hf_result] | |
| ) | |
| # System settings callbacks | |
| def update_gpu_setting(enabled): | |
| manager.gpu_enabled = enabled | |
| return f"GPU/FSDP: {'Enabled' if enabled else 'Disabled'}" | |
| def update_qat_setting(enabled): | |
| manager.qat_enabled = enabled | |
| return f"QAT: {'Enabled' if enabled else 'Disabled'}" | |
| def update_compression_setting(enabled): | |
| manager.compression_enabled = enabled | |
| return f"Compression: {'Enabled' if enabled else 'Disabled'}" | |
| def update_diffusion_setting(enabled): | |
| manager.diffusion_enabled = enabled | |
| return f"Diffusion: {'Enabled' if enabled else 'Disabled'}" | |
| # Auto-refresh telemetry every 10 seconds | |
| interface.load( | |
| manager.get_telemetry_plot, | |
| outputs=[telemetry_plot], | |
| every=10 | |
| ) | |
| # Load initial status | |
| interface.load( | |
| manager.get_model_status, | |
| outputs=[status_output] | |
| ) | |
| return interface | |
| def run_gradio_server(host="127.0.0.1", port=7860, share=False): | |
| """Run the Gradio server.""" | |
| interface = create_gradio_interface() | |
| print("๐ Starting BitTransformerLM Gradio Dashboard...") | |
| print(f"๐ Server will be available at: http://{host}:{port}") | |
| if os.getenv("MCP_SERVER_ADDR"): | |
| print(f"๐ MCP Server configured at: {os.getenv('MCP_SERVER_ADDR')}") | |
| interface.launch( | |
| server_name=host, | |
| server_port=port, | |
| share=share, | |
| show_error=True, | |
| debug=True | |
| ) | |
| if __name__ == "__main__": | |
| # Support both local development and HF Spaces | |
| if os.getenv("SPACE_ID"): | |
| # Running on HuggingFace Spaces | |
| print("๐ค Running on HuggingFace Spaces") | |
| interface = create_gradio_interface() | |
| interface.launch() | |
| else: | |
| # Local development | |
| import argparse | |
| parser = argparse.ArgumentParser(description="BitTransformerLM Gradio Dashboard") | |
| parser.add_argument("--host", default="127.0.0.1", help="Host address") | |
| parser.add_argument("--port", type=int, default=7860, help="Port number") | |
| parser.add_argument("--share", action="store_true", help="Enable sharing") | |
| args = parser.parse_args() | |
| run_gradio_server(args.host, args.port, args.share) |