BPE分词器训练优化:初版实现与后续方向分析 2026-05-31阅读 0热度 0 其他 好的,没问题。作为常年跟BPE分词器较劲的工程师,这篇文档我来重写——去掉那些通用套话,换成干活时踩坑、总结出来的硬核经验。 --- > 如果你还没摸透 `cs336_assignment1_basics.pdf` 里BPE的原理,建议先翻 [[assigment1_overview&bpe_basics]] 看我们做的翻译和细节拆解。我们在那篇里提到过,朴素BPE的时间复杂度高得离谱;后来在 [[train_bpe_naive#测试]] 里也验证了它那令人捉急的训练速度,根本扛不住作业要求。所以当务之急,是在这个朴素版基础上动手做优化。 ### 朴素版详细分析 先复盘一下朴素版的瓶颈到底卡在哪。每一轮合并的核心操作就三步: 1. **`count_pairs`**:遍历所有 word 编码里的相邻字节对,时间复杂度 O(N),N 是所有 token 的总数。 2. **`max()`**:从统计结果里揪出频率最高的那个 pair。 3. **`merge_encoding`**:再次遍历所有 word 编码,执行替换。 我之前提过,问题全在第一站。为什么?因为每轮合并之后,绝大多数 word 的编码压根没变,但我们每次还是得把它们全部重新扫一遍。拿文档里的 `bpe_example` 来说: ```text low low low low low lower lower widest widest widest newest newest newest newest newest newest ``` 第一次合并的是 `(s, t)`,受影响的 word 只有3个。然而到第二轮,我们仍然要对所有 word 扫一遍。这种“无差别”的做法,效率可想而知。  ### 优化:告别全量扫描,拥抱增量更新 分析到这,优化思路已经很直白了:维护一个全局的 `pair_counts` 字典,每次合并只做增量更新,不再全部重算。 核心逻辑:合并 `(A, B) → AB` 时,`pair_counts` 的变化只跟两边的邻居有关。比如序列 `... X A B Y ...` 合并后变成 `... X AB Y ...`。那么: - 旧的相邻对 `(X, A)` 和 `(B, Y)` 会消失,需要减去它们的计数。 - 新的相邻对 `(X, AB)` 和 `(AB, Y)` 出现了,需要加上计数。 - 被合并的 `(A, B)` 这个 pair,计数自然清零。 这个优化的精髓在于: `每轮 merge: 直接从 pair_counts 字典里找最大值,复杂度 O(unique pairs)。只需扫描【含有 (p0,p1) 的 word】执行替换,复杂度 O(受影响的 word)。然后针对这些 word:旧邻对计数 -1,新邻对计数 +1,完成增量更新。` 这里有个关键点:合并 `(p0, p1) → new_id` 后,`pair_counts` 的变化完全由新旧 `encoding` 的差异决定。我们只需对旧 `encoding` 的相邻对逐个减计数,对新 `encoding` 的相邻对逐个加计数就行。这样做的好处是不用去分类讨论左右邻居的复杂情况,逻辑清晰,不容易出 bug。 #### 优化后的实现 ```python # 4. BPE training loop num_merges = vocab_size - size pair_counts = defaultdict(int, self.count_pairs(word_counts, word_encodings)) for merge_idx in range(num_merges): if not pair_counts: print("No more pairs to be merged, quit.") break # a. Find the max frequency pair to be merged merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], self.vocab[x[0]], self.vocab[x[1]])) # b. Merge and update the word encodings token_id = size for word in word_encodings: old_encoding = word_encodings[word] # skip if the encoding not changed if not any((old_encoding[i], old_encoding[i+1]) == merge_pair for i in range(len(old_encoding) - 1)): continue # update the encoding new_encoding = self.merge_encoding(old_encoding, merge_pair, token_id) # update count for pair cnt = word_counts[word] for i in range(len(old_encoding) - 1): pair_counts[(old_encoding[i], old_encoding[i+1])] -= cnt for i in range(len(new_encoding) - 1): pair_counts[(new_encoding[i], new_encoding[i+1])] += cnt # update word encoding word_encodings[word] = new_encoding # clear the pairs whose count is no more than 0 del_keys = [k for k, v in pair_counts.items() if v <= 0] for k in del_keys: del pair_counts[k] ``` 除了手写时的一些笔误,实现过程中还踩了几个坑,值得记一笔。 **1. `any()` 的用法失误** 一开始我是这样写的: ```python for i in range(len(old_encoding) - 1): if not any((old_encoding[i], old_encoding[i+1]) == merge_pair): continue ``` 一跑测试就收到 `TypeError`,大意是 `'bool' object is not iterable`。查了一下 `any()` 的用法才知道问题所在。 ```python def any(iterable): for element in iterable: if element: return True return False ``` `any()` 接收的参数必须是可迭代对象(比如列表、生成器),而我直接传了一个布尔值进去,自然不对。修正方案就是让 `any()` 接收一个生成器表达式。 **2. 清理失效pair的时机** 另一个Bug出现在清理某些不再存在的字节对时。我一开始的做法是在扣除旧 `encoding` 中pair频率的循环里,发现某个pair计数归零就立刻删除: ```python for i in range(len(old_encoding) - 1): pair_counts[(old_encoding[i], old_encoding[i+1])] -= cnt # clear immediately if pair_counts[(old_encoding[i], old_encoding[i+1])] <= 0: del pair_counts[(old_encoding[i], old_encoding[i+1])] ``` 这么干会碰到 `KeyError`。问题在于,同一个 word 的 `encoding` 里,同一个pair可能出现多次。假设出现了两次,第一次扣除后归零,我们立刻清除了这个 key,那第二次再扣的时候,这个 key 已经不存在了。再说,扣除和加回是分阶段操作的,一个 pair 可能先被扣到0,但同一轮又被加回来,这就打断了正常的更新流程。所以,正确的做法是等整个 word 的所有 pair 更新完毕(扣除和加回都做完)之后,再统一判断哪些 pair 需要清理。 **3. `pair_counts` 没初始化就用了** 这个 Bug 找起来费了点功夫。测试报错信息是: ```python for i in range(len(new_encoding) - 1): > pair_counts[(new_encoding[i], new_encoding[i+1])] += cnt ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E KeyError: (116, 257) cs336_basics/train_bpe.py:77: KeyError ``` 我反复检查了 `merge_encoding` 和 `count_pairs` 的逻辑,都没问题。后来往上一看,原因很直接:`self.count_pairs` 返回的是一个普通的 `dict`,而普通 `dict` 访问不存在的 key 会抛 `KeyError`。而如果用 `defaultdict(int)`,访问不存在的 key 时会自动初始化为 `int()`(也就是0),再执行 `+=` 操作就没事了。所以,正确的做法是像下面这样,或者先创建 `defaultdict`,再调用 `update` 方法。 ```python pair_counts = defaultdict(int, self.count_pairs(word_counts, word_encodings)) ``` ### 测试 三个测试点都顺利通过了。相比朴素版,效率的提升是实打实的。 ### 在 TinyStories 数据集上训练 这个题目有两个要求: - a. 词汇表大小最大为10,000,必须确保特殊token `"<|endoftext|>"` 被加入到词汇表中。资源要求是:训练时长 ≤ 30分钟(不使用GPU),占用内存 ≤ 30GB RAM。提示说,如果想在2分钟内完成,可以考虑多线程处理预分词。 - b. 分析tokenizer训练过程中,哪一部分最耗时。 按照作业要求,我分了三步走: 1. 编写训练脚本:包括加载训练数据、训练、保存模型、统计时间和内存。 2. 运行时性能分析(Profiling):找到瓶颈所在。 3. 检查结果:找出最长的Token。 #### 训练脚本 **数据** - **检查数据**:用 `head -n 5 data/TinyStoriesV2-GPT4-train.txt` 先看看数据长啥样,确认跟测试数据格式一致。 - **加载数据**:为了读写文件方便,工程里通常会先获取项目根路径。 ```python sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) ``` - **获取运行时内存** ```python def get_memory_usage_mb(): """Get current process memory usage in MB""" process = psutil.Process(os.getpid()) return process.memory_info().rss / 1024 / 1024 ``` - **保存模型**:训练结束后,要把得到的 `vocabulary` 和 `merges` 规则持久化到磁盘。`vocab` 是个字典,保存成 `json` 格式,为了便于人工阅读,把字节串显示出来,对于无法显示的字节,保留 `repr` 形式。`merges` 是个列表,保存成文本文件。这里参考了测试文件 `tests/test_train_bpe.py` 的保存形式,也用了里面的 `gpt2_bytes_to_unicode` 函数。当然,为了简单,也可以直接用 `pickle` 保存。 ```python # sa ve vocab and merges to disk from tests.common import gpt2_bytes_to_unicode def sa ve_vocab_and_merges(vocab, merges, output_dir='results'): Path(output_dir).mkdir(exist_ok=True, parents=True) byte_encoder = gpt2_bytes_to_unicode() # {int: str},每个字节映射到可打印字符串 # Sa ve vocab:把每个token的字节转换成可打印字符串 vocab_str = {} for idx, token_bytes in vocab.items(): vocab_str[idx] = ''.join(byte_encoder[b] for b in token_bytes) with open(f'{output_dir}/vocab.json', 'w', encoding='utf-8') as f: json.dump(vocab_str, f, ensure_ascii=False, indent=2) # Sa ve merges:两个token用空格分隔,每个字节转为可打印字符串 with open(f'{output_dir}/merges.txt', 'w', encoding='utf-8') as f: for p1, p2 in merges: t1 = ''.join(byte_encoder[b] for b in p1) t2 = ''.join(byte_encoder[b] for b in p2) f.write(f'{t1} {t2}\n') ``` - **训练主函数** ```python # main function def run_training(input_path, vocab_size, special_tokens, output_dir='results'): """Run training""" # 记录训练开始前的内存 print(f'Initial Memory: {get_memory_usage_mb():.2f} MB') # 初始化BPE训练器 trainer = train_bpe.BPETrainer() # 开始训练,记录时间和内存 start_time = time.time() print(f'Starting training on {input_path}...') vocab, merges = trainer.train(input_path, vocab_size, special_tokens) end_time = time.time() duration = end_time - start_time peak_memory = get_memory_usage_mb() print('-' * 100) print('Training Complete.') print(f'Time Taken: {duration:.2f} seconds ({duration/60:.2f} minutes)') print(f'Final Memory: {peak_memory:.2f} MB') print('-' * 100) sa ve_vocab_and_merges(vocab, merges, output_dir) # 输出统计信息 print("\n=== Statistics (Problem b) ===") # 1. Longest token longest_token_bytes = max(vocab.values(), key=len) try: longest_token_str = longest_token_bytes.decode('utf-8') except: longest_token_str = str(longest_token_bytes) print(f"Longest Token: {longest_token_str!r}") print(f"Length in bytes: {len(longest_token_bytes)}") # 2. 关于“数据集中最频 token”的说明:BPE通常不会保留最终词汇表的完整频率,除非我们重新分词。 # 这里我们打印最后一次合并的pair,它代表那一步最频繁的pair。 print(f"Total Merges: {len(merges)}") print('-' * 100) ``` - **执行训练**:在 `main` 函数中定义好参数,然后调用 `run_training`,最后在终端执行。 ```bash uv run python ./train_bpe_tinystories.py --input_path data/TinyStoriesV2-GPT4-train.txt --vocab_size 10000 --profile ``` ### 训练与分析 #### 训练结果 - 训练过程中,顺便看了下CPU和内存。我那本就有些捉襟见肘的内存直接被干满了,CPU也确实只有一个核在跑,符合单线程的预期。这也印证了后面优化的方向:用多线程处理预分词。 - 训练结果截图在这里,从结果看:训练花了28分钟,内存占用2302.11 MB。这个数据跟训练集大小是吻的。训练集大小如下: ``` 2.1G data/TinyStoriesV2-GPT4-train.txt 22M data/TinyStoriesV2-GPT4-valid.txt 9.8G data/owt_train.txt 4.3G data/owt_train.txt.gz ``` - `vocabulary` 和 `merges` 已经保存到磁盘。 - 性能分析数据也保存到了 `training.prof`。 根据上面的统计信息和数据集大小,我的电脑显然没法直接在OpenWebText数据集上训练。没办法,必须得分块加载到内存。既然都分块了,那自然可以考虑多块并行执行预分词,就像文档里提示的那样,用多进程来优化。 #### 性能分析 启动 `snakeviz` 来分析 `training.prof`。 1. 安装 `snakeviz`:从 `toml` 文件看,它不在依赖里,需要单独安装:`uv pip install snakeviz`。 2. 启动服务:`uv run snakeviz --server training.prof`。 3. 在浏览器里打开 `snakeviz` 的可视化界面。 下面是分析得到的瓶颈所在以及优化方向。 **时间分布总览** 总耗时1687秒,主要分为两大块: | 函数 | 耗时 | 占比 | | :--- | :--- | :--- | | `pretokenize` | 619秒 | 37% | | `train` 自身(合并循环) | 416秒 | 25% | | `any` + genexpr(快速跳过) | 441秒 | 26% | | `max`(选最优pair) | 161秒 | 10% | - **瓶颈1:`pretokenize` 占了37%,而且有严重的重复编译问题。** `_compile` 被调用了271万次,耗时30秒。原因很直接:每次调用 `re.finditer` 时,Python都会重新编译一次正则表达式。而 `pretokenize` 对每个chunk都调用一次 `finditer`——语料被分成了271万个chunk(也就是271万个文档)。 ```python chunks = re.split(escape_special_tokens, text) # 当前做法:每个 chunk 都触发一次编译检查 for chunk in chunks: for match in re.finditer(self.pattern, chunk): # 每次都经过编译缓存查找 ``` 优化方向很明确:提前把正则表达式编译好。 ```python compiled_pattern = re.compile(self.pattern) ``` 同样,那个用来切分的正则也建议提前编译。另外,读文件花了68秒(`read` 35秒 + `utf_8_decode` 32秒),对于一个2GB的文件来说,这个时间是正常的,没啥优化空间。 - **瓶颈2:用来“快速跳过”的逻辑,反而成了最大的瓶颈,占了26%。** 这是最反直觉的地方。`any` 被调用了5.84亿次,其内部的生成器表达式被调用了18.3亿次,加在一起花了441秒。 对应的是代码里这一行(第69行): ```python if not any((old_encoding[i], old_encoding[i+1]) == merge_pair for i in range(len(old_encoding) - 1)): continue ``` 这个“快速跳过”不但没有帮上忙,反而拖了后腿。原因有三: 1. 每个word、每轮合并都要执行这个检查。 2. Python里面生成器表达式的调用开销相当高。 3. `len()` 被调用了5.88亿次。 对于大多数word来说,这个检查确实让它们跳过了后续的替换操作,但检查本身的开销,已经比被它跳过所节省的开销要大得多了。 一个更快的做法是建立索引:维护一个 `pair_to_words` 字典,记录每个pair出现在了哪些word里。这样合并的时候,直接查这个表就行了,完全不用遍历所有word。 ```python # 初始化时建立索引 pair_to_words = defaultdict(set) for word in word_encodings: enc = word_encodings[word] for i in range(len(enc) - 1): pair_to_words[(enc[i], enc[i+1])].add(word) # 合并时只处理受影响的 word for word in pair_to_words[merge_pair]: # 直接拿到受影响的 word,不用遍历全部 ... ``` - **瓶颈3:`max()` 每轮都要遍历整个 `pair_counts` 字典,占了10%。** `max` 被调用了9746次(每轮一次),每次都要遍历整个 `pair_counts`。随着合并的进行,`pair_counts` 字典会越来越大,所以每次 `max` 的代价也会越来越高。 这里可以用堆(`heapq`)来替代 `max`,这样能把每轮选最优pair的复杂度从O(n)降到O(log n)。不过,用堆有一个复杂性:更新计数的时候需要处理“失效条目”(lazy deletion)。这个改动比之前提到的 `pair_to_words` 索引要复杂一些,建议先做索引优化,再考虑堆优化。 ### 优化方向总结 按预期收益从高到低排列: 1. **建立 `pair_to_words` 索引**:直接消灭5.84亿次 `any` 调用,预计可节省400秒以上。 2. **预编译正则表达式**:消灭271万次重复编译,预计可节省约30秒。 3. **用堆优化 `max`**:预计可节省100秒以上,但实现起来会稍微复杂一些。 具体的优化实现,我们留到下篇文章再详细讲解。