深入理解大模型分片优化:Late Chunking 技术解析
📌 背景:为什么需要文本分片(Chunking)
在使用大语言模型(如 Transformer)进行文本嵌入(Embedding)时,输入长度往往受限(例如 512 或 2048 token),而现实中的文本往往远超此限制。为解决这一问题,文本分片(Chunking) 技术被广泛使用。
🧠 什么是 Early Chunking vs Late Chunking
🧱 Early Chunking
在编码前,将长文本切分成多个小段(如按固定 token 数、换行符、标点等),每段单独送入模型,单独编码。
- 优点:实现简单
- 缺点:
- 每个 chunk 独立编码,上下文信息丢失
- 重复计算(重叠窗口)浪费资源
🧠 Late Chunking
先将完整文本一次性输入模型,获取全局 token 嵌入后,再根据预定义的 token 区间进行切片与聚合。
- 优点:
- 上下文信息保留(单次前向传播)
- 聚合逻辑灵活,切分方式可动态调整
- 减少模型调用次数,提高效率
🔍 使用场景
- 对大段文本做语义嵌入(如句子/段落级别表示)
- 信息抽取与摘要任务的输入预处理
- 文档级搜索与向量检索(RAG)中构建 chunk embedding
🧪 代码实践解析:chunk_by_sentences
以下是基于句号(.
)分句的分片函数,提取字符级到 token 级的 span:
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
def chunk_by_sentences(input_text: str, tokenizer: callable):
inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
token_offsets = inputs['offset_mapping'][0]
token_ids = inputs['input_ids'][0]
chunk_positions = [
(i, int(start + 1))
for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
if token_id == punctuation_mark_id
and (
token_offsets[i + 1][0] - token_offsets[i][1] > 0
or token_ids[i + 1] == sep_id
)
]
chunks = [
input_text[x[1] : y[1]]
for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
]
span_annotations = [
(x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
]
return chunks, span_annotations
🔎 注意
[(1, 0)] + chunk_positions[:-1]
这一写法的意义是构造(start, end)
对,每一段的开始位置是前一个.
之后的 token。- 分片虽然以句号为界,但不会立即执行切分,而是记录每段的 token 索引区间。
🧠 为什么 chunks 能从字符 0 开始?
虽然 span_annotations 的起点是 token 编号(如 (1, 13)),但在构造 chunks
时,我们是基于 offset 映射(字符位置)来提取文本的:
chunks = [
input_text[x[1]: y[1]]
]
x[1]
是字符起始位置,y[1]
是字符结束位置,因此文本从头截取是合理的,即使 token 从 1 开始。
🧪 late_chunking
函数详解
late_chunking
根据已有的 token span(如上面得到的 span_annotations
),在模型输出的 embedding 上做 pooling(通常是均值)。
def late_chunking(model_output: 'BatchEncoding', span_annotation: list, max_length=None):
token_embeddings = model_output[0]
outputs = []
for embeddings, annotations in zip(token_embeddings, span_annotation):
if max_length is not None:
annotations = [
(start, min(end, max_length - 1))
for (start, end) in annotations
if start < (max_length - 1)
]
pooled_embeddings = [
embeddings[start:end].sum(dim=0) / (end - start)
for start, end in annotations
if (end - start) >= 1
]
pooled_embeddings = [
embedding.detach().cpu().numpy() for embedding in pooled_embeddings
]
outputs.append(pooled_embeddings)
return outputs
🧩 流程总结
- 对于每个样本中每段的
(start, end)
,提取对应 token 的 embedding; - 对该区间执行
mean pooling
; - 将结果转换为 numpy 向量,方便存储/检索。
🖼️ Early vs Late Chunking 对比图
左侧为 Early Chunking,右侧为 Late Chunking,展示了上下文保留与效率差异。
🧠 总结
特性 | Early Chunking | Late Chunking |
---|---|---|
模型调用次数 | 多次 | 一次 |
上下文信息 | 丢失 | 保留全局上下文 |
实现复杂度 | 简单 | 稍复杂(需记录 span) |
聚合策略灵活性 | 固定切片 | 高度灵活,支持任意切分逻辑 |
适合场景 | 简单 pipeline,token 限制较紧 | 向量检索、摘要、多模态、RAG 等复杂任务 |
📚 延伸阅读
- 《Scaling Transformer Embeddings via Late Interaction (Dense Retriever)》
- HuggingFace
offset_mapping
的使用指南 - 向量数据库(如 FAISS, Qdrant)中的 Chunking 策略
参考说明
作者:易迟
链接:https://zhuanlan.zhihu.com/p/885347223