Skip to content

Conversational Retrieval

Overview

Retrieval uses following components:

  1. Document loaders
  2. Document transformers
  3. Text embedding models
  4. Vector stores
  5. Retrievers

RetrievalChain は、question_generator (LLMChain) を使って与えられたquestionchat_historyから質問を生成して投げる。これによって過去のchat_historyの文脈を含んだ質問ができるようになる。一方で、質問のtoken上限に達してしまう可能性がある。

Types

  1. BaseConversationalRetrievalChain: Chain for chatting with an index これを基本は継承する
  2. ConversationalRetrievalChain(BaseConversationalRetrievalChain): Chain for having a conversation based on retrieved documents.
    This chain takes in chat history (a list of messages) and new questions,
    and then returns an answer to that question.
    The algorithm for this chain consists of three parts:
     1. Use the chat history and the new question to create a "standalone question".
    This is done so that this question can be passed into the retrieval step to fetch
    relevant documents. If only the new question was passed in, then relevant context
    may be lacking. If the whole conversation was passed into retrieval, there may
    be unnecessary information there that would distract from retrieval.
     2. This new question is passed to the retriever and relevant documents are
    returned.
     3. The retrieved documents are passed to an LLM along with either the new question
    (default behavior) or the original question and chat history to generate a final
    response.
    
  3. ChatVectorDBChain(BaseConversationalRetrievalChain): Chain for chatting with a vector database.

Implementation

  • _call: これがメインで実行される関数

    qa =
    

Components

1. RetrievalQA

  1. RetrievalQA(BaseRetrievalQA)
    1. 初期化: from_llmで、retrieverを渡して初期化する (RetrievalQAは、retriever: BaseRetriever = Field(exclude=True) を持っている。)
    2. Query: qa(query) で実行される。 BaseRetrievalQA._callが呼ばれる
    3. self._get_docs(question)でDocumentをretriever.get_relevant_documentsで取得する
      def _get_docs(
          self,
          question: str,
          *,
          run_manager: CallbackManagerForChainRun,
      ) -> List[Document]:
          """Get docs."""
          return self.retriever.get_relevant_documents(
              question, callbacks=run_manager.get_child()
          )
      
    4. _call内で_get_docsの後にdocsをcombine_documents_chainに渡して、DocsとQuestionで最終的な答えを出す。
      answer = self.combine_documents_chain.run(
          input_documents=docs, question=question, callbacks=_run_manager.get_child()
      )
      
    5. PromptTemplate
  2. ConversationalRetrievalChain
    1. step:
      1. standalone questionsをChat Historyから作成
      2. questionを使って関連DocumentをVectorStoreから取得する
      3. 取得したDocumentをLLMに(新しいquestionまたは元々のQuestionとChatHistoryと共に)渡して最終的な答えを作成する
    2. template: CONDENSE_QUESTION_PROMPT
      _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
      
      Chat History:
      {chat_history}
      Follow Up Input: {question}
      Standalone question:"""
      

2. Retriever

RetrievalQAはRetrieverを使ってDocumentを検索する。

  1. VectorStoreは as_retriever()という関数でVectorStoreRetrieverを取得できる
  2. retrieverは、検索の種類が similarity, similarity_score_threshold, mmrの3種類がある (TODO: どういうときにどれを使うかの比較)

  3. 検索時は、RetrievalQA._get_docsretriever.get_relevant_documentsを呼ぶ

  4. retriever._get_relevant_documents が呼ばれる

    1. 内部では self.vectorstore
      1. similarity_search(query, **self.search_kwargs),
      2. similarity_search_with_relevance_scores(query, **self.search_kwargs)
      3. max_marginal_relevance_search(query, **self.search_kwargs)のいづれかが呼ばれる。
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        if self.search_type == "similarity":
            docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
        elif self.search_type == "similarity_score_threshold":
            docs_and_similarities = (
                self.vectorstore.similarity_search_with_relevance_scores(
                    query, **self.search_kwargs
                )
            )
            docs = [doc for doc, _ in docs_and_similarities]
        elif self.search_type == "mmr":
            docs = self.vectorstore.max_marginal_relevance_search(
                query, **self.search_kwargs
            )
        else:
            raise ValueError(f"search_type of {self.search_type} not allowed.")
        return docs
    

vectorstoreから取得する部分は、次のVectorstoreを参照。

2. Vectorstore

VectorstoreはLangChain内でラップされたDeepLakeを使っている。が、DeepLakeの機能をすべて使いたい場合には、中身を理解して適宜拡張する必要がある。 デフォルトでは、Scoreが返ってこないので、DeepLakeのVectorStoreを直接いじれるように変更する必要がある。

  1. LangChain

    1. DeepLake.similarity_search

      def similarity_search(
          self,
          query: str,
          k: int = 4,
          **kwargs: Any,
      ) -> List[Document]:
          return self._search(
              query=query,
              k=k,
              use_maximal_marginal_relevance=False,
              return_score=False,
              **kwargs,
          )
      

      _searchは、queryからembedding_functionを使ってembeddingを計算して self.vectorstore.searchを呼ぶ。

      result = self.vectorstore.search(
          embedding=embedding,
          k=fetch_k if use_maximal_marginal_relevance else k,
          distance_metric=distance_metric,
          filter=filter,
          exec_option=exec_option,
          return_tensors=["embedding", "metadata", "text", "id"],
          deep_memory=deep_memory,
      )
      
      1. VectorStore.similarity_search_with_relevance_scores
      def similarity_search_with_relevance_scores(
          self,
          query: str,
          k: int = 4,
          **kwargs: Any,
      ) -> List[Tuple[Document, float]]:
          docs_and_similarities = self._similarity_search_with_relevance_scores(
              query, k=k, **kwargs
          )
          ...
          return docs_and_similarities
      
      _similarity_search_with_relevance_scoresで実際の計算はするが、計算式はSubclassでself._select_relevance_score_fn()で変更することができる。

    2. DeepLake.max_marginal_relevance_search

      return self._search(
          query=query,
          k=k,
          fetch_k=fetch_k,
          use_maximal_marginal_relevance=True,
          lambda_mult=lambda_mult,
          exec_option=exec_option,
          embedding_function=embedding_function,  # type: ignore
          **kwargs,
      )
      

      maximum_marginal_relevanceで、fetch_k取得した分にたいして計算する。(TODO:詳細確認)

Maximal Marginal Relevance: The idea behind using MMR is that it tries to reduce redundancy and increase diversity in the result and is used in text summarization. (Maximal Marginal Relevance to Re-rank results in Unsupervised KeyPhrase Extraction)

  1. DeepLake: libs/langchain/langchain/vectorstores/deeplake.py
    1. Update dataset
    2. Offcial Docs
    3. Deep Lake Vector Store API

3. Datastore

  1. SQLite: For permanent data (Chat History, document source data (Confluence, Directory, Text, etc.), imported document.)

4. Text Splitter

  1. Check: https://langchain-text-splitter.streamlit.app/
  2. MarkdownHeaderTextSplitter
  3. RecursiveCharacterTextSplitter

5. Document Loader

  1. TextLoader for Directory import
  2. ConfluenceLoader for Confluence (Source)
    1. https://developer.atlassian.com/server/confluence/confluence-server-rest-api/
    2. https://python.langchain.com/docs/integrations/document_loaders/confluence
  3. GoogleDriveLoader
    1. google.auth package added to langchain in langchain#6035

FAQ

  1. ConversationalRetrievalChain vs. ChatVectorDBChain?
  2. retrievalqaをagentのtoolとして使う時にsource_documentを表示する方法は?
    1. ToolのInputとOutputはStringである必要があるので以下のように sourceをまとめて返す関数でwrapする (return_direct=True)
      def run_qa_chain(question):
          results = qa_chain({"question":question},return_only_outputs=True)
          return str(results)
      
    2. _run を拡張してSourceを含める (return_direct=True)
      def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
          """Use the tool."""
          retrieval_chain = self.create_retrieval_chain()
          answer = retrieval_chain(query)['answer']
          sources = '\n'.join(retrieval_chain(query)['sources'].split(', '))
      
          return f'{answer}\nSources:\n{sources}'
      

Ref

  1. RetrievalQA chain return source_documents when using it as a Tool for an Agent #5097
  2. Vector store agent with sources #4187