mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor: replace sync requests with async httpx in Jina AI client (#1603)
* refactor: replace sync requests with async httpx in Jina AI client Replace synchronous `requests.post()` with `httpx.AsyncClient` in JinaClient.crawl() and make web_fetch_tool async. This is part of the planned async concurrency optimization for the agent hot path (see docs/TODO.md). * fix: address Copilot review feedback on async Jina client - Short-circuit error strings in web_fetch_tool before passing to ReadabilityExtractor, preventing misleading extraction results - Log missing JINA_API_KEY warning only once per process to reduce noise under concurrent async fetching - Use logger.exception instead of logger.error in crawl exception handler to preserve stack traces for debugging - Add async web_fetch_tool tests and warn-once coverage * fix: mock get_app_config in web_fetch_tool tests for CI The web_fetch_tool tests failed in CI because get_app_config requires a config.yaml file that isn't present in the test environment. Mock the config loader to remove the filesystem dependency. --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
52c8c06cf2
commit
2f3744f807
@ -1,13 +1,16 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
class JinaClient:
|
||||
def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
global _api_key_warned
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Return-Format": return_format,
|
||||
@ -15,11 +18,13 @@ class JinaClient:
|
||||
}
|
||||
if os.getenv("JINA_API_KEY"):
|
||||
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
|
||||
else:
|
||||
elif not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||
data = {"url": url}
|
||||
try:
|
||||
response = requests.post("https://r.jina.ai/", headers=headers, json=data)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"Jina API returned status {response.status_code}: {response.text}"
|
||||
@ -34,5 +39,5 @@ class JinaClient:
|
||||
return response.text
|
||||
except Exception as e:
|
||||
error_message = f"Request to Jina API failed: {str(e)}"
|
||||
logger.error(error_message)
|
||||
logger.exception(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
@ -8,7 +8,7 @@ readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
async def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@ -23,6 +23,8 @@ def web_fetch_tool(url: str) -> str:
|
||||
config = get_app_config().get_tool_config("web_fetch")
|
||||
if config is not None and "timeout" in config.model_extra:
|
||||
timeout = config.model_extra.get("timeout")
|
||||
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||
return html_content
|
||||
article = readability_extractor.extract_article(html_content)
|
||||
return article.to_markdown()[:4096]
|
||||
|
||||
177
backend/tests/test_jina_client.py
Normal file
177
backend/tests/test_jina_client.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Tests for JinaClient async crawl method."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
import deerflow.community.jina_ai.jina_client as jina_client_module
|
||||
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||
from deerflow.community.jina_ai.tools import web_fetch_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jina_client():
|
||||
return JinaClient()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_success(jina_client, monkeypatch):
|
||||
"""Test successful crawl returns response text."""
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
return httpx.Response(200, text="<html><body>Hello</body></html>", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
result = await jina_client.crawl("https://example.com")
|
||||
assert result == "<html><body>Hello</body></html>"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_non_200_status(jina_client, monkeypatch):
|
||||
"""Test that non-200 status returns error message."""
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
return httpx.Response(429, text="Rate limited", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
result = await jina_client.crawl("https://example.com")
|
||||
assert result.startswith("Error:")
|
||||
assert "429" in result
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_empty_response(jina_client, monkeypatch):
|
||||
"""Test that empty response returns error message."""
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
return httpx.Response(200, text="", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
result = await jina_client.crawl("https://example.com")
|
||||
assert result.startswith("Error:")
|
||||
assert "empty" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_whitespace_only_response(jina_client, monkeypatch):
|
||||
"""Test that whitespace-only response returns error message."""
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
return httpx.Response(200, text=" \n ", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
result = await jina_client.crawl("https://example.com")
|
||||
assert result.startswith("Error:")
|
||||
assert "empty" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_network_error(jina_client, monkeypatch):
|
||||
"""Test that network errors are handled gracefully."""
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
result = await jina_client.crawl("https://example.com")
|
||||
assert result.startswith("Error:")
|
||||
assert "failed" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_passes_headers(jina_client, monkeypatch):
|
||||
"""Test that correct headers are sent."""
|
||||
captured_headers = {}
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
captured_headers.update(kwargs.get("headers", {}))
|
||||
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
await jina_client.crawl("https://example.com", return_format="markdown", timeout=30)
|
||||
assert captured_headers["X-Return-Format"] == "markdown"
|
||||
assert captured_headers["X-Timeout"] == "30"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_includes_api_key_when_set(jina_client, monkeypatch):
|
||||
"""Test that Authorization header is set when JINA_API_KEY is available."""
|
||||
captured_headers = {}
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
captured_headers.update(kwargs.get("headers", {}))
|
||||
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
monkeypatch.setenv("JINA_API_KEY", "test-key-123")
|
||||
await jina_client.crawl("https://example.com")
|
||||
assert captured_headers["Authorization"] == "Bearer test-key-123"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_warns_once_when_api_key_missing(jina_client, monkeypatch, caplog):
|
||||
"""Test that the missing API key warning is logged only once."""
|
||||
jina_client_module._api_key_warned = False
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
monkeypatch.delenv("JINA_API_KEY", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.community.jina_ai.jina_client"):
|
||||
await jina_client.crawl("https://example.com")
|
||||
await jina_client.crawl("https://example.com")
|
||||
|
||||
warning_count = sum(1 for record in caplog.records if "Jina API key is not set" in record.message)
|
||||
assert warning_count == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_crawl_no_auth_header_without_api_key(jina_client, monkeypatch):
|
||||
"""Test that no Authorization header is set when JINA_API_KEY is not available."""
|
||||
jina_client_module._api_key_warned = False
|
||||
captured_headers = {}
|
||||
|
||||
async def mock_post(self, url, **kwargs):
|
||||
captured_headers.update(kwargs.get("headers", {}))
|
||||
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
monkeypatch.delenv("JINA_API_KEY", raising=False)
|
||||
await jina_client.crawl("https://example.com")
|
||||
assert "Authorization" not in captured_headers
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
|
||||
"""Test that web_fetch_tool short-circuits and returns the error string when crawl fails."""
|
||||
|
||||
async def mock_crawl(self, url, **kwargs):
|
||||
return "Error: Jina API returned status 429: Rate limited"
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
assert result.startswith("Error:")
|
||||
assert "429" in result
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
|
||||
"""Test that web_fetch_tool returns extracted markdown on successful crawl."""
|
||||
|
||||
async def mock_crawl(self, url, **kwargs):
|
||||
return "<html><body><p>Hello world</p></body></html>"
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
assert "Hello world" in result
|
||||
assert not result.startswith("Error:")
|
||||
Loading…
x
Reference in New Issue
Block a user