一、Query 改写:精准捕捉用户深层需求
Query 改写的核心任务,是**在用户意图与检索系统之间构建语义桥梁**。用户输入的查询往往口语化、碎片化,而检索系统需要结构清晰、语义明确的“检索语言”。

1.1 多查询改写(Multi-Query Rewriting)
最直接的策略:单一查询覆盖不全,那就生成多个变体。
LangChain 的 MultiQueryRetriever 是这一思路的官方实现。其原理非常清晰——借助 LLM 将用户原始问题改写成 N 个不同表述的查询,分别执行检索,最后合并去重。
import os
from typing import List
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY", "you-deepseek-api-key")
os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE", "https://api.deepseek.com")
docs = [
Document(page_content="RAG(检索增强生成)通过引入外部知识库来增强大语言模型的回答能力,有效缓解幻觉问题。", metadata={"doc_id": "1"}),
Document(page_content="幻觉问题是大语言模型的固有缺陷,RAG通过检索真实文档来提供事实依据,减少模型编造内容。", metadata={"doc_id": "2"}),
Document(page_content="BM25是一种基于词频和文档频率的稀疏检索算法,擅长精确关键词匹配。", metadata={"doc_id": "3"}),
Document(page_content="向量检索使用Embedding模型将文本转换为高维向量,通过向量相似度进行语义匹配。", metadata={"doc_id": "4"}),
Document(page_content="HyDE(Hypothetical Document Embedding)通过生成伪文档来扩展查询,提高召回率。", metadata={"doc_id": "5"}),
]
embeddings = HuggingFaceEmbeddings(model_name=r"E:LLM ProjectLocal Knowledge Base Q&A SystemmodelsXorbitsbge-large-zh-v1.5")
vector_store = FAISS.from_documents(docs, embeddings)
base_retriever = vector_store.as_retriever(search_kwargs={"k": 3})print(f"向量存储构建完成,包含 {len(docs)} 个文档")
from langchain_classic.retrievers.multi_query import MultiQueryRetriever, DEFAULT_QUERY_PROMPTllm = ChatOpenAI(temperature=0, model="deepseek-chat", api_key=os.environ["OPENAI_API_KEY"], base_url=os.environ["OPENAI_API_BASE"])
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
parser_key="lines"
)
query = "什么是RAG系统的多路召回?"
print(f"原始查询: {query}")
prompt_text = DEFAULT_QUERY_PROMPT.format(question=query)
response = llm.invoke(prompt_text)
rewritten_queries = [q.strip() for q in response.content.strip().split("n") if q.strip()]
print(f"nLLM 生成的 {len(rewritten_queries)} 个改写查询:")
for i, q in enumerate(rewritten_queries, 1):
print(f" {i}. {q}")
docs = retriever.invoke(query)
print(f"n检索到 {len(docs)} 个不重复文档:")
for i, doc in enumerate(docs, 1):
print(f" {i}. [doc_id={doc.metadata['doc_id']}] {doc.page_content}")
实践中一个常见陷阱:对于简单查询,3-5 个改写足矣;而复杂的多跳查询可能需要 7-10 个。关键在于让模型自主判断——这正是 DMQR-RAG 提出的“自适应改写数量”机制。
1.2 HyDE:通过“伪文档”扩展查询
如果说 Multi-Query 是“换种问法”,那么 HyDE(Hypothetical Document Embedding)就是“**先预判理想答案的形态,再用它去匹配**”。
该思路源自 2022 年 CMU 与滑铁卢大学的论文 Precise Zero-Shot Dense Retrieval without Relevance Labels。核心流程仅三步:**生成伪文档**:让 LLM 基于 query 构造一个“假设的理想答案”;**嵌入伪文档**:将生成的伪文档转化为向量;**用伪文档向量检索**:由于伪文档包含了 query 的语义扩展,召回效果显著提升。
class HyDEEncoder:
def __init__(self, llm, encoder):
self.llm = llm
self.encoder = encoder
def generate_pseudo_doc(self, query: str) -> str:
"""生成伪文档"""
prompt = f"""请根据以下问题,生成一段可能包含答案的文档片段(100-200字):问题:{query}文档片段:"""
response = self.llm.invoke(prompt)
return response.content
def encode(self, query: str) -> List[float]:
"""生成伪文档并返回扩展后的向量"""
pseudo_doc = self.generate_pseudo_doc(query)
print(f"生成的伪文档:n{pseudo_doc}n")
embedding = self.encoder.embed_query(pseudo_doc)
return embedding
llm = ChatOpenAI(temperature=0.7, model="deepseek-chat", api_key=os.environ["OPENAI_API_KEY"], base_url=os.environ["OPENAI_API_BASE"])
hyde_encoder = HyDEEncoder(llm, embeddings)query = "RAG系统如何解决幻觉问题?"
print(f"查询: {query}n")hyde_embedding = hyde_encoder.encode(query)
results = vector_store.similarity_search_by_vector(hyde_embedding, k=5)print(f"n检索到 {len(results)} 个文档:")
for i, doc in enumerate(results, 1):
print(f" {i}. [doc_id={doc.metadata['doc_id']}] {doc.page_content}")
HyDE 的优势在于,生成的伪文档天然携带了**领域术语和上下文信息**——这些往往是原始 query 所欠缺的。例如,用户问“如何预防幻觉”,伪文档中可能包含“事实核查”、“源头追溯”、“检索增强”等专业表达,从而匹配到更精准的资料。
**但 HyDE 也有短板**:若 LLM 生成的伪文档本身存在幻觉(偏离主题),检索效果会适得其反。因此,落地时建议搭配**置信度过滤**或**多伪文档投票**策略。
1.3 查询分解:化繁为简
部分 query 天然属于“多跳”类型,例如:“**对比2023年和2024年特斯拉在中国的销量变化,并剖析背后的原因**”。
直接使用向量检索这类查询,召回结果通常碎片化,难以覆盖所有维度。查询分解(Query Decomposition)的策略是:**将复杂问题拆解为多个子问题,分别检索后再汇总**。
import re
from langchain_core.prompts import PromptTemplate
if 'llm' not in dir():
llm = ChatOpenAI(temperature=0, model="deepseek-chat", api_key=os.environ["OPENAI_API_KEY"], base_url=os.environ["OPENAI_API_BASE"])decompose_prompt = PromptTemplate(
input_variables=["question"],
template="""将以下复杂问题分解为2-4个简单的子问题。直接列出子问题,每行一个(不要加编号或"子问题"前缀,只写问题本身):问题:{question}子问题:"""
)def decompose_query(llm, complex_query: str) -> List[str]:
"""分解复杂查询,支持多种LLM输出格式的健壮解析"""
response = llm.invoke(decompose_prompt.format(question=complex_query))
raw_text = response.content.strip()
lines = raw_text.split("n") sub_questions = []
for line in lines:
line = line.strip()
if not line:
continue
cleaned = re.sub(r'^[d]+[.、)]s*', '', line)
cleaned = re.sub(r'^子问题[d]*[::]s*', '', cleaned)
cleaned = cleaned.strip()
if cleaned and len(cleaned) > 3:
sub_questions.append(cleaned)
if not sub_questions:
parts = re.split(r'(?:^|s)(?:d+[.、)]s*|子问题d+[::]s*)', raw_text)
sub_questions = [p.strip() for p in parts if p.strip() and len(p.strip()) > 3] return sub_questions
complex_query = "比较2023年和2024年特斯拉在中国的销量变化,并分析原因"
print(f"复杂查询: {complex_query}n")sub_questions = decompose_query(llm, complex_query)
print("分解后的子问题:")
if sub_questions:
for i, sq in enumerate(sub_questions, 1):
print(f"{i}. {sq}")
else:
print(" (未能分解出子问题,请检查LLM连接或尝试其他查询)")
2026 年的论文《The Impact of Query Decomposition and Cross-Encoder Reranking in Multi-Hop Retrieval-Augmented Generation》证实了这一策略的有效性。其结论是:**查询分解 + 重排序的“广撒网、精过滤”模式,在多跳问答场景中明显优于单次检索**。
1.4 Query2Doc:另一种伪文档方案
Query2Doc 与 HyDE 类似,同样通过生成伪文档来扩展查询。区别在于 Query2Doc 更侧重于**利用少量示例(few-shot)引导生成**,使伪文档更贴近真实文档的数据分布。
实际项目中,HyDE 和 Query2Doc 的效果高度依赖**生成模型的质量**。若使用 GPT-4 级别模型,生成的伪文档质量有保障;若采用 7B 级别本地模型,建议先用领域数据微调,否则伪文档可能“一本正经地胡说八道”。
二、多路召回:分散风险,提升覆盖
Query 改写解决的是“**查询不够丰富**”的问题,多路召回则应对“**单路检索存在盲区**”的挑战。
2.1 单路检索的局限性
向量检索与稀疏检索(BM25)各有优劣:
| 检索方式 | 优势 | 劣势 | 典型应用 |
|---|
| 向量检索(Dense) | 语义相似、同义词识别、概念匹配 | 精确匹配、罕见术语、ID/代码 | "RAG原理"匹配"检索增强生成" |
| 稀疏检索(BM25) | 精确匹配、关键词命中、术语 | 语义泛化、同义词 | "BERT"精确匹配含"BERT"的文档 |
| 关键词匹配 | 快速过滤、布尔逻辑 | 语义理解 | 必须包含/排除特定词汇 |
from langchain_community.retrievers import BM25Retriever
from rank_bm25 import BM25Okapi
bm25_retriever = BM25Retriever.from_documents(docs, k=3)print(f"BM25检索器构建完成,包含 {len(docs)} 个文档")
test_query = "幻觉问题"
bm25_results = bm25_retriever.invoke(test_query)
print(f"nBM25检索 '{test_query}' 的结果:")
for i, doc in enumerate(bm25_results, 1):
print(f"{i}. {doc.page_content[:80]}...")
test_query = "如何减少模型编造内容"
vector_results = base_retriever.invoke(test_query)
print(f"向量检索 '{test_query}' 的结果:")
for i, doc in enumerate(vector_results, 1):
print(f"{i}. {doc.page_content[:80]}...")
2026 年的基准测试报告《From BM25 to Corrective RAG: Benchmarking Retrieval Strategies for Text-and-Table Documents》对比了 10 余种检索组合,结论清晰:**混合检索(BM25 + 向量)在绝大多数场景下均是最优基线**。
2.2 并行召回 vs 串行召回
多路召回主要有两种架构:
**并行召回**(推荐方案):
用户Query → [向量检索] → 结果A
→ [BM25检索] → 结果B
→ [关键词过滤] → 结果C
→ RRF融合 → 重排序 → 最终结果
三路并行,互不阻塞,最终通过融合算法整合。优点是响应速度快,适合在线服务。
**串行召回**(级联过滤):
用户Query → [粗排:BM25快速过滤] → Top100
→ [精排:向量检索] → Top20
→ [重排序:Cross-Encoder] → Top5
先粗后精,逐步缩小候选集。优点是计算量小,适用于资源受限环境。
在工程实践中,**并行召回 + 融合 + 重排序**的架构更值得推荐,因为各路召回互补性最强,融合后的效果通常优于任何单一路径。
2.3 向量数据库的原生混合检索
好消息是,主流向量数据库已内置混合检索能力,无需从零手写:
**Qdrant**:
from qdrant_client import QdrantClientclient = QdrantClient("localhost", port=6333)
client.search(
collection_name="my_collection",
query_vector=("dense", [0.1, 0.2, ...]),
query_sparse_vector=("sparse", {0: 1.0, 5: 0.8}),
limit=10
)
**Milvus**:
from pymilvus import Collectioncollection = Collection("hybrid_collection")
search_params = {
"metric_type": "L2",
"params": {"nprobe": 128}
}
results = collection.hybrid_search(
reqs=[
AnnSearchRequest(data=[[0.1, 0.2, ...]], anns_field="dense_vector", param=search_params, limit=100),
AnnSearchRequest(data=[[...]], anns_field="sparse_vector", param={}, limit=100)
],
rerank=RRFRanker(k=60),
limit=10
)
**Elasticsearch**:
{
"query": {
"hybrid": {
"queries": [
{ "match": { "content": "RAG多路召回" } },
{ "knn": { "field": "vector", "query_vector": [0.1, 0.2], "k": 100 } }
]
}
}
}
三、RRF 融合:多路结果的“民主投票”
多路召回带来多组结果,如何合并?直接取并集会浪费排名信息;加权求和又该如何确定权重?
RRF(Reciprocal Rank Fusion,倒数排名融合)提供了优雅的解决方案——无需训练、无需调参,公式简洁至极:
RRF_score(d) = Σ 1 / (k + rank_i(d))
其中:
- rank_i(d):文档 d 在第 i 路召回中的排名
- k:常数,通常取 60(经验值,防止低排名文档分数过高)
import numpy as npdef rrf_fusion(results_lists: List[List[Document]], k: int = 60) -> List[Document]:
"""
多路召回结果的RRF融合
results_lists: [[doc1, doc2, ...], [doc1, doc3, ...], ...]
k: RRF常数,通常取60
"""
scores = {}
doc_map = {} for results in results_lists:
for rank, doc in enumerate(results, start=1):
doc_id = doc.metadata.get("doc_id", hash(doc.page_content))
doc_map[doc_id] = doc if doc_id not in scores:
scores[doc_id] = 0
scores[doc_id] += 1.0 / (k + rank)
sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return [doc_map[doc_id] for doc_id, _ in sorted_docs]
demo_docs = [
Document(page_content="RAG(检索增强生成)通过引入外部知识库来增强大语言模型的回答能力,有效缓解幻觉问题。", metadata={"doc_id": "1"}),
Document(page_content="幻觉问题是大语言模型的固有缺陷,RAG通过检索真实文档来提供事实依据,减少模型编造内容。", metadata={"doc_id": "2"}),
Document(page_content="BM25是一种基于词频和文档频率的稀疏检索算法,擅长精确关键词匹配。", metadata={"doc_id": "3"}),
Document(page_content="向量检索使用Embedding模型将文本转换为高维向量,通过向量相似度进行语义匹配。", metadata={"doc_id": "4"}),
Document(page_content="HyDE(Hypothetical Document Embedding)通过生成伪文档来扩展查询,提高召回率。", metadata={"doc_id": "5"}),
Document(page_content="2024年诺贝尔物理学奖授予了John Hopfield和Geoffrey Hinton,表彰他们在人工神经网络的基础性发现。", metadata={"doc_id": "6"}),
Document(page_content="Transformer架构自2017年提出以来,彻底改变了搜索与问答领域,BERT、GPT系列模型均基于此架构。", metadata={"doc_id": "7"}),
Document(page_content="Python是数据科学和机器学习领域最流行的编程语言,拥有NumPy、Pandas、PyTorch等丰富生态。", metadata={"doc_id": "8"}),
Document(page_content="语义搜索与传统关键词搜索的最大区别在于理解用户意图,而非仅仅匹配表面文字。", metadata={"doc_id": "9"}),
Document(page_content="多路召回结合了向量检索和稀疏检索的优势,通过融合算法实现1+1>2的检索效果。", metadata={"doc_id": "10"}),
]
demo_embeddings = HuggingFaceEmbeddings(model_name=r"E:LLM ProjectLocal Knowledge Base Q&A SystemmodelsXorbitsbge-large-zh-v1.5")
demo_vector_store = FAISS.from_documents(demo_docs, demo_embeddings)
demo_vector_retriever = demo_vector_store.as_retriever(search_kwargs={"k": 5})
demo_bm25_retriever = BM25Retriever.from_documents(demo_docs, k=5)print(f"演示语料库构建完成,共 {len(demo_docs)} 篇文档n")
query = "什么是检索增强生成中的语义搜索"
vector_results = demo_vector_retriever.invoke(query)
bm25_results = demo_bm25_retriever.invoke(query)
print(f"{'='*60}")
print(f"查询: {query}")
print(f"{'='*60}")print(f"n 向量检索 Top-{len(vector_results)}(语义匹配能力强,擅长同义词/概念):")
for i, doc in enumerate(vector_results, 1):
print(f" {i}. [doc_id={doc.metadata['doc_id']}] {doc.page_content}")print(f"n BM25检索 Top-{len(bm25_results)}(精确关键词匹配,擅长术语命中):")
for i, doc in enumerate(bm25_results, 1):
print(f" {i}. [doc_id={doc.metadata['doc_id']}] {doc.page_content}")
fused_results = rrf_fusion([vector_results, bm25_results], k=60)
vector_ids = {doc.metadata['doc_id'] for doc in vector_results}
bm25_ids = {doc.metadata['doc_id'] for doc in bm25_results}
only_vector = vector_ids - bm25_ids
only_bm25 = bm25_ids - vector_ids
both = vector_ids & bm25_idsprint(f"n{'='*60}")
print(f"RRF 融合结果统计:")
print(f" 向量独有: {len(only_vector)} 篇 | BM25独有: {len(only_bm25)} 篇 | 两路共有: {len(both)} 篇")
print(f" 向量检索: {len(vector_results)} → BM25检索: {len(bm25_results)} → RRF融合: {len(fused_results)} (去重+重排)")
print(f"{'='*60}")print(f"n RRF融合后 Top-{len(fused_results)}(排名融合 + 去重):")
for i, doc in enumerate(fused_results, 1):
did = doc.metadata['doc_id']
if did in only_vector:
source_tag = " 仅向量"
elif did in only_bm25:
source_tag = " 仅BM25"
else:
source_tag = " 两路共有"
print(f" {i}. [doc_id={did}] {source_tag}n {doc.page_content}")print(f"n RRF 的核心价值:两路检索结果存在差异时,融合后能同时保留")
print(f" 向量擅长的语义匹配结果 和 BM25擅长的关键词精确匹配结果,互为补充。")
RRF 的精妙之处在于:
1. **无参数**:无需训练,无需调整权重
2. **对排名敏感,对分数不敏感**:只关注“排第几”,不关心“具体得分”
3. **自然处理结果缺失**:若某路未召回某文档,则该路不贡献分数,不影响其他路
**但需注意**:RRF 假设各路召回的“排名质量”大致相当。若某路召回质量极差(如随机排序),会拖累整体效果。因此融合前应确保每路召回均达到“合格线以上”。
四、重排序:提升最后一公里精度
多路召回 + RRF 融合之后,效果已优于单路。但若追求极致精度,还需追加一道“重排序”(Reranking)工序。
4.1 重排序的必要性
召回阶段追求“**不漏**”(高召回率),可采用轻量级近似算法(ANN、BM25)。但召回结果的 Top-K 排名未必准确。
重排序阶段追求“**准**”(高精确率),使用更重的模型(Cross-Encoder)对召回结果进行精细排序,将真正相关的文档前置。
4.2 BGE-Reranker:工业级重排序方案
BAAI 的 BGE-Reranker 系列是目前中文场景最实用的重排序模型。与 Bi-Encoder(双塔模型,提前编码文档与 query)不同,Reranker 属于 Cross-Encoder(交叉编码器),将 query 与文档同时输入模型,输出相关性分数。
import os, time, socket
from FlagEmbedding import FlagRerankerLOCAL = r'E:LLM ProjectLocal Knowledge Base Q&A SystemmodelsXorbitsbge-reranker-large'
MODEL = 'BAAI/bge-reranker-base'
os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', '30')
socket.setdefaulttimeout(30)def _load_reranker(path, retries=5, fp16=True):
NET = (TimeoutError, ConnectionError, ConnectionRefusedError,
ConnectionAbortedError, ConnectionResetError, BrokenPipeError)
for i in range(retries + 1):
try:
reranker = FlagReranker(path, use_fp16=fp16)
if i: print(f'[重试成功] 第{i}次重试后加载成功')
return reranker
except NET as e:
if i == retries: raise
wait = 2 ** i
print(f'[重试 {i+1}/{retries}] {type(e).__name__} — 等待 {wait}s')
time.sleep(wait)
except OSError as e:
if getattr(e, 'winerror', 0) in (10060, 10061, 10053, 10054):
if i == retries: raise
wait = 2 ** i
print(f'[重试 {i+1}/{retries}] WinError {e.winerror} — 等待 {wait}s')
time.sleep(wait)
else:
raisepath = LOCAL if os.path.isdir(LOCAL) else MODEL
if path == LOCAL:
print(f'[本地模型] {LOCAL}')
else:
print(f'[远程下载] ⏳ {MODEL}')
reranker = _load_reranker(path)
print('n BGE-Reranker模型已加载')
import types
from transformers import BatchEncodingdef _patch_prepare_for_model(tokenizer):
if hasattr(tokenizer, 'prepare_for_model'):
return
def _prepare_for_model(this, ids, pair_ids=None, max_length=None,
add_special_tokens=True, padding=False,
truncation='only_second', stride=0,
return_tensors=None, return_token_type_ids=None,
return_attention_mask=None, return_overflowing_tokens=False,
return_special_tokens_mask=False):
text_a = this.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
text_b = this.decode(pair_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) if pair_ids is not None else None
encoded = this(text_a, text_b,
truncation=truncation or True,
max_length=max_length or this.model_max_length,
padding=padding, add_special_tokens=add_special_tokens)
result = {'input_ids': encoded['input_ids']}
if return_attention_mask:
result['attention_mask'] = [1] * len(result['input_ids'])
if return_token_type_ids:
result['token_type_ids'] = [0] * len(result['input_ids'])
if return_special_tokens_mask:
result['special_tokens_mask'] = [0] * len(result['input_ids'])
if return_overflowing_tokens:
result['overflowing_tokens'] = []
return BatchEncoding(result)
tokenizer.prepare_for_model = types.MethodType(_prepare_for_model, tokenizer)
print('[补丁] prepare_for_model 已注入到 tokenizer')
def rerank_results(reranker, query, candidates, top_k=3):
"""使用 BGE-Reranker 做精排"""
if not candidates:
return []
_patch_prepare_for_model(reranker.tokenizer)
pairs = [[query, doc.page_content] for doc in candidates]
scores = reranker.compute_score(pairs)
reranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return [doc for doc, _ in reranked[:top_k]]
query = "RAG系统如何解决幻觉问题?"
reranked_results = rerank_results(reranker, query, fused_results[:5], top_k=3)
print(f"查询: {query}nn重排序后的Top-3结果:")
for i, doc in enumerate(reranked_results, 1):
print(f"{i}. {doc.page_content[:80]}...")
Cross-Encoder 的优势在于**交互充分**——query 与文档的每个 token 均可互相 attention,精度远超 Bi-Encoder。代价是**计算量大**,无法提前缓存,只能对召回后的少量候选(通常 50-200 个)进行重排。
4.3 量化收益
掘金文章《RRF混合检索+BGE重排序:召回率从0.67到0.82的实战》给出了具体数据:
| 方案 | Recall@10 | 相对提升 |
|---|
| 纯向量检索 | 0.67 | 基准 |
| 向量 + BM25 混合 | 0.74 | +10.4% |
| 混合 + RRF 融合 | 0.78 | +16.4% |
| 混合 + RRF + BGE 重排 | 0.82 | +22.4% |
提升曲线清晰表明:**每增加一道工序,均有明确收益**。但边际收益递减——从 0.78 到 0.82 的提升,需要引入额外的重排序模型,使计算成本翻倍。实际项目中需权衡精度需求与延迟预算。
五、完整实战:构建 Query 改写 + 多路召回的 RAG 系统
以下提供可直接运行的完整代码,整合前述所有技术要点:
class AdvancedRAG:
"""完整的Query改写 + 多路召回 + RRF融合 + 重排序系统""" def __init__(self, documents: List[Document], use_llm: bool = True):
self.documents = documents
self.use_llm = use_llm
self.embeddings = HuggingFaceEmbeddings(model_name=r"E:LLM ProjectLocal Knowledge Base Q&A SystemmodelsXorbitsbge-large-zh-v1.5")
self.vector_store = FAISS.from_documents(documents, self.embeddings)
self.vector_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
self.bm25_retriever = BM25Retriever.from_documents(documents, k=5)
model_path = r'E:LLM ProjectLocal Knowledge Base Q&A SystemmodelsXorbitsbge-reranker-large'
if os.path.isdir(model_path):
self.reranker = FlagReranker(model_path, use_fp16=True)
elif 'load_reranker_with_retry' in dir():
self.reranker = load_reranker_with_retry('BAAI/bge-reranker-base', use_fp16=True)
else:
self.reranker = FlagReranker('BAAI/bge-reranker-base', use_fp16=True)
if use_llm:
self.llm = ChatOpenAI(temperature=0, model="deepseek-chat", api_key=os.environ["OPENAI_API_KEY"], base_url=os.environ["OPENAI_API_BASE"])
else:
self.llm = None def query_rewrite(self, query: str, num_variants: int = 3) -> List[str]:
"""使用LLM生成多个查询改写"""
if not self.llm:
return [query] prompt = f"""请将以下问题改写为{num_variants}个不同表述的查询,每个查询单独一行,保持语义不变但用词不同:原问题:{query}改写查询:""" response = self.llm.invoke(prompt)
variants = [line.strip() for line in response.content.split("n") if line.strip()]
return [query] + variants[:num_variants] def hyde_expand(self, query: str) -> str:
"""HyDE:生成伪文档"""
if not self.llm:
return query prompt = f"""请根据以下问题,生成一段可能包含答案的文档片段(100-200字):问题:{query}文档片段:""" response = self.llm.invoke(prompt)
return response.content def multi_path_retrieve(self, query: str) -> List[List[Document]]:
"""多路召回:向量 + BM25"""
vector_results = self.vector_retriever.invoke(query)
bm25_results = self.bm25_retriever.invoke(query) return [vector_results, bm25_results] def rrf_fuse(self, results_lists: List[List[Document]], k: int = 60) -> List[Document]:
"""RRF融合多路召回结果"""
scores = {}
doc_map = {} for results in results_lists:
for rank, doc in enumerate(results, start=1):
doc_id = doc.metadata.get("doc_id", hash(doc.page_content))
doc_map[doc_id] = doc if doc_id not in scores:
scores[doc_id] = 0
scores[doc_id] += 1.0 / (k + rank) sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return [doc_map[doc_id] for doc_id, _ in sorted_docs] def rerank(self, query: str, candidates: List[Document], top_k: int = 3) -> List[Document]:
"""使用BGE-Reranker做精排(自动修复 prepare_for_model 兼容性)"""
if not candidates:
return []
if not hasattr(self.reranker.tokenizer, 'prepare_for_model'):
self._inject_prepare_for_model() pairs = [[query, doc.page_content] for doc in candidates]
scores = self.reranker.compute_score(pairs) reranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return [doc for doc, score in reranked[:top_k]] def _inject_prepare_for_model(self):
"""为 XLMRobertaTokenizer 注入 prepare_for_model 兼容方法"""
import types
from transformers import BatchEncoding
t = self.reranker.tokenizer
def _fn(this, ids, pair_ids=None, max_length=None,
add_special_tokens=True, padding=False,
truncation='only_second', stride=0,
return_tensors=None, return_token_type_ids=None,
return_attention_mask=None, return_overflowing_tokens=False,
return_special_tokens_mask=False):
ta = this.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
tb = this.decode(pair_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) if pair_ids is not None else None
enc = this(ta, tb, truncation=truncation or True,
max_length=max_length or this.model_max_length,
padding=padding, add_special_tokens=add_special_tokens)
r = {'input_ids': enc['input_ids']}
if return_attention_mask: r['attention_mask'] = [1] * len(r['input_ids'])
if return_token_type_ids: r['token_type_ids'] = [0] * len(r['input_ids'])
if return_special_tokens_mask: r['special_tokens_mask'] = [0] * len(r['input_ids'])
if return_overflowing_tokens: r['overflowing_tokens'] = []
return BatchEncoding(r)
t.prepare_for_model = types.MethodType(_fn, t) def query(self, query: str, use_rewrite: bool = False, use_hyde: bool = False) -> dict:
"""完整查询流程"""
if use_rewrite and self.llm:
queries = self.query_rewrite(query)
print(f"查询改写结果:{queries}")
else:
queries = [query]
all_results = []
for q in queries:
results = self.multi_path_retrieve(q)
all_results.extend(results)
fused = self.rrf_fuse(all_results)
print(f"RRF融合后候选数:{len(fused)}")
final_results = self.rerank(query, fused[:10], top_k=3) return {
"query": query,
"rewrites": queries if use_rewrite else [],
"results": final_results,
"num_candidates": len(fused)
}
rag = AdvancedRAG(docs, use_llm=False)
print("AdvancedRAG系统已初始化(不使用LLM改写)")
query = <