deer-flow/src/rag/dify.py
Xun 0e64c52975
refactor: Refactors the retriever function to use async/await (#821)
* refactor: Refactors the retriever function to use async/await
2026-01-20 19:56:26 +08:00

152 lines
4.7 KiB
Python

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import asyncio
import os
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class DifyProvider(Retriever):
"""
DifyProvider is a provider that uses dify to retrieve documents.
"""
api_url: str
api_key: str
def __init__(self):
api_url = os.getenv("DIFY_API_URL")
if not api_url:
raise ValueError("DIFY_API_URL is not set")
self.api_url = api_url
api_key = os.getenv("DIFY_API_KEY")
if not api_key:
raise ValueError("DIFY_API_KEY is not set")
self.api_key = api_key
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
if not resources:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
all_documents = {}
for resource in resources:
dataset_id, _ = parse_uri(resource.uri)
payload = {
"query": query,
"retrieval_model": {
"search_method": "hybrid_search",
"reranking_enable": False,
"weights": {
"weight_type": "customized",
"keyword_setting": {"keyword_weight": 0.3},
"vector_setting": {"vector_weight": 0.7},
},
"top_k": 3,
"score_threshold_enabled": True,
"score_threshold": 0.5,
},
}
response = requests.post(
f"{self.api_url}/datasets/{dataset_id}/retrieve",
headers=headers,
json=payload,
)
if response.status_code != 200:
raise Exception(f"Failed to query documents: {response.text}")
result = response.json()
records = result.get("records", {})
for record in records:
segment = record.get("segment")
if not segment:
continue
document_info = segment.get("document")
if not document_info:
continue
doc_id = document_info.get("id")
doc_name = document_info.get("name")
if not doc_id or not doc_name:
continue
if doc_id not in all_documents:
all_documents[doc_id] = Document(
id=doc_id, title=doc_name, chunks=[]
)
chunk = Chunk(
content=segment.get("content", ""),
similarity=record.get("score", 0.0),
)
all_documents[doc_id].chunks.append(chunk)
return list(all_documents.values())
async def query_relevant_documents_async(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Asynchronous version of query_relevant_documents.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(
self.query_relevant_documents, query, resources
)
def list_resources(self, query: str | None = None) -> list[Resource]:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
params = {}
if query:
params["keyword"] = query
response = requests.get(
f"{self.api_url}/datasets", headers=headers, params=params
)
if response.status_code != 200:
raise Exception(f"Failed to list resources: {response.text}")
result = response.json()
resources = []
for item in result.get("data", []):
item = Resource(
uri=f"rag://dataset/{item.get('id')}",
title=item.get("name", ""),
description=item.get("description", ""),
)
resources.append(item)
return resources
async def list_resources_async(self, query: str | None = None) -> list[Resource]:
"""
Asynchronous version of list_resources.
Wraps the synchronous implementation in asyncio.to_thread() to avoid blocking the event loop.
"""
return await asyncio.to_thread(self.list_resources, query)
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment