// Dedicated worker for the 5Hz LM. Isolated WASM heap lets the 1.77 GB model // load without competing with DiT + encoders in the main worker. import { AutoTokenizer } from "@huggingface/transformers"; import * as ort from "onnxruntime-web/webgpu"; const MODEL_REPO = "shreyask/ACE-Step-v1.5-ONNX"; const MODEL_REVISION = "bdabfb5684fd70fcc76f98cbb51bb9ebc47ee342"; const ONNX_BASE = `https://huggingface.co/${MODEL_REPO}/resolve/${MODEL_REVISION}/onnx`; const LM_TOKENIZER_REPO = "ACE-Step/acestep-5Hz-lm-0.6B"; const CACHE_NAME = "ace-step-onnx-v12"; const NUM_KV_LAYERS = 28; const NUM_KV_HEADS = 8; const KV_HEAD_DIM = 128; const VOCAB_SIZE = 217204; const NUM_CODES = 64000; const POOL_WINDOW = 5; const EOS_ID = 151645; let tokenizer = null; let session = null; function post(type, data = {}) { self.postMessage({ type, ...data }); } async function fetchBuffer(url, label) { const cache = await caches.open(CACHE_NAME); const cached = await cache.match(url); if (cached) { post("progress", { label, loaded: 1, total: 1, percent: 100 }); return await cached.arrayBuffer(); } const response = await fetch(url); const total = parseInt(response.headers.get("content-length") || "0"); const reader = response.body.getReader(); const chunks = []; let loaded = 0; while (true) { const { done, value } = await reader.read(); if (done) break; chunks.push(value); loaded += value.length; if (total > 0) post("progress", { label, loaded, total, percent: (loaded / total) * 100 }); } const buf = new Uint8Array(loaded); let offset = 0; for (const c of chunks) { buf.set(c, offset); offset += c.length; } try { await cache.put(url, new Response(buf.buffer.slice(0), { headers: { "Content-Type": "application/octet-stream" } })); } catch (_) {} return buf.buffer; } function tensor(data, dims, type = "float32") { return new ort.Tensor(type, data, dims); } async function loadModel() { ort.env.wasm.numThreads = 1; ort.env.wasm.simd = true; post("status", { message: "Loading LM tokenizer..." }); tokenizer = await AutoTokenizer.from_pretrained(LM_TOKENIZER_REPO); post("status", { message: "Loading LM graph..." }); const graphBuf = await fetchBuffer(`${ONNX_BASE}/lm_kv_q4.onnx`, "LM graph"); post("status", { message: "Loading LM weights (1.24 GB q4)..." }); const weightsBuf = await fetchBuffer(`${ONNX_BASE}/lm_kv_q4.onnx.data`, "LM weights"); post("status", { message: "Creating LM session..." }); // Try WebGPU first (faster), fall back to WASM if unsupported ops try { session = await ort.InferenceSession.create(graphBuf, { executionProviders: ["webgpu"], externalData: [{ path: "lm_kv_q4.onnx.data", data: weightsBuf }], }); post("status", { message: "LM on WebGPU" }); } catch (err) { console.warn("LM WebGPU failed, falling back to WASM:", err.message); session = await ort.InferenceSession.create(graphBuf, { executionProviders: ["wasm"], externalData: [{ path: "lm_kv_q4.onnx.data", data: weightsBuf }], }); post("status", { message: "LM on WASM (WebGPU unsupported)" }); } post("status", { message: "LM ready" }); post("loaded"); } function createEmptyKV() { const kv = {}; for (let i = 0; i < NUM_KV_LAYERS; i++) { kv[`past_key_values.${i}.key`] = tensor(new Float32Array(0), [1, NUM_KV_HEADS, 0, KV_HEAD_DIM]); kv[`past_key_values.${i}.value`] = tensor(new Float32Array(0), [1, NUM_KV_HEADS, 0, KV_HEAD_DIM]); } return kv; } function extractKV(outputs) { const kv = {}; for (let i = 0; i < NUM_KV_LAYERS; i++) { kv[`past_key_values.${i}.key`] = outputs[`present.${i}.key`]; kv[`past_key_values.${i}.value`] = outputs[`present.${i}.value`]; } return kv; } function sampleToken(logits, recentTokens, { temperature = 0.8, topK = 200, topP = 0.95, repetitionPenalty = 1.05, repWindow = 64 } = {}) { const V = logits.length; const scores = new Float32Array(V); scores.set(logits); // Repetition penalty if (repetitionPenalty !== 1.0 && recentTokens.length > 0) { const window = recentTokens.slice(-repWindow); const seen = new Set(window); for (const tok of seen) { if (tok >= 0 && tok < V) { scores[tok] = scores[tok] > 0 ? scores[tok] / repetitionPenalty : scores[tok] * repetitionPenalty; } } } // Temperature if (temperature !== 1.0 && temperature > 0) { const invT = 1.0 / temperature; for (let i = 0; i < V; i++) scores[i] *= invT; } // Top-K via full sort (good enough — sort overhead << LM forward pass) const k = Math.min(topK, V); const idx = new Array(V); for (let i = 0; i < V; i++) idx[i] = i; idx.sort((a, b) => scores[b] - scores[a]); const topIdx = idx.slice(0, k); // Softmax with log-sum-exp trick let maxS = -Infinity; for (const i of topIdx) if (scores[i] > maxS) maxS = scores[i]; const exps = new Float64Array(k); let sumE = 0; for (let i = 0; i < k; i++) { const e = Math.exp(scores[topIdx[i]] - maxS); exps[i] = e; sumE += e; } const probs = new Float64Array(k); for (let i = 0; i < k; i++) probs[i] = exps[i] / sumE; // Top-P (nucleus) let cum = 0, nuc = k; for (let i = 0; i < k; i++) { cum += probs[i]; if (cum >= topP) { nuc = i + 1; break; } } // Multinomial sample within nucleus let nSum = 0; for (let i = 0; i < nuc; i++) nSum += probs[i]; const r = Math.random() * nSum; let acc = 0; for (let i = 0; i < nuc; i++) { acc += probs[i]; if (r < acc) return topIdx[i]; } return topIdx[nuc - 1]; } function buildPrompt(caption, lyrics, duration, language = "en") { const instruction = "Generate audio semantic tokens based on the given conditions"; const lyricsSection = lyrics.trim() ? `# Languages\n${language}\n\n# Lyrics\n${lyrics}` : "# Lyrics\n[instrumental]"; const userPrompt = `# Instruction\n${instruction}\n\n# Caption\n${caption}\n\n${lyricsSection}\n\n# Metas\n- language: ${language}\n- duration: ${duration} seconds\n<|endoftext|>\n`; return `<|im_start|>user\n${userPrompt}<|im_end|>\n<|im_start|>assistant\n`; } async function generate({ caption, lyrics, duration, numLatentFrames }) { const numCodes5Hz = Math.ceil(numLatentFrames / POOL_WINDOW); post("status", { message: `LM: generating ~${numCodes5Hz} codes...` }); const prompt = buildPrompt(caption, lyrics, Math.round(duration)); const encoded = tokenizer(prompt); const promptIds = Array.from(encoded.input_ids.data, Number); // CoT metadata ~150 tokens + numCodes5Hz audio codes + some slack const maxNewTokens = Math.min(numCodes5Hz + 250, 600); const audioCodeTokenRegex = /<\|audio_code_(\d+)\|>/g; const startTime = performance.now(); const allIds = [...promptIds]; // Prefill post("status", { message: `LM prefill (${promptIds.length} tokens)...` }); const prefillIds = new BigInt64Array(promptIds.map(BigInt)); const prefillMask = new BigInt64Array(promptIds.length).fill(1n); const prefillPos = new BigInt64Array(promptIds.map((_, i) => BigInt(i))); let outputs = await session.run({ input_ids: tensor(prefillIds, [1, promptIds.length], "int64"), attention_mask: tensor(prefillMask, [1, promptIds.length], "int64"), position_ids: tensor(prefillPos, [1, promptIds.length], "int64"), ...createEmptyKV(), }); let kv = extractKV(outputs); let lastLogits = outputs.logits.data.slice((promptIds.length - 1) * VOCAB_SIZE, promptIds.length * VOCAB_SIZE); let nextToken = sampleToken(lastLogits, allIds); allIds.push(nextToken); // Decode loop — exit early once we have enough audio codes let codesSoFar = 0; for (let step = 0; step < maxNewTokens - 1; step++) { if (nextToken === EOS_ID) break; if (codesSoFar >= numCodes5Hz) break; // have enough codes, stop early if (step % 20 === 0) { const elapsed = ((performance.now() - startTime) / 1000).toFixed(1); const tps = (step / Math.max(parseFloat(elapsed), 0.1)).toFixed(1); post("status", { message: `LM: ${step} tokens, ${codesSoFar}/${numCodes5Hz} codes (${tps} tok/s)` }); } const seqLen = allIds.length; outputs = await session.run({ input_ids: tensor(new BigInt64Array([BigInt(nextToken)]), [1, 1], "int64"), attention_mask: tensor(new BigInt64Array(seqLen).fill(1n), [1, seqLen], "int64"), position_ids: tensor(new BigInt64Array([BigInt(seqLen - 1)]), [1, 1], "int64"), ...kv, }); kv = extractKV(outputs); lastLogits = outputs.logits.data.slice(0, VOCAB_SIZE); nextToken = sampleToken(lastLogits, allIds); allIds.push(nextToken); // Streaming decode — check if this token is an audio code const tokText = tokenizer.decode([nextToken], { skip_special_tokens: false }); if (audioCodeTokenRegex.test(tokText)) codesSoFar++; audioCodeTokenRegex.lastIndex = 0; } const elapsed = ((performance.now() - startTime) / 1000).toFixed(1); const generatedIds = allIds.slice(promptIds.length); const outputText = tokenizer.decode(generatedIds, { skip_special_tokens: false }); console.log(`[lm] ${generatedIds.length} tokens in ${elapsed}s`); // Find end of thinking const thinkEnd = outputText.indexOf(""); console.log("[lm] CoT length:", thinkEnd >= 0 ? thinkEnd : "no found"); console.log("[lm] preview (CoT):", thinkEnd >= 0 ? outputText.slice(0, thinkEnd + 10) : outputText.slice(0, 500)); console.log("[lm] preview (after think):", thinkEnd >= 0 ? outputText.slice(thinkEnd, thinkEnd + 500) : "(n/a)"); const audioCodes = []; for (const m of outputText.matchAll(/<\|audio_code_(\d+)\|>/g)) { audioCodes.push(Math.min(Math.max(parseInt(m[1]), 0), NUM_CODES - 1)); } console.log(`[lm] extracted ${audioCodes.length} audio codes, first 10:`, audioCodes.slice(0, 10)); // Truncate if too many but DON'T zero-pad — main worker uses last-frame padding in 25Hz space (matches MLX port) const codes = new Int32Array(audioCodes.slice(0, numCodes5Hz)); post("audio_codes", { codes, elapsed, tokenCount: generatedIds.length }); } self.onmessage = async (e) => { const { type, ...data } = e.data; try { if (type === "load") await loadModel(); else if (type === "generate") await generate(data); } catch (err) { post("error", { message: err.message, stack: err.stack }); } };