深度学习系列80Pike-RAG解析
深度学习系列80:Pike-RAG解析
github地址在https://github.com/microsoft/PIKE-RAG
1. chunking
需要examples/chunking.py和examples/biology/configs/chunking.yml
1.1 chunking.yml文件
先看下配置文件:
- 去github上下载book并放在这个文件夹
input_doc_setting:
doc_dir: data/biology/contents
后面chunk完的结果会放在下面:
output_doc_setting:
doc_dir: data/biology/chunks
- 需要自定义个LLM,参考pikerag/llm_client下的文件,然后修改yml文件中的llm_client部分。
- 定义splitter,这里用的是pikerag.document_transformers.pikerag.document_transformers,其实就是RecursiveCharacterTextSplitter再加上一个chunk_summary。其中chunk_summary相关代码在prompts/chunking目录下,其中一个template如下:
chunk_summary_template_Chinese = MessageTemplate(
template=[
("system", "You are a helpful AI assistant good at document summarization."),
("user", """
# 原文来源
原文来自 {source} 的政策文档 {filename}。
# 原文
“部分原文”:
{content}
# 任务要求
你的任务是输出以上“部分原文”的总结。
# 输出
只输出内容总结,不要添加其他任何内容。
""".strip())
],
input_variables=["source", "filename", "content"],
)
splitter出来的Document会加上如下信息:
chunk_meta.update({"summary": chunk_summary})
1.2 chunking.py
首先加载文件:
from pikerag.document_loaders import get_loader
,会推断文件类型,支持:
class DocumentType(Enum):
csv = ["csv"]
excel = ["xlsx", "xls"]
markdown = ["md"]
pdf = ["pdf"]
text = ["txt"]
word = ["docx", "doc"]
增加metadata:
doc.metadata.update({"filename": doc_name})
分割文件并保存:
chunk_docs = self._splitter.transform_documents(docs)
2. qa
需要examples/qa.py和examples/biology/configs/qa.yml
2.1 qa.yml文件
调用模块都写在yml文件里面了:
workflow:
module_path: pikerag.workflows.qa
class_name: QaWorkflow
qa_protocol:
module_path: pikerag.prompts.qa
# available attr_name: multiple_choice_qa_protocol, multiple_choice_qa_with_reference_protocol
attr_name: multiple_choice_qa_with_reference_protocol
template_partial:
knowledge_domain: biological
llm_client:
module_path: pikerag.llm_client
# available class_name: AzureMetaLlamaClient, AzureOpenAIClient, HFMetaLlamaClient
class_name: AzureOpenAIClient
retriever:
module_path: pikerag.knowledge_retrievers
class_name: QaChunkRetriever
retrieval_query:
module_path: pikerag.knowledge_retrievers.query_parsers
func_name: question_plus_options_as_query
vector_store:
collection_name: biology_book
id_document_loading:
module_path: biology.utils
func_name: load_ids_and_chunks
args:
chunk_file_dir: data/biology/chunks
embedding_setting:
args:
model_name: "FremyCompany/BioLORD-2023"
model_kwargs:
device: "cuda:0"
encode_kwargs:
normalize_embeddings: True
2.2 模块分析
qa.py的代码非常简单:
workflow = workflow_class(yaml_config)
workflow.run()
这里顺着模块来梳理。qa的workflow在pikerag.workflows.qa下。
QaWorkflow类的初始化模块包括初始化logger、testing_suite、agent(包含了protocol、retriever、llm_client)、evaluator。
然后run函数调用_single_thread_run()或者_multiple_threads_run()。这里我们看单线程的版本,简单来说就是循环调用anser函数,那我们首先看下answer函数。
# 步骤:使用retriver获取相关chunks、将问题和chunks组合成message、调用LLM回答问题
def answer(self, qa: BaseQaData, question_idx: int) -> dict:
reference_chunks = self._retriever.retrieve_contents(qa, retrieve_id=f"Q{question_idx:03}")
messages = self._qa_protocol.process_input(content=qa.question, references=reference_chunks, **qa.as_dict())
response = self._client.generate_content_with_messages(messages, **self.llm_config)
output_dict: dict = self._qa_protocol.parse_output(response, **qa.as_dict())
……
2.3 retriever部分
核心代码为pikerage/workflows/qa.py中的answer函数.
2.3.1 retriever获取chunks
这里的retriever使用的是chroma_qa_retriever.py中的QaChunkRetriever类(同时使用到了mixins文件夹下的chroma_mixin.py)
代码为:
reference_chunks: List[str] = self._retriever.retrieve_contents(qa, retrieve_id=f"Q{question_idx:03}")
。
retriever初始化用到的embedding函数是FremyCompany/BioLORD-2023。我们可以把模型下载到本地加载,也可以把它修改为本地自定义embedding。
这里调用了biology.utils下的load_ids_and_chunks函数,使用pickle.load()读取chunks,数据读取地址为data/biology/chunks(跟chunking.yml保持一致)。
2.3.2 protocol构造prompt
接下来是使用_qa_protocol.process_input处理问题和chunks,代码为:
messages = self._qa_protocol.process_input(content=qa.question, references=reference_chunks, **qa.as_dict())
这里是pikerag.prompts.qa.multiple_choice.py中的MultipleChoiceQaWithReferenceParser函数,使用的template为multiple_choice_qa_with_reference_template。函数的功能就是用’\n’把多选项和references连接起来,和content一起填充到template中:
[
("system", "You are an helpful assistant good at {knowledge_domain} knowledge that can help people answer {knowledge_domain} questions."),
("user", """
# Task
Your task is to think step by step and then choose the correct option from the given options, the chosen option should be correct and the most suitable one to answer the given question. You can refer to the references provided when thinking and answering. Please note that the references may or may not be relevant to the question. If you don't have sufficient information to determine, randomly choose one option from the given options.
# Output format
The output should strictly follow the format below, do not add any redundant information.
<result>
<thinking>Your thinking for the given question.</thinking>
<answer>
<mask>The chosen option mask. Please note that only one single mask is allowable.</mask>
<option>The option detail corresponds to the chosen option mask.</option>
</answer>
</result>
# Question
{content}
# Options
{options_str}
# References
{references_str}
# Thinking and Answer
""".strip()),
]
2.3.3 LLM获取response
代码为:
response = self._client.generate_content_with_messages(messages, **self.llm_config)
实际调用的是_get_response_with_messages函数。我们以pikerag/llm_client/hf_meta_llama_client.py为例,其代码为:
self._client = AutoModelForCausalLM.from_pretrained(self._model_id, **kwargs)
input_ids = self._tokenizer.apply_chat_template(messages,add_generation_prompt=True, return_tensors="pt",).to(self._client.device)
outputs = self._client.generate(input_ids,pad_token_id=self._tokenizer.eos_token_id,**llm_config,)
response = outputs[0][input_ids.shape[-1]:]
最后再调用下面的代码解析结果:
output_dict: dict = self._qa_protocol.parse_output(response, **qa.as_dict())