Qwen3 reranker 测试

环境准备

见:Qwen3 Embedding 测试

代码与解释

# 导入必要的库
import torch  # PyTorch深度学习框架
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM  # Hugging Face的transformers库

# 加载Qwen3-Reranker-0.6B模型和分词器
# padding_side='left'表示在序列左侧进行填充,这对于因果语言模型很重要
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
# 加载模型并设置为评估模式
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()

这部分代码初始化了模型和分词器,为后续的重排序任务做准备。Qwen3-Reranker是专门为文档重排序设计的模型。

# 定义重排序任务的指令
task = 'Given a web search query, retrieve relevant passages that answer the query'

# 定义示例查询
queries = ["What is the capital of China?",
    "Explain gravity",
]

# 定义对应的文档
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

# 格式化指令、查询和文档为模型输入格式
def format_instruction(instruction, query, doc):
    if instruction is None:
        instruction = 'Given a web search query, retrieve relevant passages that answer the query'
    output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction,query=query, doc=doc)
    return output

# 为每对查询-文档创建格式化输入
pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
pairs 

这部分代码准备了输入数据。它定义了查询和对应的文档,然后使用format_instruction函数将它们格式化为模型可以理解的结构。格式化后的输入包含指令、查询和文档三部分,清晰地告诉模型需要判断文档是否满足查询需求。

# 获取"no"和"yes"在词表中的token ID
token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
# 设置最大序列长度
max_length = 8192

# 定义系统提示前缀,指示模型进行二分类判断
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
# 定义后缀,包含助手回答的开始部分
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
# 将前缀和后缀转换为token ID
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)

# 计算前缀和后缀的总token数
len(prefix_tokens) + len(suffix_tokens)

这部分代码设置了模型输入的格式和参数。它获取了"yes"和"no"的token ID,这将用于从模型输出中提取相关概率。同时定义了系统提示前缀和后缀,将重排序问题转化为二分类问题:文档是否满足查询需求。

# 处理输入数据的函数
def process_inputs(pairs):
    # 对输入文本进行分词,但不进行填充
    inputs = tokenizer(
        pairs, padding=False, truncation='longest_first',
        return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
    )
    # 为每个输入添加前缀和后缀token
    for i, ele in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
    # 对所有序列进行填充,转换为PyTorch张量
    inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    # 将输入移至模型所在设备
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)
    return inputs

# 处理示例输入
inputs = process_inputs(pairs)
# 查看处理后的输入形状
inputs["input_ids"].shape

这部分代码定义了process_inputs函数,用于处理输入数据。它首先对文本进行分词,然后添加前缀和后缀,最后进行填充并转换为PyTorch张量。输出显示处理后的输入形状为[2, 104],表示有2个样本,每个样本的长度为104个token。

# 使用torch.no_grad()装饰器,在推理时不计算梯度,节省内存
@torch.no_grad()
def compute_logits(inputs, **kwargs):
    # 将输入传递给模型,获取输出的logits
    res = model(**inputs).logits
    # 选择最后一个位置的输出,即回答位置的词表概率分布
    batch_scores =  res[:, -1, :]
    # 提取"yes" token的logit值
    true_vector = batch_scores[:, token_true_id]
    # 提取"no" token的logit值
    false_vector = batch_scores[:, token_false_id]
    # 将"no"和"yes"的logit值堆叠成新的张量[batch_size, 2]
    batch_scores = torch.stack([false_vector, true_vector], dim=1)
    # 应用log_softmax将logits转换为对数概率
    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
    
    # 提取"yes"对应的概率,转换回原始概率空间,并转为Python列表
    scores = batch_scores[:, 1].exp().tolist()
    return scores

# 计算示例输入的相关性分数
scores = compute_logits(inputs)
scores

这部分代码定义了compute_logits函数,用于计算查询-文档对的相关性分数。函数的核心思想是将重排序问题转化为二分类问题:

  1. 首先获取模型对输入的预测结果
  2. 从预测结果中提取最后一个位置(回答位置)的词表概率分布
  3. 只关注"yes"和"no" token的logit值
  4. 将这两个值堆叠并应用log_softmax转换为对数概率
  5. 最后提取"yes"对应的概率作为相关性分数

输出结果显示两个示例的相关性分数分别为0.9995和0.9994,表明模型认为这些文档与对应的查询高度相关。

另个例子

文档相关

# 添加更多测试用例
task = 'Given a web search query, retrieve relevant passages that answer the query'

# 示例1:添加更多查询-文档对
queries = [
    "What is the capital of China?",
    "Explain gravity",
    "How does photosynthesis work?",  # 新增查询
    "What are the benefits of exercise?",  # 新增查询
]

documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
    "Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize foods with carbon dioxide and water. It converts light energy into chemical energy, releasing oxygen as a byproduct.",  # 相关文档
    "Regular physical activity can improve muscle strength, boost endurance, deliver oxygen and nutrients to tissues, and help your cardiovascular system work more efficiently.",  # 相关文档
]

# 创建所有查询-文档对
pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)]

# 处理输入并计算相关性分数
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

# 打印每个查询-文档对的相关性分数
for i, (query, doc, score) in enumerate(zip(queries, documents, scores)):
    print(f"Query {i+1}: {query}")
    print(f"Document: {doc[:50]}..." if len(doc) > 50 else f"Document: {doc}")
    print(f"Relevance Score: {score:.6f}")
    print("-" * 50)
Query 1: What is the capital of China?
Document: The capital of China is Beijing.
Relevance Score: 0.999498
--------------------------------------------------
Query 2: Explain gravity
Document: Gravity is a force that attracts two bodies toward...
Relevance Score: 0.999362
--------------------------------------------------
Query 3: How does photosynthesis work?
Document: Photosynthesis is the process by which green plant...
Relevance Score: 0.998967
--------------------------------------------------
Query 4: What are the benefits of exercise?
Document: Regular physical activity can improve muscle stren...
Relevance Score: 0.981571
--------------------------------------------------

文档不相关

# 测试不相关的查询-文档对
mismatched_queries = [
    "What is the capital of China?",
    "How does photosynthesis work?",
]

mismatched_documents = [
    "Paris is the capital of France and one of the most populous cities in Europe.",  # 不相关文档
    "Machine learning is a branch of artificial intelligence that focuses on building systems that learn from data.",  # 不相关文档
]

# 格式化不匹配的查询-文档对
mismatched_pairs = [format_instruction(task, query, doc) for query, doc in zip(mismatched_queries, mismatched_documents)]

# 处理输入并计算相关性分数
mismatched_inputs = process_inputs(mismatched_pairs)
mismatched_scores = compute_logits(mismatched_inputs)

# 打印每个不匹配查询-文档对的相关性分数
for i, (query, doc, score) in enumerate(zip(mismatched_queries, mismatched_documents, mismatched_scores)):
    print(f"Query {i+1}: {query}")
    print(f"Document: {doc[:50]}..." if len(doc) > 50 else f"Document: {doc}")
    print(f"Relevance Score: {score:.6f}")
    print("-" * 50)
Query 1: What is the capital of China?
Document: Paris is the capital of France and one of the most...
Relevance Score: 0.000153
--------------------------------------------------
Query 2: How does photosynthesis work?
Document: Machine learning is a branch of artificial intelli...
Relevance Score: 0.000047
--------------------------------------------------
### Ollama Rerank 使用方法及问题解决方案 #### 一、Ollama Rerank 的工作流程 在构建基于 Ollama 和 AnythingLLM 的本地知识库过程中,Rerank 是提升问答系统性能的重要环节之一。当用户提交一个问题时,检索模块会从向量数据库中找到若干个与该问题语义上最接近的 chunks[^2]。然而,这些初步选出的 chunks 可能并不完全按照相关性的强弱顺序排列。 为了优化这一过程,Rerank 对上述检索出的结果进行二次排序,使那些真正高度关联于提问内容的信息片段能够占据更为显著的位置。这一步骤有助于确保后续由大语言模型生成的回答更加贴合用户的实际需求并具备更高的准确性[^3]。 #### 二、具体实施步骤 1. **配置环境** 安装必要的依赖项以及设置好用于运行 Ollama 和其他组件的服务端口等参数。对于特定版本如 qwen1.5-chat ,需指定正确的 Base URL 并完成相应模型 UID 的注册操作[^4]。 2. **集成 Rerank 功能** 在现有架构基础上加入专门负责执行重排序逻辑的部分。这部分可以利用现成工具包来简化开发难度;也可以自定义算法以适应特殊应用场景下的个性化要求。 3. **评估效果** 经过多次迭代测试之后调整各项参数直至达到满意的查询精度为止。期间可借助一些量化指标比如 MRR(Mean Reciprocal Rank)、NDCG (Normalized Discounted Cumulative Gain)来进行客观衡量。 ```python from ollama import Reranker, Retriever def perform_reranking(query: str, top_k_chunks): retriever = Retriever() initial_results = retriever.retrieve(query=query, k=top_k_chunks) reranker = Reranker(model="anything_llm") # 假设使用的是 AnyThing LLM 进行重排 ranked_results = reranker.rank(chunks=initial_results, query=query) return ranked_results ``` 通过以上方式可以在基于 Ollama 构建的知识库系统内有效引入 Rerank 特性,进而改善最终呈现给用户的答案质量。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值