深入理解大模型分片优化:Late Chunking 技术解析

深入理解大模型分片优化:Late Chunking 技术解析

📌 背景:为什么需要文本分片(Chunking)

在使用大语言模型(如 Transformer)进行文本嵌入(Embedding)时,输入长度往往受限(例如 512 或 2048 token),而现实中的文本往往远超此限制。为解决这一问题,文本分片(Chunking) 技术被广泛使用。


🧠 什么是 Early Chunking vs Late Chunking

🧱 Early Chunking

在编码前,将长文本切分成多个小段(如按固定 token 数、换行符、标点等),每段单独送入模型,单独编码。

  • 优点:实现简单
  • 缺点
    • 每个 chunk 独立编码,上下文信息丢失
    • 重复计算(重叠窗口)浪费资源

🧠 Late Chunking

先将完整文本一次性输入模型,获取全局 token 嵌入后,再根据预定义的 token 区间进行切片与聚合。

  • 优点
    • 上下文信息保留(单次前向传播)
    • 聚合逻辑灵活,切分方式可动态调整
    • 减少模型调用次数,提高效率

🔍 使用场景

  • 对大段文本做语义嵌入(如句子/段落级别表示)
  • 信息抽取与摘要任务的输入预处理
  • 文档级搜索与向量检索(RAG)中构建 chunk embedding

🧪 代码实践解析:chunk_by_sentences

以下是基于句号(.)分句的分片函数,提取字符级到 token 级的 span:

from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)

def chunk_by_sentences(input_text: str, tokenizer: callable):
    inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
    sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
    token_offsets = inputs['offset_mapping'][0]
    token_ids = inputs['input_ids'][0]

    chunk_positions = [
        (i, int(start + 1))
        for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
        if token_id == punctuation_mark_id
        and (
            token_offsets[i + 1][0] - token_offsets[i][1] > 0
            or token_ids[i + 1] == sep_id
        )
    ]

    chunks = [
        input_text[x[1] : y[1]]
        for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    span_annotations = [
        (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    return chunks, span_annotations

🔎 注意

  • [(1, 0)] + chunk_positions[:-1] 这一写法的意义是构造 (start, end) 对,每一段的开始位置是前一个 . 之后的 token。
  • 分片虽然以句号为界,但不会立即执行切分,而是记录每段的 token 索引区间

🧠 为什么 chunks 能从字符 0 开始?

虽然 span_annotations 的起点是 token 编号(如 (1, 13)),但在构造 chunks 时,我们是基于 offset 映射(字符位置)来提取文本的:

chunks = [
    input_text[x[1]: y[1]]
]

x[1] 是字符起始位置,y[1] 是字符结束位置,因此文本从头截取是合理的,即使 token 从 1 开始。


🧪 late_chunking 函数详解

late_chunking 根据已有的 token span(如上面得到的 span_annotations),在模型输出的 embedding 上做 pooling(通常是均值)。

def late_chunking(model_output: 'BatchEncoding', span_annotation: list, max_length=None):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if max_length is not None:
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]

        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs

🧩 流程总结

  1. 对于每个样本中每段的 (start, end),提取对应 token 的 embedding;
  2. 对该区间执行 mean pooling
  3. 将结果转换为 numpy 向量,方便存储/检索。

🖼️ Early vs Late Chunking 对比图

左侧为 Early Chunking,右侧为 Late Chunking,展示了上下文保留与效率差异。


🧠 总结

特性Early ChunkingLate Chunking
模型调用次数多次一次
上下文信息丢失保留全局上下文
实现复杂度简单稍复杂(需记录 span)
聚合策略灵活性固定切片高度灵活,支持任意切分逻辑
适合场景简单 pipeline,token 限制较紧向量检索、摘要、多模态、RAG 等复杂任务

📚 延伸阅读

  • 《Scaling Transformer Embeddings via Late Interaction (Dense Retriever)》
  • HuggingFace offset_mapping 的使用指南
  • 向量数据库(如 FAISS, Qdrant)中的 Chunking 策略

参考说明
作者:易迟
链接:https://zhuanlan.zhihu.com/p/885347223

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值