You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

AuriStream-1B

AuriStream is a biologically-inspired, GPT-style autoregressive Transformer trained to predict cochlear tokens - discrete codes produced by a companion “WavCoch” tokenizer over long speech contexts (through transofmration imitation). Auristream utilizes a long context window of (~20 s, ~4096 tokens) and is trained on LibriLight (~60k h) for ~500k steps. It learns rich, time‑aligned representations (useful for linear probing) and can roll out future tokens to generate speech continuations. Inputs are token IDs; use it with a WavCoch quantizer for audio->tokens and with the built in vocoder for tokens->audio.


Installation

pip install -U torch torchaudio transformers

This model uses custom code; when loading from Hugging Face, pass trust_remote_code=True.


Use Case 1) get hidden‑state embeddings from a WAV

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Load the WavCoch tokenizer (audio -> token IDs)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# 2) Load the AuriStream LM (tokens -> hidden states / next-token preds)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_40Pred_librilight_500k", trust_remote_code=True
).to(device).eval()

# 3) Read an audio file (mono, 16 kHz recommended)
wav, sr = torchaudio.load("sample.wav")
if wav.size(0) > 1:  # stereo -> mono
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000

# 4) Quantize to cochlear token IDs
with torch.no_grad():
    # quantizer.quantize expects (B, T); returns LongTensor (B, L)
    token_ids = quantizer.quantize(wav.unsqueeze(0).to(device))  # (1, L)

# 5) Forward pass with hidden states
with torch.no_grad():
    out = lm(token_ids, output_hidden_states=True)
    last_layer = out["hidden_states"][-1]   # (1, T, D)
    clip_embedding = last_layer.mean(dim=1)  # time mean-pool -> (1, D)

print("Pooled embedding shape:", clip_embedding.shape)

Notes

  • output_hidden_states=True returns all layers; choose a layer or pool over time.
  • For word/phone segments, slice the time axis before pooling.

Use Case 2) generate a speech continuation (token rollout)

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# WavCoch tokenizer (audio->tokens, tokens->cochleagram->audio)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# AuriStream LM (tokens->next tokens)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_40Pred_librilight_500k", trust_remote_code=True
).to(device).eval()

# Load & prep a short prompt (e.g., 3s of audio at 16 kHz)
wav, sr = torchaudio.load("prompt.wav")
if wav.size(0) > 1:
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000
prompt_seconds = 3
wav = wav[:, : sr * prompt_seconds]

# Quantize prompt to token IDs
with torch.no_grad():
    prompt_tokens = quantizer.quantize(wav.unsqueeze(0).to(device))  # (1, L)

# Decide how many future tokens to generate
tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds)
rollout_seconds = 3
rollout_steps = int(round(tokens_per_sec * rollout_seconds))

# Roll out future tokens
with torch.no_grad():
    # returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional
    pred_tokens, _ = lm.generate(
        prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0
    )
    full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1)  # (1, L+K)

Citation

If you use this model, please cite:

@misc{tuckute2025cochleartokens,
  title = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens},
  author = {Tuckute, Greta and Kotar, Klemen and Fedorenko, Evelina and Yamins, Daniel L. K.},
  year = {2025},
  eprint = {2508.11598},
  archivePrefix = {arXiv},
  url = {https://arxiv.org/abs/2508.11598}
}
Downloads last month
29
Safetensors
Model size
1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for TuKoResearch/AuriStream1B_40Pred_librilight_500k