Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import subprocess | |
| import traceback | |
| import gradio as gr | |
| MODEL_ID = "CADCODER/CAD-Coder" # HF model id | |
| REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git" | |
| REPO_DIR = "CAD-Coder" | |
| # 1) git-clone the repo if missing (your preference) | |
| if not os.path.isdir(REPO_DIR): | |
| try: | |
| print("Cloning CAD-Coder repo...") | |
| subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True) | |
| except Exception as e: | |
| print("Could not clone repository:", e) | |
| # 2) Prepare model loader with graceful fallback to HF Inference API | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN") | |
| local_generate = None | |
| api_generate = None | |
| # Try to load model locally (8-bit if possible) | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| try: | |
| import bitsandbytes as bnb # optional; enables 8-bit loading | |
| has_bnb = True | |
| except Exception: | |
| has_bnb = False | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True) | |
| load_kwargs = {"device_map": "auto", "trust_remote_code": True} | |
| if has_bnb: | |
| print("bitsandbytes available — will attempt 8-bit load (saves memory).") | |
| load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16}) | |
| else: | |
| # attempt fp16 auto if GPU present | |
| if torch.cuda.is_available(): | |
| load_kwargs["torch_dtype"] = torch.float16 | |
| print("Loading model (this can take a while)...") | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs) | |
| if hasattr(model, "to"): | |
| # make sure model moved to devices by device_map | |
| pass | |
| device = next(model.parameters()).device | |
| print("Model loaded on device:", device) | |
| def local_generate_fn(prompt, max_new_tokens=512): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| return tokenizer.decode(gen[0], skip_special_tokens=True) | |
| local_generate = local_generate_fn | |
| except Exception as e: | |
| print("Local model load failed or not feasible in this environment.") | |
| traceback.print_exc() | |
| # Fallback: Hugging Face Inference API (works without loading weights locally) | |
| if local_generate is None: | |
| try: | |
| from huggingface_hub import InferenceApi | |
| print("Setting up HF Inference API client as fallback...") | |
| api = InferenceApi(repo_id=MODEL_ID, token=hf_token) | |
| def api_generate_fn(prompt, max_new_tokens=512): | |
| # call the hosted inference endpoint | |
| out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens}) | |
| # Response can be a dict or list depending on pipeline; extract defensively | |
| if isinstance(out, list): | |
| first = out[0] | |
| if isinstance(first, dict): | |
| return first.get("generated_text") or str(first) | |
| return str(first) | |
| elif isinstance(out, dict): | |
| return out.get("generated_text") or str(out) | |
| else: | |
| return str(out) | |
| api_generate = api_generate_fn | |
| print("Inference API fallback ready.") | |
| except Exception as e: | |
| print("HF Inference API not available:", e) | |
| traceback.print_exc() | |
| # Final generate function: prefer local, otherwise API fallback, otherwise error | |
| def generate(prompt, max_new_tokens=512): | |
| if local_generate: | |
| return local_generate(prompt, max_new_tokens=max_new_tokens) | |
| elif api_generate: | |
| return api_generate(prompt, max_new_tokens=max_new_tokens) | |
| else: | |
| return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware." | |
| # Gradio UI | |
| def run_prompt(prompt, max_tokens=512): | |
| if not prompt or prompt.strip() == "": | |
| return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')." | |
| try: | |
| return generate(prompt, max_new_tokens=int(max_tokens)) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"Generation error: {e}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# CAD-Coder (Text → CadQuery code)") | |
| prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...") | |
| max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens") | |
| out = gr.Textbox(label="Generated CadQuery code", lines=18) | |
| btn = gr.Button("Generate") | |
| btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |