简介
LLM大模型一般训练的数据都是滞后的,这是就需要用到RAG知识库,RAG知识库可以降低大模型在输出答案时的幻觉,也能够让大模型知识拓展。
知识库架构知识
检索流程图
- 用户输入 (User Query)
- |
- v
- +-----------------------+
- | 提示词 (Prompt) |
- +-----------------------+
- |
- | (1) 转化为向量 (Embedding)
- v
- +-----------------------+
- | 文字向量模型 (EMB) |
- +-----------------------+
- |
- | (2) 相似度检索 (Search)
- v
- +-----------------------+ +-----------------------+
- | RAG 向量数据库 (DB) | <-----> | 本地知识库/文档集 |
- +-----------------------+ +-----------------------+
- |
- | (3) 召回相关片段 (Context)
- v
- +-----------------------+
- | 大语言模型 (LLM) | <--- (将提示词与背景片段拼接)
- +-----------------------+
- |
- | (4) 最终生成 (Generation)
- v
- +-----------------------+
- | 输出结果 |
- +-----------------------+
复制代码 代码解释
代码的流程如下:
- 初始化RAG:文本切片->文本向量模型->构建向量数据库
- 询问ai:提示词->文本向量模型->向量数据库检索->组合prompt->喂给LLM->回答问题->记录历史对话
文本切片
- from operator import itemgetter
- import os
- from langchain_community.chat_models.tongyi import ChatTongyi
- from langchain_community.vectorstores import FAISS
- from langchain_community.embeddings import DashScopeEmbeddings
- from langchain_text_splitters import RecursiveCharacterTextSplitter
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
- from langchain_core.documents import Document
- from langchain_core.runnables import RunnablePassthrough
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.messages import HumanMessage, AIMessage
- from sqlalchemy import true
-
- #apikey
- os.environ["DASHSCOPE_API_KEY"] = "你的apikey"
- FAISS_PATH = "my_faiss_index"
- raw_text = """
- 【公司内部机密:代号“深蓝计划”】
- 1. 项目目标:开发一款能听懂猫语的翻译器。
- 2. 核心技术:基于Transformer的“喵声波”分析算法。
- 3. 团队暗号:如果有人问“今天天气怎么样?”,必须回答“我想吃鱼”。
- 4. 截止日期:2026年12月31日。
- 5. 经费预算:仅剩50元人民币,主要用于购买猫条。
- """
- #初始化文本
- docs = [Document(page_content=raw_text)]
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
- split_docs = text_splitter.split_documents(docs)
-
- #初始化模型
- llm = ChatTongyi(model="qwen-plus")
- embeddings = DashScopeEmbeddings(model="text-embedding-v1")
- #创建向量数据库
- if os.path.exists(FAISS_PATH):
- print("向量数据库已存在")
- ragdb = FAISS.load_local(FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
- else:
- print("创建向量数据库")
- ragdb = FAISS.from_documents(split_docs, embeddings)
- ragdb.save_local(FAISS_PATH)
- #构建提示词chain
- def format_docs(docs):
- return "\n\n".join(doc.page_content for doc in docs)
- final_prompt = ChatPromptTemplate.from_messages([
- ("system", """
- 你是一个专业的问答助手,你的任务是根据上下文简洁的回答用户的问题。
- <context>
- {context}
- </context>
- """),
- MessagesPlaceholder(variable_name="history"),
- ("human", "{input}")
- ])
- chain = (
- #查询rag
- RunnablePassthrough.assign(
- context = itemgetter("input") | ragdb.as_retriever() | format_docs
- )
- | RunnablePassthrough.assign(
- answer = {"input":itemgetter("input"), "context":itemgetter("context"), "history":itemgetter("history")} | final_prompt | llm | StrOutputParser()
- )
- )
- history = []
- while true:
- input_q = input("我:")
-
- respond = chain.invoke({
- "input": input_q,
- "history": history})
- print("answer:" + respond["answer"])
- print("=="*30)
-
- history.append(HumanMessage(content=input_q))
- history.append(AIMessage(content=respond['answer']))
复制代码 这里利用了langchain提供的文本分词器RecursiveCharacterTextSplitter(递归分词)
构建向量数据库
- docs = [Document(page_content=raw_text)]
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
- split_docs = text_splitter.split_documents(docs)
复制代码 这部分要注意:新版FAISS读取现有数据库要设置:allow_dangerous_deserialization=True,不然会报错
提示词模板
- llm = ChatTongyi(model="qwen-plus")
- embeddings = DashScopeEmbeddings(model="text-embedding-v1")
- #创建向量数据库
- if os.path.exists(FAISS_PATH):
- print("向量数据库已存在")
- ragdb = FAISS.load_local(FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
- else:
- print("创建向量数据库")
- ragdb = FAISS.from_documents(split_docs, embeddings)
- ragdb.save_local(FAISS_PATH)
复制代码 之前没有讲到历史对话记录,这次补充下:
MessagesPlaceholder这个是langchain框架的占位符(其实是框架写好了prompt模板,告诉ai这个是历史对话),使用时将历史对话记录的数组放在这里设置的字段中,在添加历史对话时要使用相关的类进行声明对话(告诉ai这句话是ai说的还是用户说的)- final_prompt = ChatPromptTemplate.from_messages([
- ("system", """
- 你是一个专业的问答助手,你的任务是根据上下文简洁的回答用户的问题。
- <context>
- {context}
- </context>
- """),
- MessagesPlaceholder(variable_name="history"),
- ("human", "{input}")
- ])
复制代码 Chain链的解释(核心逻辑)
- history.append(HumanMessage(content=input_q))
- history.append(AIMessage(content=respond['answer']))
复制代码 Chain链流程:
- 查询RAG的chain:获取input字段->内容交给向量数据库检索->将检索的内容(数组)转换为字符串格式->保存到context字段并传递给下一个任务
- 询问LLM的chain:获取input,context,history字段->填充上面定义的prompt模板->喂给LLM模型->解析成文本并保存在answer字段
itemgetter是获取上一个任务传递过来的字段内容。
如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |