File size: 2,938 Bytes
3ffd86f 42d6456 3ffd86f ea13192 87285ff d4f85e2 3ffd86f 2c973dd 3ffd86f 87285ff 3ffd86f ea13192 3ffd86f 87285ff 2c973dd 87285ff 3ffd86f 98dddc9 3ffd86f 87285ff 3ffd86f 98dddc9 3ffd86f 87285ff 3ffd86f ea13192 3ffd86f 05a8b3a 3ffd86f 2c973dd 87285ff 3ffd86f 87285ff 3ffd86f 87285ff 3ffd86f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# https://github.com/langchain-ai/langchain/issues/8623
import pandas as pd
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
from pydantic import Field
class ClimateQARetriever(BaseRetriever):
vectorstore:VectorStore
sources:list = ["IPCC","IPBES","IPOS"]
reports:list = []
threshold:float = 0.6
k_summary:int = 3
k_total:int = 10
namespace:str = "vectors",
min_size:int = 200,
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources,list)
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(self.reports) > 0:
filters["short_name"] = {"$in":self.reports}
else:
filters["source"] = { "$in":self.sources}
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"report_type": { "$in":["SPM"]},
}
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"report_type": { "$nin":["SPM"]},
}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
# docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i,(doc,score) in enumerate(docs):
doc.page_content = doc.page_content.replace("\r\n"," ")
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
results.append(doc)
# Sort by score
# results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
return results
|