Tokenization #
Intro #
The process of encoding strings into tokens. A Tokenizer is a class that implements the encode and decode methods.
assert [15496, 11, 995, 0] == Tokenizer.encode("Hello, 🌍! 你好!")
assert "Hello, 🌍! 你好!" == Tokenizer.decode([15496, 11, 995, 0])- Vocabulary size: the number of possible tokens.
- Spaces are also part of tokens
- “hello world, hello world” -> [“hello”, " world", “,”, " hello", " world"] -> [24912, 2375, 11, 40617, 2375]
- In BPE tokenizer, spaces are put in front intentionally during the pre-tokenization process.
- Compression ratio = Number of bytes / Number of tokens
Character-based tokenization #
Converting each Unicode character into a code point (integer)
- Problem 1: this is a very large vocabulary (around 150K Unicode characters)
- Problem 2: many characters are quite rare (e.g., ) Compression ratio is around 1.53 (due to those non-ASCII characters that uses 2, 3, 4 bytes in UTF-8, e.g., len(你) = 3 bytes)
Byte-based tokenization #
Converting each byte (typical in UTF-8) into a code point (integer, which is between 0 and 255) Vocabulary = 256, which is very small. Problem: Compression ratio = 1, which is terrible, leading to too long sequences (attention is quadratic)
Word-based tokenizer #
Converting each word into a code point
Problems:
- Vocabulary is unbounded
- The model won’t learn much about those many rare words
Byte Pair Encoding (BPE) #
Basic idea: train the tokenizer on raw text to automatically determine the vocabulary Intuition: common sequences of characters are represented by a single token, rare sequences are represented by many tokens Sketch: start with each byte as a token, then successively merge the most common pairs of adjacent tokens
Code below comes from stanford-cs336/lectures
tokenizer = BPETokenizer(params)
string = "the quick brown fox"
indices = tokenizer.encode(string)
reconstructed_string = tokenizer.decode(indices)
assert string == reconstructed_string
def train_bpe(string: str, num_merges: int) -> BPETokenizerParams:
text("Start with the list of bytes of `string`.")
indices = list(map(int, string.encode("utf-8")))
merges: dict[tuple[int, int], int] = {} # index1, index2 => merged index
vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)} # index -> bytes
for i in range(num_merges):
# Count the number of occurrences of each pair of tokens
counts = count_adjacent_pairs(indices)
# Find the most common pair
pair = max(counts, key=counts.get)
# Merge that pair
new_index = 256 + i # @inspect new_index
merges[pair] = new_index # @inspect merges
vocab[new_index] = vocab[pair[0]] + vocab[pair[1]]
indices = merge(indices, pair, new_index)
compression_ratio = get_compression_ratio(string, indices)
return BPETokenizerParams(vocab=vocab, merges=merges)
@dataclass(frozen=True)
class BPETokenizerParams:
"""All you need to specify a BPETokenizer."""
vocab: dict[int, bytes] # index -> bytes
merges: dict[tuple[int, int], int] # index1,index2 -> new_index
class BPETokenizer(Tokenizer):
"""BPE tokenizer given a set of merges and a vocabulary."""
def __init__(self, params: BPETokenizerParams):
self.params = params
def encode(self, string: str) -> list[int]:
indices = list(map(int, string.encode("utf-8")))
# Note: this is a very slow implementation
for pair, new_index in self.params.merges.items():
indices = merge(indices, pair, new_index)
return indices
def decode(self, indices: list[int]) -> str:
bytes_list = list(map(self.params.vocab.get, indices))
string = b"".join(bytes_list).decode("utf-8")
return string