中国有色金属建设股份有限公司网站,wordpress+免费模版,东莞58同城做网站电话,百度网站的建设讲了很多理论#xff0c;最后来一篇实践作为结尾。本次案例根据阿里云的博金大模型挑战赛的题目以及数据集做一次实践。 完整代码地址#xff1a;https://github.com/forever1986/finrag.git 本次实践代码有参考#xff1a;https://github.com/Tongyi-EconML/FinQwen/ 目录 …讲了很多理论最后来一篇实践作为结尾。本次案例根据阿里云的博金大模型挑战赛的题目以及数据集做一次实践。 完整代码地址https://github.com/forever1986/finrag.git 本次实践代码有参考https://github.com/Tongyi-EconML/FinQwen/ 目录 1 题目内容1.1 数据集说明 2 设计思路2.1 总体思路2.2 RAG应用点2.3 代码地址 3 实现过程3.1 问题路由3.2 文本理解3.2.1 总体设计思路3.2.2 文档抽取3.2.3 文档分块3.2.4 文档检索重排 3.3 NL2SQL3.3.1 SQL生成3.3.2 结果生成 3.4 agenttool方式 4 提高召回率5 总结 1 题目内容
根据原先的挑战赛总结题目要求如下 1题目要求基于大模型构建一个问答系统 2问答系统数据来源包括pdf文档和关系型数据库 3回答内容可能是通过pdf获得内容也可能需要先查询数据库再根据获得的内容得到最终回答
该案例原先设计是为了“通义千问金融大模型”我们这里只是为了展现一下RAG系统构建实战过程因此不会一定使用“通义千问金融大模型”。
1.1 数据集说明
数据集下载地址https://www.modelscope.cn/datasets/BJQW14B/bs_challenge_financial_14b_dataset/files 主要下载3部分
pdf中的所有pdf文件dataset中的“博金杯比赛数据.db”question.json这个是测试集问题
简单来说就是回答question.json中的问题问题的答案包括在pdf和db中通过RAG形式获取最终答案。
2 设计思路
2.1 总体思路 总体设计思路如下
问题路由从question.json可以得出问题的答案要么在PDF中要么在DB中因此要优先判断问题是查询PDF还是DB文本理解如果问题的答案来自PDF那么就是走查询PDF的路径SQL查询如果问题的答案来自DB那么就走NL2SQL的路径最终答案根据查询结果让大模型得出想要的答案格式
2.2 RAG应用点
文档处理本次应用中需要读取PDF数据并进行检索。这里包括解析、分块、embedding、检索等。查询结构内容本次应用中需要从DB数据库中进行SQL查询因此包括Text-to-SQL等路由本次应用中需要将问题分类到PDF或者DB事实上就使用到了RAG的路由模块。重排本次应用中为了提高准确率通过检索得到的结果进行重排后扔给大模型
2.3 代码地址
本次实践的代码地址已经上传githubhttps://github.com/forever1986/finrag.git
3 实现过程
3.1 问题路由
从question.json中将问题做一个路由。我们从检索增强生成RAG系列5–RAG提升之路由routing中总结的2种方式Logical routing和Semantic routing本案例中2种方式都可以采用。下面演示采用Logical routing的方式。 Logical routing其实就是采用prompt的方式让大模型给出一个路由结果这里我们也有2种方式可以选择
提示词当你的大模型参数量或者推理能力较强的时候可以直接使用promptfew shot方式指令微调通过给出一定数量500个指令数据对模型进行微调比如通过公司名、问题模板等方式进行指令微调让大模型具备分类能力
下面通过提示词和该案例的特点进行问题路由。
这里采用智谱AI的API接口因此可以先去申请一个API KEY当然你使用其它模型也可以目前智谱AI的GLM4送token就拿它来试验吧提取pdf的公司名称该案例特点就是pdf主要是公司的招股文书而question.json中问题提及到公司名称因此可以通过给prompt加上公司名称来提示大模型进行准确回答
import os
import config
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI# 初始化模型
llm ChatOpenAI(temperature0.95,modelglm-4,openai_api_key你的API KEY,openai_api_basehttps://open.bigmodel.cn/api/paas/v4/
)df pd.DataFrame(columns[filename, company])
i 1
for filename in os.listdir(config.text_files_path):if filename.endswith(.txt):file_path os.path.join(config.text_files_path, filename)with open(file_path, r, encodingutf-8) as file:content file.read()template ChatPromptTemplate.from_template(你是一个能精准提取信息的AI。我会给你一篇招股说明书请输出此招股说明书的主体是哪家公司若无法查询到则输出无。\n{t}\n\n请指出以上招股说明书属于哪家公司请只输出公司名。)chain template | llmresponse chain.invoke({t: content[:3000]})print(response.content)df.at[i, filename] filenamedf.at[i, company] response.contenti 1
df.to_csv(config.company_save_path)
下面通过自定义agent和tool的方式进行问题路由关键设计在于prompt中增加公司名称和few-shot方式下面只是贴出主要流程的代码全部代码可以下载全部代码。 其中config、util.instances和util.prompts都是基础类pdf_retrieve_chain和sql_retrieve_chain是自定义的tool的function import re
from typing import Sequence, Unionimport pandas as pd
from langchain.agents import AgentExecutor, AgentOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.tools.render import render_text_description
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool, Toolimport config
from SQL_retrieve_chain import sql_retrieve_chain
from util.instances import LLM
from pdf_retrieve_chain import
from util import promptsdef create_react_my_agent(llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: BasePromptTemplate
) - Runnable:# noqa: E501missing_vars {tools, tool_names, agent_scratchpad}.difference(prompt.input_variables)if missing_vars:raise ValueError(fPrompt missing required variables: {missing_vars})# 读取公司名称df pd.read_csv(config.company_save_path)company_list df[company]company_content for company in company_list:company_content company_content \n company# print(company_content)prompt prompt.partial(toolsrender_text_description(list(tools)),tool_names, .join([t.name for t in tools]),companycompany_content)llm_with_stop llm.bind(stop[\n观察])temp_agent (RunnablePassthrough.assign(agent_scratchpadlambda x: format_log_to_str(x[intermediate_steps]),)| prompt| llm_with_stop| MyReActSingleInputOutputParser())return temp_agentclass MyReActSingleInputOutputParser(AgentOutputParser):def get_format_instructions(self) - str:return FORMAT_INSTRUCTIONSdef parse(self, text: str) - Union[AgentAction, AgentFinish]:FINAL_ANSWER_ACTION Final Answer:FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE (Parsing LLM output produced both a final answer and a parse-able action:)includes_answer FINAL_ANSWER_ACTION in textregex (rAction\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*))action_match re.search(regex, text, re.DOTALL)if action_match:action action_match.group(1).strip()action_input action_match.group(2)tool_input action_input.strip( )tool_input tool_input.strip()return AgentAction(action, tool_input, text)elif includes_answer:return AgentFinish({output: text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text)else:return AgentFinish({output: text}, text)propertydef _type(self) - str:return react-single-inputauto_tools [Tool(name招股说明书,funcpdf_retrieve_chain,description招股说明书检索,),Tool(name查询数据库,funcsql_retrieve_chain,description查询数据库检索结果,),
]
tmp_prompt ChatPromptTemplate.from_template(prompts.AGENT_CLASSIFY_PROMPT_TEMPLATE)
agent create_react_my_agent(LLM, auto_tools, prompttmp_prompt)agent_executor AgentExecutor(agentagent, toolsauto_tools, verboseTrue)
result agent_executor.invoke({question: 报告期内华瑞电器股份有限公司人工成本占主营业务成本的比例分别为多少})
# result agent_executor.invoke({question: 请帮我计算在20210105中信行业分类划分的一级行业为综合金融行业中涨跌幅最大股票的股票代码是涨跌幅是多少百分数保留两位小数。股票涨跌幅定义为收盘价 - 前一日收盘价 / 前一日收盘价* 100%。})
print(result[output])
3.2 文本理解
这部分分为2个阶段第一个阶段是文档处理第二部分是检索排序。在设计该模块时我们在检索增强生成RAG系列3–RAG优化之文档处理中讲过解析、分块、embedding、向量数据库都对最终结果的准确度或者召回率会产生较大影响。但是实际实践中如果按照普通方式进行解析、分块、embedding最终检索的准确率一定不高因此在不同场景的应用中需要做一些技巧性从而提高最终检索召回率。
3.2.1 总体设计思路 本案例是一个金融招股书的检索每一份招股书都是对应一个公司而question.json中对于检索招股书都会涉及公司名称因此该部分的设计可以利用该特性进行设计
对pdf文档进行解析为txt并以对应公司名称进行存储分块对文档进行2个层次分块先进行较大长度分块然后通过将较大长度的分块进行细分块这样公司-大分块-小分块的映射关系在检索的时候可以通过公司进行匹配在embedding中可以通过小分块匹配后找到大分块这样增加上下文内容从而提高召回率通过问题与公司之间的匹配度获得公司名称通过双链路检索稀疏BM25检索和密集embedding相似度检索增加检索结果的准确率通过重排将2种查询结果进行重排增加检索结果的准确率
3.2.2 文档抽取
对于本案例中pdf的格式大致相同而且主要包括文字和表格。在本次案例中尝试了一些开源的pdfplumber、pdfminer、gptpdf、RAGFlow等说一下总结
pdfplumber、pdfminer虽然能解析表格但是对于一些特别的表格比如该案例中一些没有左右边框的表格解析不好另外一些换行也需要自己处理一般都比较难处理好。RAGFlow解析效果不错特别是表格和自动换行但是也会出现部分问题解析错误的但是整体效果比pdfplumber、pdfminer好很多。gptpdf通过截图大模型的方式进行解析首先是需要费用其次尝试过chatgpt之外的模型使用其它大模型需要改提示语、agent等效果也很不好另外它将图片、表格圈出来再让大模型去识别图片和表格经常会将表格上下的文本也圈进去。
下面代码是本次实践中使用pdfplumber方式进行解析大概原理如下
通过pdfplumber的find_tables获取表格循环表格获取表格之上的文字获取表格的markdown格式最后一个表格时获取表格之下的文字存在问题部分没有左右边框的表格处理不好没有实现较好的换行页眉页尾等不相关内容未做处理
import re
import pdfplumber# 通过表格的top和bottom来读取页面的文章通过3种情况
# 1 第一种情况top和bottom为空则代表纯文本
# 2 第二种情况top为空bottom不为空则代表处理最后一个表格下面的文本
# 3 第三种情况top和bottom不为空则代表处理表格上面的文本
def check_lines(page, top, bottom):try:# 获取文本框lines page.extract_words()except Exception as e:print(f页码: {page.page_number}, 抽取文本异常异常信息: {e})return # empty utilcheck_re (?:。||单位元|单位万元|币种人民币)$page_top_re (招股意向书(?:全文)?(?:修订版|修订稿|更正后)?)text last_top 0last_check 0if top and bottom :if len(lines) 0:print(f{page.page_number}页无数据, 请检查)return for l in range(len(lines)):each_line lines[l]# 第一种情况top和bottom为空则代表纯文本if top and bottom :if abs(last_top - each_line[top]) 2:text text each_line[text]elif last_check 0 and (page.height * 0.9 - each_line[top]) 0 and not re.search(check_re, text):if \n not in text and re.search(page_top_re, text):text text \n each_line[text]else:text text each_line[text]else:if text :text each_line[text]else:text text \n each_line[text]# 第二种情况top为空bottom不为空则代表处理最后一个表格下面的文本elif top :if each_line[top] bottom:if abs(last_top - each_line[top]) 2:text text each_line[text]elif last_check 0 and (page.height * 0.85 - each_line[top]) 0 and not re.search(check_re, text):if \n not in text and re.search(page_top_re, text):text text \n each_line[text]else:text text each_line[text]else:if text :text each_line[text]else:text text \n each_line[text]# 第三种情况top和bottom不为空则代表处理表格上面的文本else:if top each_line[top] bottom:if abs(last_top - each_line[top]) 2:text text each_line[text]elif last_check 0 and (page.height * 0.85 - each_line[top]) 0 and not re.search(check_re, text):if \n not in text and re.search(page_top_re, text):text text \n each_line[text]else:text text each_line[text]else:if text :text each_line[text]else:text text \n each_line[text]last_top each_line[top]last_check each_line[x1] - page.width * 0.83return text# 删除没有数据的列
def drop_empty_cols(data):# 删除所有列为空数据的列transposed_data list(map(list, zip(*data)))filtered_data [col for col in transposed_data if not all(cell for cell in col)]result list(map(list, zip(*filtered_data)))return result# 通过判断页面是否有表格
# 1 如果没有表格则按照读取文本处理
# 2 如果有表格则获取每个表格的top坐标和bottom坐标按照表格顺序先读取表格之上的文字在使用markdown读取表格
# 3 不断循环2等到最后一个表格只需要读取表格之下的文字即可
def extract_text_and_tables(page):all_text bottom 0try:tables page.find_tables()except:tables []if len(tables) 1:count len(tables)for table in tables:# 判断表格底部坐标是否小于0if table.bbox[3] bottom:passelse:count - 1# 获取表格顶部坐标top table.bbox[1]text check_lines(page, top, bottom)text_list text.split(\n)for _t in range(len(text_list)):all_text text_list[_t] \nbottom table.bbox[3]new_table table.extract()r_count 0for r in range(len(new_table)):row new_table[r]if row[0] is None:r_count 1for c in range(len(row)):if row[c] is not None and row[c] not in [, ]:if new_table[r - r_count][c] is None:new_table[r - r_count][c] row[c]else:new_table[r - r_count][c] row[c]new_table[r][c] Noneelse:r_count 0end_table []for row in new_table:if row[0] is not None:cell_list []cell_check Falsefor cell in row:if cell is not None:cell cell.replace(\n, )else:cell if cell ! :cell_check Truecell_list.append(cell)if cell_check:end_table.append(cell_list)end_table drop_empty_cols(end_table)markdown_table # 存储当前表格的Markdown表示for i, row in enumerate(end_table):# 移除空列这里假设空列完全为空根据实际情况调整row [cell for cell in row if cell is not None and cell ! ]# 转换每个单元格内容为字符串并用竖线分隔processed_row [str(cell).strip() if cell is not None else for cell in row]markdown_row | | .join(processed_row) |\nmarkdown_table markdown_row# 对于表头下的第一行添加分隔线if i 0:separators [:--- if cell.isdigit() else --- for cell in row]markdown_table | | .join(separators) |\nall_text markdown_table \nif count 0:text check_lines(page, , bottom)text_list text.split(\n)for _t in range(len(text_list)):all_text text_list[_t] \nelse:text check_lines(page, , )text_list text.split(\n)for _t in range(len(text_list)):all_text text_list[_t] \nreturn all_textdef extract_text(pdf_path):with pdfplumber.open(pdf_path) as pdf:all_text for i, page in enumerate(pdf.pages):all_text extract_text_and_tables(page)return all_textif __name__ __main__:# 使用示例test_pdf_path data/pdf/3e0ded8afa8f8aa952fd8179b109d6e67578c2dd.pdfextracted_text extract_text(test_pdf_path)pdf_save_path data/pdf_txt_file2/宁波华瑞电器股份有限公司.txtwith open(pdf_save_path, w, encodingutf-8) as file:file.write(extracted_text)
3.2.3 文档分块
通过将3.2.1中得到的txt文档进行分块分块步骤如下
进行大的分块然后将大分块再次进行小分块将小分块做2部分存储一部分存储pkl文件是用于bm25检索一部分存储在faiss向量数据库用于向量检索将文档–大分块–小分块的映射关系进行存储每个招股文件存储为一个pkl文件
import os
import faiss
import numpy
import pickle
import config
from tqdm import tqdm
from util.instances import BEG_MODEL
from langchain.text_splitter import RecursiveCharacterTextSplitter# 将每个公司的txt文件进行分块并将分别存储在本地文件和本地向量数据库
# 本地文件存为pkl用于bm25的相似度查询
# 本地向量数据库用于embedding的相似度查询
def splitter_doc(txt_file, model, splitterFalse, doc_chunk_size800, doc_chunk_overlap100,sub_chunk_size150, sub_chunk_overlap50):if not splitter:pkl_save_path os.path.join(config.pkl_save_path, txt_file.split(.)[0] .pkl)if os.path.exists(pkl_save_path):print(当前文件已经初始化完成无需再次初始化如希望重新写入则将参数splitter设为True)return# 第一步读取txt文件cur_file_path os.path.join(data/pdf_txt_file2, txt_file)with open(cur_file_path, r, encodingutf-8) as file:file_doc file.read()# 第二步先将文档切块text_splitter RecursiveCharacterTextSplitter(chunk_sizedoc_chunk_size, chunk_overlapdoc_chunk_overlap,separators[\n], keep_separatorTrue, length_functionlen)parent_docs text_splitter.split_text(file_doc)print(len(parent_docs))# 第三步将切块再次切分小文本cur_text []child_parent_dict {} # 子模块与父模块的dictfor doc in parent_docs:text_splitter RecursiveCharacterTextSplitter(chunk_sizesub_chunk_size, chunk_overlapsub_chunk_overlap,separators[\n, ], keep_separatorTrue, length_functionlen)child_docs text_splitter.split_text(doc)for child_doc in child_docs:child_parent_dict[child_doc] doccur_text child_docs# 第四步将文本向量化返回一个key为文本value为embedding的dictresult_dict dict()for doc in tqdm(cur_text):result_dict[doc] numpy.array(model.encode(doc))# 第五步将dict存储为.pkl文件用于bm25相似度查询pkl_save_path os.path.join(config.pkl_save_path, txt_file.split(.)[0] .pkl)if os.path.exists(pkl_save_path):os.remove(pkl_save_path)print(存在旧版本pkl文件进行先删除后创建)with open(pkl_save_path, wb) as file:pickle.dump(result_dict, file)print(完成pkl数据存储, pkl_save_path)pkl_dict_save_path os.path.join(config.pkl_save_path, txt_file.split(.)[0] _dict .pkl)if os.path.exists(pkl_dict_save_path):os.remove(pkl_dict_save_path)print(存在旧版本pkl dict文件进行先删除后创建)with open(pkl_dict_save_path, wb) as file:pickle.dump(child_parent_dict, file)print(完成pkl dict数据存储, pkl_dict_save_path)# 第六步将dict中的向量化数据存储到faiss数据库result_vectors numpy.array(list(result_dict.values()))dim result_vectors.shape[1]index faiss.IndexFlatIP(dim)faiss.normalize_L2(result_vectors)index.add(result_vectors)faiss_save_path os.path.join(config.faiss_save_path, txt_file.replace(txt, faiss))if os.path.exists(faiss_save_path):os.remove(faiss_save_path)print(存在旧版本faiss索引文件进行先删除后创建)faiss.write_index(index, faiss_save_path)print(完成faiss向量存储, faiss_save_path)if __name__ __main__:txt_file_name 宁波华瑞电器股份有限公司.txt# 存储数据splitter_doc(txt_file_name, BEG_MODEL)
3.2.4 文档检索重排
关于向量搜索能否取代传统的一些文本搜索的问题相信网上已经做了很多的讨论。我想说的是做过真正实践的人就不会问出这样的问题。这里采用的就是BM25向量检索的双重。并根据检索增强生成RAG系列7–RAG提升之高级阶段中的重排BGE-reranker模型进行重排。 注意这里面有个rerank_api方法调用bge的rerank需要下载bge-reranker-base并启动一个api服务。这里只是贴出主要流程代码全代码参考github import os
import json
import faiss
import numpy
import config
import pickle
import requests
import pandas as pd
from util import prompts
from rank_bm25 import BM25Okapi
from requests.adapters import HTTPAdapter
from util.instances import LLM, BEG_MODEL
from langchain_core.prompts import ChatPromptTemplateclass Query:def __init__(self, question, docs, top_k5):super().__init__()self.question questionself.docs docsself.top_k top_kdef to_dict(self):return {question: self.question,docs: self.docs,top_k: self.top_k}# 使用bm25进行检索
def bm25_retrieve(query, contents):bm25 BM25Okapi(contents)# 对于每个文档计算结合BM25bm25_scores bm25.get_scores(query)# 根据得分排序文档sorted_docs sorted(zip(contents, bm25_scores), keylambda x: x[1], reverseTrue)# print(通过bm25检索结果查到相关文本数量, len(sorted_docs))return sorted_docs# 使用faiss向量数据库的索引进行查询
def embedding_retrieve(query, txt_file, model):embed_select_docs []faiss_save_path os.path.join(data/embedding_index, txt_file.faiss)if os.path.exists(faiss_save_path):index faiss.read_index(faiss_save_path)query_embedding numpy.array(model.encode(query))_, search_result index.search(query_embedding.reshape(1, -1), 5)pkl_save_path os.path.join(config.pkl_save_path, txt_file.split(.)[0] .pkl)with open(pkl_save_path, rb) as file:docs_dict pickle.load(file)chunk_docs list(docs_dict.keys())embed_select_docs [chunk_docs[i] for i in search_result[0]] # 存储为列表# print(通过embedding检索结果查到相关文本数量, len(embed_select_docs))else:print(找不到对于的faiss文件请确认是否已经进行存储)return embed_select_docsdef search(query, model, llm, top_k5):# 读取公司名称列表df pd.read_csv(config.company_save_path)company_list df[company].to_numpy()# 使用大模型获得最终公司的名称prompt ChatPromptTemplate.from_template(prompts.COMPANY_PROMPT_TEMPLATE)chain prompt | llmresponse chain.invoke({company: company_list, question: query})# print(response.content)company_name response.contentfor name in company_list:if name in company_name:company_name namebreak# print(company_name)# 通过bm25获取相似度最高的chunkpkl_file os.path.join(config.pkl_save_path, company_name .pkl)with open(pkl_file, rb) as file:docs_dict pickle.load(file)chunk_docs list(docs_dict.keys())bm25_chunks [docs_tuple[0] for docs_tuple in bm25_retrieve(query, chunk_docs)[:top_k]]# 通过embedding获取相似度最高的chunkembedding_chunks embedding_retrieve(query, company_name, model)# 重排chunks list(set(bm25_chunks embedding_chunks))# print(通过双路检索结果, len(chunks))arg Query(questionquery, docschunks, top_ktop_k)chunk_similarity rerank_api(arg)# for r in chunk_similarity.items():# print(r)# 获取父文本块result_docs []pkl_dict_file os.path.join(config.pkl_save_path, company_name _dict .pkl)with open(pkl_dict_file, rb) as file:child_parent_dict pickle.load(file)for key, _ in sorted(chunk_similarity.items(), keylambda x: x[1], reverseTrue):for child_txt, parent_txt in child_parent_dict.items(): # 遍历父文本块if key child_txt: # 根据匹配的子文本块找到父文本result_docs.append(parent_txt)# print(最终结果)# for d in result_docs:# print(d)return result_docsdef rerank_api(query, urlhttp://127.0.0.1:8000/bge_rerank):headers {Content-Type: application/json}data json.dumps(query.__dict__)s requests.Session()s.mount(http://, HTTPAdapter(max_retries3))try:res s.post(url, datadata, headersheaders, timeout600)if res.status_code 200:return res.json()else:return Noneexcept requests.exceptions.RequestException as e:print(e)return Noneif __name__ __main__:user_query 报告期内华瑞电器股份有限公司人工成本占主营业务成本的比例分别为多少# 检索search(user_query, BEG_MODEL, LLM)
3.3 NL2SQL
本案例中一部分问题是需要通过查询DB获取结果的。在检索增强生成RAG系列6–RAG提升之查询结构内容Query Construction中讨论过几种不同的查询结构内容而本案例中就需要Text-to-SQL。Text-to-SQL需要3个步骤
将问题转换为SQL语句也就是SQL的生成执行SQL语句这个主要是执行DB的查询并获得查询结果生成最终结果
3.3.1 SQL生成
关于SQL的生成有几种不同的方法有的利用prompt有的利用微调有的利用特殊模型等等这方面的具体可以自行研究该案例中通过某一个通用大模型来实现因此可以采用以下2种方式
提示词直接使用promptfew shot方式指令微调通过给出一定数量500指令数据对模型进行微调比如通过表名、字段名等方式进行指令微调让大模型具备特定场景下生成SQL能力
无论使用上面哪一种最终你需要一些few shot或者一些指令数据这方面也是可以通过2种方式进行获得
人工编辑ChatGPT生成通过算法聚类
该案例中是将question.json中关于需要生成SQL的问题进行整理组成demo数据ICL_EXP.csv来自比赛团队中整理好的现成数据并使用Jaccard对问题与demo中的问题进行相似度计算获取几条相似度靠前的demo然后通过promptfew-shot方式进行SQL生成。
import csv
import re
import copy
import config
import pandas as pdfrom util.instances import TOKENIZER, LLM
from util import prompts
from langchain_core.prompts import ChatPromptTemplatedef generate_sql(question, llm, example_question_list, example_sql_list, tmp_example_token_list, example_num5):pattern1 r\d{8} # 过滤掉一些数字的正则表达式sql_pattern_start sqlsql_pattern_end temp_question question# 提取数字date_list re.findall(pattern1, temp_question)temp_question2_for_search temp_question# 将数字都替换为空格for t_date in date_list:temp_question2_for_search.replace(t_date, )temp_tokens TOKENIZER(temp_question2_for_search)temp_tokens temp_tokens[input_ids]# 计算与已有问题的相似度--使用Jaccard进行相似度计算similarity_list list()for cyc2 in range(len(tmp_example_token_list)):similarity_list.append(len(set(temp_tokens) set(tmp_example_token_list[cyc2]))/ (len(set(temp_tokens)) len(set(tmp_example_token_list[cyc2]))))# 求与第X个问题相似的问题t copy.deepcopy(similarity_list)# 求m个最大的数值及其索引max_index []for _ in range(example_num):number max(t)index t.index(number)t[index] 0max_index.append(index)# 防止提示语过长temp_length_test short_index_list list() # 匹配到的问题下标for index in max_index:temp_length_test temp_length_test example_question_list[index]temp_length_test temp_length_test example_sql_list[index]if len(temp_length_test) 2000:breakshort_index_list.append(index)# print(找到相似的模板, short_index_list)# 组装promptprompt ChatPromptTemplate.from_template(prompts.GENERATE_SQL_TEMPLATE)examples for index in short_index_list:examples examples 问题 example_question_list[index] \nexamples examples SQL example_sql_list[index] \nchain prompt | llmresponse chain.invoke({examples: examples, table_info: prompts.TABLE_INFO, question: temp_question})# print(问题, temp_question)# print(SQL, response.content)sql response.contentstart_index sql.find(sql_pattern_start) len(sql_pattern_start)end_index -1if start_index 0:end_index sql[start_index:].find(sql_pattern_end) start_indexif start_index end_index:sql sql[start_index:end_index]return prompt.invoke({examples: examples, table_info: prompts.TABLE_INFO, question: temp_question}), sqlelse:print(generate sql error:, temp_question)return error, errorif __name__ __main__:# 第一步读取问题和SQL模板使用tokenizer进行token化sql_examples_file pd.read_csv(config.sql_examples_path, delimiter,, header0)g_example_question_list list()g_example_sql_list list()g_example_token_list list()for cyc in range(len(sql_examples_file)):g_example_question_list.append(sql_examples_file[cyc:cyc 1][问题][cyc])g_example_sql_list.append(sql_examples_file[cyc:cyc 1][SQL][cyc])tokens TOKENIZER(sql_examples_file[cyc:cyc 1][问题][cyc])tokens tokens[input_ids]g_example_token_list.append(tokens)# 第二步测试问题及结果文件question_csv_file pd.read_csv(config.question_classify_path, delimiter,, header0)question_sql_file open(config.question_sql_path, w, newline, encodingutf-8-sig)csvwriter csv.writer(question_sql_file)csvwriter.writerow([问题id, 问题, SQL, prompt])# 第三步循环问题使用Jaccard进行相似度计算问题与模板中的问题相似度最高的几条记录for cyc in range(len(question_csv_file)):if question_csv_file[分类][cyc] 查询数据库:result_prompt, result generate_sql(question_csv_file[问题][cyc], LLM, g_example_question_list,g_example_sql_list, g_example_token_list)csvwriter.writerow([str(question_csv_file[cyc:(cyc 1)][问题id][cyc]),str(question_csv_file[cyc:(cyc 1)][问题][cyc]),result, result_prompt])else:print(pass question:, question_csv_file[问题][cyc])pass
3.3.2 结果生成
由于SQL查询结果一般是一个json格式或者数组格式的一个数据还需要通过大模型将数据转换成最终自然语言的结果。同样也是具备多种方式而本案例中可以采用如下
提示词直接使用promptfew shot方式指令微调通过给出一定数量500指令数据对模型进行微调。
本次演示跟SQL生成一样也是采用promptfew-shot方式其中demo数据ICL_EXP.csv来自比赛团队中整理好的现成数据并使用Jaccard对问题与demo中的问题进行相似度计算。
import csv
import re
import copy
import config
import pandas as pdfrom util.instances import LLM, TOKENIZER
from util import prompts
from langchain_core.prompts import ChatPromptTemplatedef generate_answer(question, fa, llm, example_question_list, example_info_list, example_fa_list,tmp_example_token_list, example_num5):pattern1 r\d{8} # 过滤掉一些数字的正则表达式temp_question question# 提取数字date_list re.findall(pattern1, temp_question)temp_question2_for_search temp_question# 将数字都替换为空格for t_date in date_list:temp_question2_for_search.replace(t_date, )temp_tokens TOKENIZER(temp_question2_for_search)temp_tokens temp_tokens[input_ids]# 计算与已有问题的相似度--使用Jaccard进行相似度计算similarity_list list()for cyc2 in range(len(tmp_example_token_list)):similarity_list.append(len(set(temp_tokens) set(tmp_example_token_list[cyc2]))/ (len(set(temp_tokens)) len(set(tmp_example_token_list[cyc2]))))# 求与第X个问题相似的问题t copy.deepcopy(similarity_list)# 求m个最大的数值及其索引max_index []for _ in range(example_num):number max(t)index t.index(number)t[index] 0max_index.append(index)# 防止提示语过长temp_length_test short_index_list list() # 匹配到的问题下标for index in max_index:temp_length_test temp_length_test example_question_list[index]temp_length_test temp_length_test example_fa_list[index]if len(temp_length_test) 2000:breakshort_index_list.append(index)# print(找到相似的模板, short_index_list)# 组装promptprompt ChatPromptTemplate.from_template(prompts.ANSWER_TEMPLATE)examples for index in short_index_list:examples examples 问题 example_question_list[index] \nexamples examples 资料 example_info_list[index] \nexamples examples 答案 example_fa_list[index] \nchain prompt | llmresponse chain.invoke({examples: examples, FA: fa, question: temp_question})# print(答案, response.content)return response.contentif __name__ __main__:# 第一步读取问题和FA模板使用tokenizer进行token化sql_examples_file pd.read_csv(config.sql_examples_path, delimiter,, header0)g_example_question_list list()g_example_info_list list()g_example_fa_list list()g_example_token_list list()for cyc in range(len(sql_examples_file)):g_example_question_list.append(sql_examples_file[cyc:cyc 1][问题][cyc])g_example_info_list.append(sql_examples_file[cyc:cyc 1][资料][cyc])g_example_fa_list.append(sql_examples_file[cyc:cyc 1][FA][cyc])tokens TOKENIZER(sql_examples_file[cyc:cyc 1][问题][cyc])tokens tokens[input_ids]g_example_token_list.append(tokens)# 第二步拿到答案result_csv_file pd.read_csv(config.question_sql_check_path, delimiter,, header0)answer_file open(config.answer_path, w, newline, encodingutf-8-sig)csvwriter csv.writer(answer_file)csvwriter.writerow([问题id, 问题, 资料, FA])# 第三步循环问题使用Jaccard进行相似度计算问题与模板中的问题相似度最高的几条记录for cyc in range(len(result_csv_file)):if result_csv_file[flag][cyc] 1:result generate_answer(result_csv_file[问题][cyc], result_csv_file[执行结果][cyc], LLM,g_example_question_list, g_example_info_list, g_example_fa_list,g_example_token_list)csvwriter.writerow([str(result_csv_file[cyc:(cyc 1)][问题id][cyc]),str(result_csv_file[cyc:(cyc 1)][问题][cyc]),str(result_csv_file[cyc:(cyc 1)][执行结果][cyc]),result])
3.4 agenttool方式
通过自定义agent和tool方式将整个流程串联起来
import re
from typing import Sequence, Unionimport pandas as pd
from langchain.agents import AgentExecutor, AgentOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.tools.render import render_text_description
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool, Toolimport config
from SQL_retrieve_chain import sql_retrieve_chain
from util.instances import LLM
from pdf_retrieve_chain import pdf_retrieve_chain
from util import promptsdef create_react_my_agent(llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: BasePromptTemplate
) - Runnable:# noqa: E501missing_vars {tools, tool_names, agent_scratchpad}.difference(prompt.input_variables)if missing_vars:raise ValueError(fPrompt missing required variables: {missing_vars})# 读取公司名称df pd.read_csv(config.company_save_path)company_list df[company]company_content for company in company_list:company_content company_content \n company# print(company_content)prompt prompt.partial(toolsrender_text_description(list(tools)),tool_names, .join([t.name for t in tools]),companycompany_content)llm_with_stop llm.bind(stop[\n观察])temp_agent (RunnablePassthrough.assign(agent_scratchpadlambda x: format_log_to_str(x[intermediate_steps]),)| prompt| llm_with_stop| MyReActSingleInputOutputParser())return temp_agentclass MyReActSingleInputOutputParser(AgentOutputParser):def get_format_instructions(self) - str:return FORMAT_INSTRUCTIONSdef parse(self, text: str) - Union[AgentAction, AgentFinish]:FINAL_ANSWER_ACTION Final Answer:FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE (Parsing LLM output produced both a final answer and a parse-able action:)includes_answer FINAL_ANSWER_ACTION in textregex (rAction\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*))action_match re.search(regex, text, re.DOTALL)if action_match:action action_match.group(1).strip()action_input action_match.group(2)tool_input action_input.strip( )tool_input tool_input.strip()return AgentAction(action, tool_input, text)elif includes_answer:return AgentFinish({output: text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text)else:return AgentFinish({output: text}, text)propertydef _type(self) - str:return react-single-inputauto_tools [Tool(name招股说明书,funcpdf_retrieve_chain,description招股说明书检索,),Tool(name查询数据库,funcsql_retrieve_chain,description查询数据库检索结果,),
]
tmp_prompt ChatPromptTemplate.from_template(prompts.AGENT_CLASSIFY_PROMPT_TEMPLATE)
agent create_react_my_agent(LLM, auto_tools, prompttmp_prompt)agent_executor AgentExecutor(agentagent, toolsauto_tools, verboseTrue)
result agent_executor.invoke({question: 报告期内华瑞电器股份有限公司人工成本占主营业务成本的比例分别为多少})
# result agent_executor.invoke({question: 请帮我计算在20210105中信行业分类划分的一级行业为综合金融行业中涨跌幅最大股票的股票代码是涨跌幅是多少百分数保留两位小数。股票涨跌幅定义为收盘价 - 前一日收盘价 / 前一日收盘价* 100%。})
print(result[output])
4 提高召回率
本次案例中虽然简单实现了功能过程还需要在不同环节中提高其召回率才能达到真正RAG业务使用级别。这里总结一下本次实践中还需要哪些提升以及方案中存在哪些问题
问题路由采用的是promptfew-shot方式缺点的过于依赖prompt文档解析采用pdfplumber进行解析在本案例中的效果其实一般部分表格没有解析得很好另外换行也是有待提高。因此这部分可以做改进文档分块虽然采用2层方式进行分块增加了召回上下文大小但是整体召回率还是不高需要不断优化分块大小通过调试获得最终的结果文档检索通过BM25和向量检索的结合但是实践中2种也不一定能很好的召回相关性最高的内容还是要结合其它传统检索方式比如ES等获得更为精确的召回结果SQL生成通过模板few-shot的方式缺点就是依赖于demo库需要比较大的人工整理也依赖于demo库中的样例丰富性。更为通用的方式是采用专业SQL生成大模型会得到更好的准确率问题生成本案例中也是通过demo库提供few-shot方式如果通过一定指令微调可能更为适应其泛化能力
5 总结
本次通过一次实践过程给大家演示一下RAG的落地过程。我们可以发现虽然前面2~7中讲了很多理论在实际过程中算是入门的应用过程中针对具体场景我们还是需要做其他大量工作特别是数据处理、寻找更高召回率的步骤慢慢探索。