ace-step-webgpu / _source /src /lm-worker.js
shreyask's picture
Initial deploy: built app at root + source under _source/
24b9788 verified
// Dedicated worker for the 5Hz LM. Isolated WASM heap lets the 1.77 GB model
// load without competing with DiT + encoders in the main worker.
import { AutoTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime-web/webgpu";
const MODEL_REPO = "shreyask/ACE-Step-v1.5-ONNX";
const MODEL_REVISION = "bdabfb5684fd70fcc76f98cbb51bb9ebc47ee342";
const ONNX_BASE = `https://huggingface.co/${MODEL_REPO}/resolve/${MODEL_REVISION}/onnx`;
const LM_TOKENIZER_REPO = "ACE-Step/acestep-5Hz-lm-0.6B";
const CACHE_NAME = "ace-step-onnx-v12";
const NUM_KV_LAYERS = 28;
const NUM_KV_HEADS = 8;
const KV_HEAD_DIM = 128;
const VOCAB_SIZE = 217204;
const NUM_CODES = 64000;
const POOL_WINDOW = 5;
const EOS_ID = 151645;
let tokenizer = null;
let session = null;
function post(type, data = {}) {
self.postMessage({ type, ...data });
}
async function fetchBuffer(url, label) {
const cache = await caches.open(CACHE_NAME);
const cached = await cache.match(url);
if (cached) {
post("progress", { label, loaded: 1, total: 1, percent: 100 });
return await cached.arrayBuffer();
}
const response = await fetch(url);
const total = parseInt(response.headers.get("content-length") || "0");
const reader = response.body.getReader();
const chunks = [];
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
if (total > 0) post("progress", { label, loaded, total, percent: (loaded / total) * 100 });
}
const buf = new Uint8Array(loaded);
let offset = 0;
for (const c of chunks) { buf.set(c, offset); offset += c.length; }
try {
await cache.put(url, new Response(buf.buffer.slice(0), { headers: { "Content-Type": "application/octet-stream" } }));
} catch (_) {}
return buf.buffer;
}
function tensor(data, dims, type = "float32") {
return new ort.Tensor(type, data, dims);
}
async function loadModel() {
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;
post("status", { message: "Loading LM tokenizer..." });
tokenizer = await AutoTokenizer.from_pretrained(LM_TOKENIZER_REPO);
post("status", { message: "Loading LM graph..." });
const graphBuf = await fetchBuffer(`${ONNX_BASE}/lm_kv_q4.onnx`, "LM graph");
post("status", { message: "Loading LM weights (1.24 GB q4)..." });
const weightsBuf = await fetchBuffer(`${ONNX_BASE}/lm_kv_q4.onnx.data`, "LM weights");
post("status", { message: "Creating LM session..." });
// Try WebGPU first (faster), fall back to WASM if unsupported ops
try {
session = await ort.InferenceSession.create(graphBuf, {
executionProviders: ["webgpu"],
externalData: [{ path: "lm_kv_q4.onnx.data", data: weightsBuf }],
});
post("status", { message: "LM on WebGPU" });
} catch (err) {
console.warn("LM WebGPU failed, falling back to WASM:", err.message);
session = await ort.InferenceSession.create(graphBuf, {
executionProviders: ["wasm"],
externalData: [{ path: "lm_kv_q4.onnx.data", data: weightsBuf }],
});
post("status", { message: "LM on WASM (WebGPU unsupported)" });
}
post("status", { message: "LM ready" });
post("loaded");
}
function createEmptyKV() {
const kv = {};
for (let i = 0; i < NUM_KV_LAYERS; i++) {
kv[`past_key_values.${i}.key`] = tensor(new Float32Array(0), [1, NUM_KV_HEADS, 0, KV_HEAD_DIM]);
kv[`past_key_values.${i}.value`] = tensor(new Float32Array(0), [1, NUM_KV_HEADS, 0, KV_HEAD_DIM]);
}
return kv;
}
function extractKV(outputs) {
const kv = {};
for (let i = 0; i < NUM_KV_LAYERS; i++) {
kv[`past_key_values.${i}.key`] = outputs[`present.${i}.key`];
kv[`past_key_values.${i}.value`] = outputs[`present.${i}.value`];
}
return kv;
}
function sampleToken(logits, recentTokens, { temperature = 0.8, topK = 200, topP = 0.95, repetitionPenalty = 1.05, repWindow = 64 } = {}) {
const V = logits.length;
const scores = new Float32Array(V);
scores.set(logits);
// Repetition penalty
if (repetitionPenalty !== 1.0 && recentTokens.length > 0) {
const window = recentTokens.slice(-repWindow);
const seen = new Set(window);
for (const tok of seen) {
if (tok >= 0 && tok < V) {
scores[tok] = scores[tok] > 0 ? scores[tok] / repetitionPenalty : scores[tok] * repetitionPenalty;
}
}
}
// Temperature
if (temperature !== 1.0 && temperature > 0) {
const invT = 1.0 / temperature;
for (let i = 0; i < V; i++) scores[i] *= invT;
}
// Top-K via full sort (good enough — sort overhead << LM forward pass)
const k = Math.min(topK, V);
const idx = new Array(V);
for (let i = 0; i < V; i++) idx[i] = i;
idx.sort((a, b) => scores[b] - scores[a]);
const topIdx = idx.slice(0, k);
// Softmax with log-sum-exp trick
let maxS = -Infinity;
for (const i of topIdx) if (scores[i] > maxS) maxS = scores[i];
const exps = new Float64Array(k);
let sumE = 0;
for (let i = 0; i < k; i++) {
const e = Math.exp(scores[topIdx[i]] - maxS);
exps[i] = e; sumE += e;
}
const probs = new Float64Array(k);
for (let i = 0; i < k; i++) probs[i] = exps[i] / sumE;
// Top-P (nucleus)
let cum = 0, nuc = k;
for (let i = 0; i < k; i++) {
cum += probs[i];
if (cum >= topP) { nuc = i + 1; break; }
}
// Multinomial sample within nucleus
let nSum = 0;
for (let i = 0; i < nuc; i++) nSum += probs[i];
const r = Math.random() * nSum;
let acc = 0;
for (let i = 0; i < nuc; i++) {
acc += probs[i];
if (r < acc) return topIdx[i];
}
return topIdx[nuc - 1];
}
function buildPrompt(caption, lyrics, duration, language = "en") {
const instruction = "Generate audio semantic tokens based on the given conditions";
const lyricsSection = lyrics.trim()
? `# Languages\n${language}\n\n# Lyrics\n${lyrics}`
: "# Lyrics\n[instrumental]";
const userPrompt = `# Instruction\n${instruction}\n\n# Caption\n${caption}\n\n${lyricsSection}\n\n# Metas\n- language: ${language}\n- duration: ${duration} seconds\n<|endoftext|>\n`;
return `<|im_start|>user\n${userPrompt}<|im_end|>\n<|im_start|>assistant\n`;
}
async function generate({ caption, lyrics, duration, numLatentFrames }) {
const numCodes5Hz = Math.ceil(numLatentFrames / POOL_WINDOW);
post("status", { message: `LM: generating ~${numCodes5Hz} codes...` });
const prompt = buildPrompt(caption, lyrics, Math.round(duration));
const encoded = tokenizer(prompt);
const promptIds = Array.from(encoded.input_ids.data, Number);
// CoT metadata ~150 tokens + numCodes5Hz audio codes + some slack
const maxNewTokens = Math.min(numCodes5Hz + 250, 600);
const audioCodeTokenRegex = /<\|audio_code_(\d+)\|>/g;
const startTime = performance.now();
const allIds = [...promptIds];
// Prefill
post("status", { message: `LM prefill (${promptIds.length} tokens)...` });
const prefillIds = new BigInt64Array(promptIds.map(BigInt));
const prefillMask = new BigInt64Array(promptIds.length).fill(1n);
const prefillPos = new BigInt64Array(promptIds.map((_, i) => BigInt(i)));
let outputs = await session.run({
input_ids: tensor(prefillIds, [1, promptIds.length], "int64"),
attention_mask: tensor(prefillMask, [1, promptIds.length], "int64"),
position_ids: tensor(prefillPos, [1, promptIds.length], "int64"),
...createEmptyKV(),
});
let kv = extractKV(outputs);
let lastLogits = outputs.logits.data.slice((promptIds.length - 1) * VOCAB_SIZE, promptIds.length * VOCAB_SIZE);
let nextToken = sampleToken(lastLogits, allIds);
allIds.push(nextToken);
// Decode loop — exit early once we have enough audio codes
let codesSoFar = 0;
for (let step = 0; step < maxNewTokens - 1; step++) {
if (nextToken === EOS_ID) break;
if (codesSoFar >= numCodes5Hz) break; // have enough codes, stop early
if (step % 20 === 0) {
const elapsed = ((performance.now() - startTime) / 1000).toFixed(1);
const tps = (step / Math.max(parseFloat(elapsed), 0.1)).toFixed(1);
post("status", { message: `LM: ${step} tokens, ${codesSoFar}/${numCodes5Hz} codes (${tps} tok/s)` });
}
const seqLen = allIds.length;
outputs = await session.run({
input_ids: tensor(new BigInt64Array([BigInt(nextToken)]), [1, 1], "int64"),
attention_mask: tensor(new BigInt64Array(seqLen).fill(1n), [1, seqLen], "int64"),
position_ids: tensor(new BigInt64Array([BigInt(seqLen - 1)]), [1, 1], "int64"),
...kv,
});
kv = extractKV(outputs);
lastLogits = outputs.logits.data.slice(0, VOCAB_SIZE);
nextToken = sampleToken(lastLogits, allIds);
allIds.push(nextToken);
// Streaming decode — check if this token is an audio code
const tokText = tokenizer.decode([nextToken], { skip_special_tokens: false });
if (audioCodeTokenRegex.test(tokText)) codesSoFar++;
audioCodeTokenRegex.lastIndex = 0;
}
const elapsed = ((performance.now() - startTime) / 1000).toFixed(1);
const generatedIds = allIds.slice(promptIds.length);
const outputText = tokenizer.decode(generatedIds, { skip_special_tokens: false });
console.log(`[lm] ${generatedIds.length} tokens in ${elapsed}s`);
// Find end of thinking
const thinkEnd = outputText.indexOf("</think>");
console.log("[lm] CoT length:", thinkEnd >= 0 ? thinkEnd : "no </think> found");
console.log("[lm] preview (CoT):", thinkEnd >= 0 ? outputText.slice(0, thinkEnd + 10) : outputText.slice(0, 500));
console.log("[lm] preview (after think):", thinkEnd >= 0 ? outputText.slice(thinkEnd, thinkEnd + 500) : "(n/a)");
const audioCodes = [];
for (const m of outputText.matchAll(/<\|audio_code_(\d+)\|>/g)) {
audioCodes.push(Math.min(Math.max(parseInt(m[1]), 0), NUM_CODES - 1));
}
console.log(`[lm] extracted ${audioCodes.length} audio codes, first 10:`, audioCodes.slice(0, 10));
// Truncate if too many but DON'T zero-pad — main worker uses last-frame padding in 25Hz space (matches MLX port)
const codes = new Int32Array(audioCodes.slice(0, numCodes5Hz));
post("audio_codes", { codes, elapsed, tokenCount: generatedIds.length });
}
self.onmessage = async (e) => {
const { type, ...data } = e.data;
try {
if (type === "load") await loadModel();
else if (type === "generate") await generate(data);
} catch (err) {
post("error", { message: err.message, stack: err.stack });
}
};