| """ |
| Byte-Pair Encoding trainer and codec optimized for JSON value strings. |
| |
| Uses incremental pair counting with pair→word index for fast merges. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| from collections import defaultdict |
| from typing import Optional |
|
|
|
|
| def _bytes_to_unicode() -> dict[int, str]: |
| """Map bytes 0-255 to unicode chars, avoiding control/whitespace collisions.""" |
| bs = ( |
| list(range(ord("!"), ord("~") + 1)) |
| + list(range(ord("¡"), ord("¬") + 1)) |
| + list(range(ord("®"), ord("ÿ") + 1)) |
| ) |
| cs = bs[:] |
| n = 0 |
| for b in range(2**8): |
| if b not in bs: |
| bs.append(b) |
| cs.append(2**8 + n) |
| n += 1 |
| return {b: chr(c) for b, c in zip(bs, cs)} |
|
|
|
|
| BYTE_ENCODER = _bytes_to_unicode() |
| BYTE_DECODER = {v: k for k, v in BYTE_ENCODER.items()} |
|
|
| _PRE_TOK_PAT = re.compile( |
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z_]+| ?[0-9]+| ?[^\s\w]+|\s+|.""" |
| ) |
|
|
|
|
| class BPETrainer: |
| """Train a BPE vocabulary from a corpus of JSON value strings.""" |
|
|
| def __init__(self, vocab_size: int = 4096, min_frequency: int = 2): |
| self.vocab_size = vocab_size |
| self.min_frequency = min_frequency |
| self.merges: list[tuple[str, str]] = [] |
| self.vocab: dict[str, int] = {} |
| self._id_to_tok: dict[int, str] | None = None |
|
|
| def _pre_tokenize(self, text: str) -> list[str]: |
| return _PRE_TOK_PAT.findall(text) |
|
|
| def _text_to_bytes(self, text: str) -> tuple[str, ...]: |
| return tuple(BYTE_ENCODER[b] for b in text.encode("utf-8")) |
|
|
| def train(self, texts: list[str]) -> None: |
| """Train BPE with pair→word index for O(affected) merges.""" |
| |
| word_freqs: dict[tuple[str, ...], int] = {} |
| for text in texts: |
| for word in self._pre_tokenize(text): |
| bw = self._text_to_bytes(word) |
| word_freqs[bw] = word_freqs.get(bw, 0) + 1 |
|
|
| |
| base_vocab: set[str] = set() |
| for word in word_freqs: |
| base_vocab.update(word) |
|
|
| num_merges = self.vocab_size - len(base_vocab) - 1 |
|
|
| |
| words: list[list[str]] = [] |
| freqs: list[int] = [] |
| for w, f in word_freqs.items(): |
| words.append(list(w)) |
| freqs.append(f) |
|
|
| |
| pair_counts: dict[tuple[str, str], int] = defaultdict(int) |
| pair_to_words: dict[tuple[str, str], set[int]] = defaultdict(set) |
|
|
| for idx, (w, f) in enumerate(zip(words, freqs)): |
| for i in range(len(w) - 1): |
| p = (w[i], w[i + 1]) |
| pair_counts[p] += f |
| pair_to_words[p].add(idx) |
|
|
| for _ in range(max(0, num_merges)): |
| if not pair_counts: |
| break |
|
|
| |
| best_pair = max(pair_counts, key=pair_counts.__getitem__) |
| if pair_counts[best_pair] < self.min_frequency: |
| break |
|
|
| a, b = best_pair |
| merged = a + b |
| self.merges.append(best_pair) |
|
|
| |
| affected = list(pair_to_words.pop(best_pair, set())) |
| del pair_counts[best_pair] |
|
|
| for idx in affected: |
| w = words[idx] |
| f = freqs[idx] |
|
|
| |
| new_w: list[str] = [] |
| i = 0 |
| while i < len(w): |
| if i < len(w) - 1 and w[i] == a and w[i + 1] == b: |
| |
| if new_w: |
| old_left = (new_w[-1], a) |
| pair_counts[old_left] -= f |
| if pair_counts[old_left] <= 0: |
| pair_counts.pop(old_left, None) |
| pair_to_words[old_left].discard(idx) |
|
|
| if i + 2 < len(w): |
| old_right = (b, w[i + 2]) |
| pair_counts[old_right] -= f |
| if pair_counts[old_right] <= 0: |
| pair_counts.pop(old_right, None) |
| pair_to_words[old_right].discard(idx) |
|
|
| new_w.append(merged) |
|
|
| |
| if len(new_w) >= 2: |
| nl = (new_w[-2], merged) |
| pair_counts[nl] += f |
| pair_to_words[nl].add(idx) |
|
|
| if i + 2 < len(w): |
| nr = (merged, w[i + 2]) |
| pair_counts[nr] += f |
| pair_to_words[nr].add(idx) |
|
|
| i += 2 |
| else: |
| new_w.append(w[i]) |
| i += 1 |
|
|
| words[idx] = new_w |
|
|
| |
| if _ % 50 == 0: |
| pair_counts = defaultdict(int, {k: v for k, v in pair_counts.items() if v > 0}) |
|
|
| |
| self.vocab = {} |
| idx = 0 |
| for ch in sorted(base_vocab): |
| self.vocab[ch] = idx |
| idx += 1 |
| for merge in self.merges: |
| m = merge[0] + merge[1] |
| if m not in self.vocab: |
| self.vocab[m] = idx |
| idx += 1 |
| self.vocab["<UNK>"] = idx |
| self._id_to_tok = None |
|
|
| def _apply_merge(self, word: tuple[str, ...], pair: tuple[str, str]) -> tuple[str, ...]: |
| new: list[str] = [] |
| i = 0 |
| while i < len(word): |
| if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: |
| new.append(pair[0] + pair[1]) |
| i += 2 |
| else: |
| new.append(word[i]) |
| i += 1 |
| return tuple(new) |
|
|
| def encode_word(self, word: str) -> list[str]: |
| bw = self._text_to_bytes(word) |
| if len(bw) == 1: |
| return [bw[0]] |
| for merge in self.merges: |
| bw = self._apply_merge(bw, merge) |
| return list(bw) |
|
|
| def encode(self, text: str) -> list[str]: |
| tokens: list[str] = [] |
| for word in self._pre_tokenize(text): |
| tokens.extend(self.encode_word(word)) |
| return tokens |
|
|
| def encode_to_ids(self, text: str) -> list[int]: |
| tokens = self.encode(text) |
| unk_id = self.vocab.get("<UNK>", 0) |
| return [self.vocab.get(t, unk_id) for t in tokens] |
|
|
| def id_to_token(self, token_id: int) -> str: |
| if self._id_to_tok is None: |
| self._id_to_tok = {v: k for k, v in self.vocab.items()} |
| return self._id_to_tok.get(token_id, "<UNK>") |
|
|
| def decode_ids(self, ids: list[int]) -> str: |
| return self.decode_tokens([self.id_to_token(i) for i in ids]) |
|
|
| def decode_tokens(self, tokens: list[str]) -> str: |
| byte_str = "".join(tokens) |
| return bytearray(BYTE_DECODER.get(c, ord(c)) for c in byte_str).decode("utf-8", errors="replace") |
|
|
| def save(self, path: str) -> None: |
| with open(path, "w") as f: |
| json.dump({ |
| "version": "json-tokenizer-bpe-v1", |
| "vocab_size": self.vocab_size, |
| "min_frequency": self.min_frequency, |
| "merges": [list(m) for m in self.merges], |
| "vocab": self.vocab, |
| }, f, indent=2) |
|
|
| @classmethod |
| def load(cls, path: str) -> "BPETrainer": |
| with open(path) as f: |
| data = json.load(f) |
| t = cls(vocab_size=data["vocab_size"], min_frequency=data["min_frequency"]) |
| t.merges = [tuple(m) for m in data["merges"]] |
| t.vocab = data["vocab"] |
| t._id_to_tok = None |
| return t |
|
|