mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor: replace sync requests with async httpx in Jina AI client (#1603)
* refactor: replace sync requests with async httpx in Jina AI client Replace synchronous `requests.post()` with `httpx.AsyncClient` in JinaClient.crawl() and make web_fetch_tool async. This is part of the planned async concurrency optimization for the agent hot path (see docs/TODO.md). * fix: address Copilot review feedback on async Jina client - Short-circuit error strings in web_fetch_tool before passing to ReadabilityExtractor, preventing misleading extraction results - Log missing JINA_API_KEY warning only once per process to reduce noise under concurrent async fetching - Use logger.exception instead of logger.error in crawl exception handler to preserve stack traces for debugging - Add async web_fetch_tool tests and warn-once coverage * fix: mock get_app_config in web_fetch_tool tests for CI The web_fetch_tool tests failed in CI because get_app_config requires a config.yaml file that isn't present in the test environment. Mock the config loader to remove the filesystem dependency. --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
52c8c06cf2
commit
2f3744f807
@ -1,13 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import requests
|
import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_api_key_warned = False
|
||||||
|
|
||||||
|
|
||||||
class JinaClient:
|
class JinaClient:
|
||||||
def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||||
|
global _api_key_warned
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"X-Return-Format": return_format,
|
"X-Return-Format": return_format,
|
||||||
@ -15,11 +18,13 @@ class JinaClient:
|
|||||||
}
|
}
|
||||||
if os.getenv("JINA_API_KEY"):
|
if os.getenv("JINA_API_KEY"):
|
||||||
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
|
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
|
||||||
else:
|
elif not _api_key_warned:
|
||||||
|
_api_key_warned = True
|
||||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||||
data = {"url": url}
|
data = {"url": url}
|
||||||
try:
|
try:
|
||||||
response = requests.post("https://r.jina.ai/", headers=headers, json=data)
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
error_message = f"Jina API returned status {response.status_code}: {response.text}"
|
error_message = f"Jina API returned status {response.status_code}: {response.text}"
|
||||||
@ -34,5 +39,5 @@ class JinaClient:
|
|||||||
return response.text
|
return response.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Request to Jina API failed: {str(e)}"
|
error_message = f"Request to Jina API failed: {str(e)}"
|
||||||
logger.error(error_message)
|
logger.exception(error_message)
|
||||||
return f"Error: {error_message}"
|
return f"Error: {error_message}"
|
||||||
|
|||||||
@ -8,7 +8,7 @@ readability_extractor = ReadabilityExtractor()
|
|||||||
|
|
||||||
|
|
||||||
@tool("web_fetch", parse_docstring=True)
|
@tool("web_fetch", parse_docstring=True)
|
||||||
def web_fetch_tool(url: str) -> str:
|
async def web_fetch_tool(url: str) -> str:
|
||||||
"""Fetch the contents of a web page at a given URL.
|
"""Fetch the contents of a web page at a given URL.
|
||||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||||
@ -23,6 +23,8 @@ def web_fetch_tool(url: str) -> str:
|
|||||||
config = get_app_config().get_tool_config("web_fetch")
|
config = get_app_config().get_tool_config("web_fetch")
|
||||||
if config is not None and "timeout" in config.model_extra:
|
if config is not None and "timeout" in config.model_extra:
|
||||||
timeout = config.model_extra.get("timeout")
|
timeout = config.model_extra.get("timeout")
|
||||||
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
|
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||||
|
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||||
|
return html_content
|
||||||
article = readability_extractor.extract_article(html_content)
|
article = readability_extractor.extract_article(html_content)
|
||||||
return article.to_markdown()[:4096]
|
return article.to_markdown()[:4096]
|
||||||
|
|||||||
177
backend/tests/test_jina_client.py
Normal file
177
backend/tests/test_jina_client.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""Tests for JinaClient async crawl method."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import deerflow.community.jina_ai.jina_client as jina_client_module
|
||||||
|
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||||
|
from deerflow.community.jina_ai.tools import web_fetch_tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def jina_client():
|
||||||
|
return JinaClient()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_success(jina_client, monkeypatch):
|
||||||
|
"""Test successful crawl returns response text."""
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
return httpx.Response(200, text="<html><body>Hello</body></html>", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
result = await jina_client.crawl("https://example.com")
|
||||||
|
assert result == "<html><body>Hello</body></html>"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_non_200_status(jina_client, monkeypatch):
|
||||||
|
"""Test that non-200 status returns error message."""
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
return httpx.Response(429, text="Rate limited", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
result = await jina_client.crawl("https://example.com")
|
||||||
|
assert result.startswith("Error:")
|
||||||
|
assert "429" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_empty_response(jina_client, monkeypatch):
|
||||||
|
"""Test that empty response returns error message."""
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
return httpx.Response(200, text="", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
result = await jina_client.crawl("https://example.com")
|
||||||
|
assert result.startswith("Error:")
|
||||||
|
assert "empty" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_whitespace_only_response(jina_client, monkeypatch):
|
||||||
|
"""Test that whitespace-only response returns error message."""
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
return httpx.Response(200, text=" \n ", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
result = await jina_client.crawl("https://example.com")
|
||||||
|
assert result.startswith("Error:")
|
||||||
|
assert "empty" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_network_error(jina_client, monkeypatch):
|
||||||
|
"""Test that network errors are handled gracefully."""
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
result = await jina_client.crawl("https://example.com")
|
||||||
|
assert result.startswith("Error:")
|
||||||
|
assert "failed" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_passes_headers(jina_client, monkeypatch):
|
||||||
|
"""Test that correct headers are sent."""
|
||||||
|
captured_headers = {}
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
captured_headers.update(kwargs.get("headers", {}))
|
||||||
|
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
await jina_client.crawl("https://example.com", return_format="markdown", timeout=30)
|
||||||
|
assert captured_headers["X-Return-Format"] == "markdown"
|
||||||
|
assert captured_headers["X-Timeout"] == "30"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_includes_api_key_when_set(jina_client, monkeypatch):
|
||||||
|
"""Test that Authorization header is set when JINA_API_KEY is available."""
|
||||||
|
captured_headers = {}
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
captured_headers.update(kwargs.get("headers", {}))
|
||||||
|
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
monkeypatch.setenv("JINA_API_KEY", "test-key-123")
|
||||||
|
await jina_client.crawl("https://example.com")
|
||||||
|
assert captured_headers["Authorization"] == "Bearer test-key-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_warns_once_when_api_key_missing(jina_client, monkeypatch, caplog):
|
||||||
|
"""Test that the missing API key warning is logged only once."""
|
||||||
|
jina_client_module._api_key_warned = False
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
monkeypatch.delenv("JINA_API_KEY", raising=False)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="deerflow.community.jina_ai.jina_client"):
|
||||||
|
await jina_client.crawl("https://example.com")
|
||||||
|
await jina_client.crawl("https://example.com")
|
||||||
|
|
||||||
|
warning_count = sum(1 for record in caplog.records if "Jina API key is not set" in record.message)
|
||||||
|
assert warning_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_crawl_no_auth_header_without_api_key(jina_client, monkeypatch):
|
||||||
|
"""Test that no Authorization header is set when JINA_API_KEY is not available."""
|
||||||
|
jina_client_module._api_key_warned = False
|
||||||
|
captured_headers = {}
|
||||||
|
|
||||||
|
async def mock_post(self, url, **kwargs):
|
||||||
|
captured_headers.update(kwargs.get("headers", {}))
|
||||||
|
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
monkeypatch.delenv("JINA_API_KEY", raising=False)
|
||||||
|
await jina_client.crawl("https://example.com")
|
||||||
|
assert "Authorization" not in captured_headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
|
||||||
|
"""Test that web_fetch_tool short-circuits and returns the error string when crawl fails."""
|
||||||
|
|
||||||
|
async def mock_crawl(self, url, **kwargs):
|
||||||
|
return "Error: Jina API returned status 429: Rate limited"
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_tool_config.return_value = None
|
||||||
|
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||||
|
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||||
|
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||||
|
assert result.startswith("Error:")
|
||||||
|
assert "429" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
|
||||||
|
"""Test that web_fetch_tool returns extracted markdown on successful crawl."""
|
||||||
|
|
||||||
|
async def mock_crawl(self, url, **kwargs):
|
||||||
|
return "<html><body><p>Hello world</p></body></html>"
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_tool_config.return_value = None
|
||||||
|
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
|
||||||
|
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||||
|
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||||
|
assert "Hello world" in result
|
||||||
|
assert not result.startswith("Error:")
|
||||||
Loading…
x
Reference in New Issue
Block a user