Spaces:
Running
Running
| // Dedicated worker for the 5Hz LM. Isolated WASM heap lets the 1.77 GB model | |
| // load without competing with DiT + encoders in the main worker. | |
| import { AutoTokenizer } from "@huggingface/transformers"; | |
| import * as ort from "onnxruntime-web/webgpu"; | |
| const MODEL_REPO = "shreyask/ACE-Step-v1.5-ONNX"; | |
| const MODEL_REVISION = "bdabfb5684fd70fcc76f98cbb51bb9ebc47ee342"; | |
| const ONNX_BASE = `https://huggingface.co/${MODEL_REPO}/resolve/${MODEL_REVISION}/onnx`; | |
| const LM_TOKENIZER_REPO = "ACE-Step/acestep-5Hz-lm-0.6B"; | |
| const CACHE_NAME = "ace-step-onnx-v12"; | |
| const NUM_KV_LAYERS = 28; | |
| const NUM_KV_HEADS = 8; | |
| const KV_HEAD_DIM = 128; | |
| const VOCAB_SIZE = 217204; | |
| const NUM_CODES = 64000; | |
| const POOL_WINDOW = 5; | |
| const EOS_ID = 151645; | |
| let tokenizer = null; | |
| let session = null; | |
| function post(type, data = {}) { | |
| self.postMessage({ type, ...data }); | |
| } | |
| async function fetchBuffer(url, label) { | |
| const cache = await caches.open(CACHE_NAME); | |
| const cached = await cache.match(url); | |
| if (cached) { | |
| post("progress", { label, loaded: 1, total: 1, percent: 100 }); | |
| return await cached.arrayBuffer(); | |
| } | |
| const response = await fetch(url); | |
| const total = parseInt(response.headers.get("content-length") || "0"); | |
| const reader = response.body.getReader(); | |
| const chunks = []; | |
| let loaded = 0; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| chunks.push(value); | |
| loaded += value.length; | |
| if (total > 0) post("progress", { label, loaded, total, percent: (loaded / total) * 100 }); | |
| } | |
| const buf = new Uint8Array(loaded); | |
| let offset = 0; | |
| for (const c of chunks) { buf.set(c, offset); offset += c.length; } | |
| try { | |
| await cache.put(url, new Response(buf.buffer.slice(0), { headers: { "Content-Type": "application/octet-stream" } })); | |
| } catch (_) {} | |
| return buf.buffer; | |
| } | |
| function tensor(data, dims, type = "float32") { | |
| return new ort.Tensor(type, data, dims); | |
| } | |
| async function loadModel() { | |
| ort.env.wasm.numThreads = 1; | |
| ort.env.wasm.simd = true; | |
| post("status", { message: "Loading LM tokenizer..." }); | |
| tokenizer = await AutoTokenizer.from_pretrained(LM_TOKENIZER_REPO); | |
| post("status", { message: "Loading LM graph..." }); | |
| const graphBuf = await fetchBuffer(`${ONNX_BASE}/lm_kv_q4.onnx`, "LM graph"); | |
| post("status", { message: "Loading LM weights (1.24 GB q4)..." }); | |
| const weightsBuf = await fetchBuffer(`${ONNX_BASE}/lm_kv_q4.onnx.data`, "LM weights"); | |
| post("status", { message: "Creating LM session..." }); | |
| // Try WebGPU first (faster), fall back to WASM if unsupported ops | |
| try { | |
| session = await ort.InferenceSession.create(graphBuf, { | |
| executionProviders: ["webgpu"], | |
| externalData: [{ path: "lm_kv_q4.onnx.data", data: weightsBuf }], | |
| }); | |
| post("status", { message: "LM on WebGPU" }); | |
| } catch (err) { | |
| console.warn("LM WebGPU failed, falling back to WASM:", err.message); | |
| session = await ort.InferenceSession.create(graphBuf, { | |
| executionProviders: ["wasm"], | |
| externalData: [{ path: "lm_kv_q4.onnx.data", data: weightsBuf }], | |
| }); | |
| post("status", { message: "LM on WASM (WebGPU unsupported)" }); | |
| } | |
| post("status", { message: "LM ready" }); | |
| post("loaded"); | |
| } | |
| function createEmptyKV() { | |
| const kv = {}; | |
| for (let i = 0; i < NUM_KV_LAYERS; i++) { | |
| kv[`past_key_values.${i}.key`] = tensor(new Float32Array(0), [1, NUM_KV_HEADS, 0, KV_HEAD_DIM]); | |
| kv[`past_key_values.${i}.value`] = tensor(new Float32Array(0), [1, NUM_KV_HEADS, 0, KV_HEAD_DIM]); | |
| } | |
| return kv; | |
| } | |
| function extractKV(outputs) { | |
| const kv = {}; | |
| for (let i = 0; i < NUM_KV_LAYERS; i++) { | |
| kv[`past_key_values.${i}.key`] = outputs[`present.${i}.key`]; | |
| kv[`past_key_values.${i}.value`] = outputs[`present.${i}.value`]; | |
| } | |
| return kv; | |
| } | |
| function sampleToken(logits, recentTokens, { temperature = 0.8, topK = 200, topP = 0.95, repetitionPenalty = 1.05, repWindow = 64 } = {}) { | |
| const V = logits.length; | |
| const scores = new Float32Array(V); | |
| scores.set(logits); | |
| // Repetition penalty | |
| if (repetitionPenalty !== 1.0 && recentTokens.length > 0) { | |
| const window = recentTokens.slice(-repWindow); | |
| const seen = new Set(window); | |
| for (const tok of seen) { | |
| if (tok >= 0 && tok < V) { | |
| scores[tok] = scores[tok] > 0 ? scores[tok] / repetitionPenalty : scores[tok] * repetitionPenalty; | |
| } | |
| } | |
| } | |
| // Temperature | |
| if (temperature !== 1.0 && temperature > 0) { | |
| const invT = 1.0 / temperature; | |
| for (let i = 0; i < V; i++) scores[i] *= invT; | |
| } | |
| // Top-K via full sort (good enough — sort overhead << LM forward pass) | |
| const k = Math.min(topK, V); | |
| const idx = new Array(V); | |
| for (let i = 0; i < V; i++) idx[i] = i; | |
| idx.sort((a, b) => scores[b] - scores[a]); | |
| const topIdx = idx.slice(0, k); | |
| // Softmax with log-sum-exp trick | |
| let maxS = -Infinity; | |
| for (const i of topIdx) if (scores[i] > maxS) maxS = scores[i]; | |
| const exps = new Float64Array(k); | |
| let sumE = 0; | |
| for (let i = 0; i < k; i++) { | |
| const e = Math.exp(scores[topIdx[i]] - maxS); | |
| exps[i] = e; sumE += e; | |
| } | |
| const probs = new Float64Array(k); | |
| for (let i = 0; i < k; i++) probs[i] = exps[i] / sumE; | |
| // Top-P (nucleus) | |
| let cum = 0, nuc = k; | |
| for (let i = 0; i < k; i++) { | |
| cum += probs[i]; | |
| if (cum >= topP) { nuc = i + 1; break; } | |
| } | |
| // Multinomial sample within nucleus | |
| let nSum = 0; | |
| for (let i = 0; i < nuc; i++) nSum += probs[i]; | |
| const r = Math.random() * nSum; | |
| let acc = 0; | |
| for (let i = 0; i < nuc; i++) { | |
| acc += probs[i]; | |
| if (r < acc) return topIdx[i]; | |
| } | |
| return topIdx[nuc - 1]; | |
| } | |
| function buildPrompt(caption, lyrics, duration, language = "en") { | |
| const instruction = "Generate audio semantic tokens based on the given conditions"; | |
| const lyricsSection = lyrics.trim() | |
| ? `# Languages\n${language}\n\n# Lyrics\n${lyrics}` | |
| : "# Lyrics\n[instrumental]"; | |
| const userPrompt = `# Instruction\n${instruction}\n\n# Caption\n${caption}\n\n${lyricsSection}\n\n# Metas\n- language: ${language}\n- duration: ${duration} seconds\n<|endoftext|>\n`; | |
| return `<|im_start|>user\n${userPrompt}<|im_end|>\n<|im_start|>assistant\n`; | |
| } | |
| async function generate({ caption, lyrics, duration, numLatentFrames }) { | |
| const numCodes5Hz = Math.ceil(numLatentFrames / POOL_WINDOW); | |
| post("status", { message: `LM: generating ~${numCodes5Hz} codes...` }); | |
| const prompt = buildPrompt(caption, lyrics, Math.round(duration)); | |
| const encoded = tokenizer(prompt); | |
| const promptIds = Array.from(encoded.input_ids.data, Number); | |
| // CoT metadata ~150 tokens + numCodes5Hz audio codes + some slack | |
| const maxNewTokens = Math.min(numCodes5Hz + 250, 600); | |
| const audioCodeTokenRegex = /<\|audio_code_(\d+)\|>/g; | |
| const startTime = performance.now(); | |
| const allIds = [...promptIds]; | |
| // Prefill | |
| post("status", { message: `LM prefill (${promptIds.length} tokens)...` }); | |
| const prefillIds = new BigInt64Array(promptIds.map(BigInt)); | |
| const prefillMask = new BigInt64Array(promptIds.length).fill(1n); | |
| const prefillPos = new BigInt64Array(promptIds.map((_, i) => BigInt(i))); | |
| let outputs = await session.run({ | |
| input_ids: tensor(prefillIds, [1, promptIds.length], "int64"), | |
| attention_mask: tensor(prefillMask, [1, promptIds.length], "int64"), | |
| position_ids: tensor(prefillPos, [1, promptIds.length], "int64"), | |
| ...createEmptyKV(), | |
| }); | |
| let kv = extractKV(outputs); | |
| let lastLogits = outputs.logits.data.slice((promptIds.length - 1) * VOCAB_SIZE, promptIds.length * VOCAB_SIZE); | |
| let nextToken = sampleToken(lastLogits, allIds); | |
| allIds.push(nextToken); | |
| // Decode loop — exit early once we have enough audio codes | |
| let codesSoFar = 0; | |
| for (let step = 0; step < maxNewTokens - 1; step++) { | |
| if (nextToken === EOS_ID) break; | |
| if (codesSoFar >= numCodes5Hz) break; // have enough codes, stop early | |
| if (step % 20 === 0) { | |
| const elapsed = ((performance.now() - startTime) / 1000).toFixed(1); | |
| const tps = (step / Math.max(parseFloat(elapsed), 0.1)).toFixed(1); | |
| post("status", { message: `LM: ${step} tokens, ${codesSoFar}/${numCodes5Hz} codes (${tps} tok/s)` }); | |
| } | |
| const seqLen = allIds.length; | |
| outputs = await session.run({ | |
| input_ids: tensor(new BigInt64Array([BigInt(nextToken)]), [1, 1], "int64"), | |
| attention_mask: tensor(new BigInt64Array(seqLen).fill(1n), [1, seqLen], "int64"), | |
| position_ids: tensor(new BigInt64Array([BigInt(seqLen - 1)]), [1, 1], "int64"), | |
| ...kv, | |
| }); | |
| kv = extractKV(outputs); | |
| lastLogits = outputs.logits.data.slice(0, VOCAB_SIZE); | |
| nextToken = sampleToken(lastLogits, allIds); | |
| allIds.push(nextToken); | |
| // Streaming decode — check if this token is an audio code | |
| const tokText = tokenizer.decode([nextToken], { skip_special_tokens: false }); | |
| if (audioCodeTokenRegex.test(tokText)) codesSoFar++; | |
| audioCodeTokenRegex.lastIndex = 0; | |
| } | |
| const elapsed = ((performance.now() - startTime) / 1000).toFixed(1); | |
| const generatedIds = allIds.slice(promptIds.length); | |
| const outputText = tokenizer.decode(generatedIds, { skip_special_tokens: false }); | |
| console.log(`[lm] ${generatedIds.length} tokens in ${elapsed}s`); | |
| // Find end of thinking | |
| const thinkEnd = outputText.indexOf("</think>"); | |
| console.log("[lm] CoT length:", thinkEnd >= 0 ? thinkEnd : "no </think> found"); | |
| console.log("[lm] preview (CoT):", thinkEnd >= 0 ? outputText.slice(0, thinkEnd + 10) : outputText.slice(0, 500)); | |
| console.log("[lm] preview (after think):", thinkEnd >= 0 ? outputText.slice(thinkEnd, thinkEnd + 500) : "(n/a)"); | |
| const audioCodes = []; | |
| for (const m of outputText.matchAll(/<\|audio_code_(\d+)\|>/g)) { | |
| audioCodes.push(Math.min(Math.max(parseInt(m[1]), 0), NUM_CODES - 1)); | |
| } | |
| console.log(`[lm] extracted ${audioCodes.length} audio codes, first 10:`, audioCodes.slice(0, 10)); | |
| // Truncate if too many but DON'T zero-pad — main worker uses last-frame padding in 25Hz space (matches MLX port) | |
| const codes = new Int32Array(audioCodes.slice(0, numCodes5Hz)); | |
| post("audio_codes", { codes, elapsed, tokenCount: generatedIds.length }); | |
| } | |
| self.onmessage = async (e) => { | |
| const { type, ...data } = e.data; | |
| try { | |
| if (type === "load") await loadModel(); | |
| else if (type === "generate") await generate(data); | |
| } catch (err) { | |
| post("error", { message: err.message, stack: err.stack }); | |
| } | |
| }; | |