一架梯子,一头程序猿,仰望星空!
LangChain教程(Python版本) > 内容正文

LangChain 检索器(Retriever)


1. Retrievers(检索器)

检索器(retriever)是LangChain封装的一个接口,它可以根据非结构化查询返回相关文档。检索器设计的目的是为了方便查询本地数据。向量存储可以用作检索器的底层实现,LangChain支持多种retriever接口的底层实现。

2. Retriever入门

2.1. 安装

为了演示如何获取检索器,我们以Qdrant向量数据库为例进行介绍。

%pip install --upgrade --quiet  qdrant-client

2.2. 获取OpenAI API密钥

在使用OpenAIEmbeddings之前,我们需要获取OpenAI API密钥。

import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")

2.3. 导入文档数据和获取Qdrant客户端

以下代码演示了如何导入文档数据并获取Qdrant客户端以创建检索器:

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Qdrant
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter

# 加载本地文档
loader = TextLoader("../../modules/state_of_the_union.txt")
documents = loader.load()
# 切割文档,每块文档大小是1000
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

# 定义embedding模型,这里使用openai的模型
embeddings = OpenAIEmbeddings()

# 将处理后的文档,导入到向量数据库
qdrant = Qdrant.from_documents(
    docs,
    embeddings,
    path="/tmp/local_qdrant",
    collection_name="my_documents",
)

2.4. 获取检索器

通过以下代码演示了如何从Qdrant获取检索器:

retriever = qdrant.as_retriever()
retriever

可以通过问题查询跟问题相关的文档

docs = retriever.get_relevant_documents("what did he say about ketanji brown jackson")

也可以设置检索器的相似度阈值

# 文本相似度分数大于0.5分才会返回数据
retriever = db.as_retriever(
    search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.5}
)

也可以设置检索器,相似度最高的前面K条记录,下面定义就是返回相似度最高的2条记录

retriever = db.as_retriever(search_kwargs={"k": 2})

2.5. 在LCEL中使用检索器

由于检索器是Runnable对象,我们可以轻松地将它们与其他Runnable对象组合在一起,编排工作流:

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# 定义提示词模板(prompt)
template = """仅根据以下上下文回答问题:

{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# 定义模型
model = ChatOpenAI()

def format_docs(docs):
    return "\n\n".join([d.page_content for d in docs])

# 通过LCEL表达式编排工作流,这里定义了一条流程
chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

# 调用chain
chain.invoke("关于科技,总统说了什么?")

流程说明:

  • step1: 目的是生成一个包含contextquestion两个属性的字典,用于为提示模板(prompt template)准备参数, context参数由retriever检索器根据invoke方法传入的参数关于科技,总统说了什么?查询跟问题相似的文档,然后将文档数组传递给format_docs函数进行格式化,最后复制给context属性,RunnablePassthrough函数将调用chain的参数(用户输入的问题)复制给question属性。
  • step2: 将第一步生成的字典传递给prompt模板进行提示模板格式化
  • step3: 将prompt模板格式化后的prompt传递给模型(model)
  • step4: 将模型调用结果传递给输出解析器StrOutputParser

3. 自定义检索器

4.1. 检索器接口简介

检索器接口非常简单,我们可以很容易地编写自定义的检索器。

4.2. 自定义检索器示例

下面是一个自定义检索器的示例,展示了如何编写自定义检索器并使用它获取相关文档:

from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from typing import List

class CustomRetriever(BaseRetriever):

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        # 在这里实现查询功能,例如查询你本地数据库数据
        return [Document(page_content=query]

retriever = CustomRetriever()

retriever.get_relevant_documents("bar")

通过以上章节的学习,您将对检索器的概念、获取方法以及自定义方式有更深入的了解。



关联主题