From 5ceb19f6f650c397569177fda5e5129768364f71 Mon Sep 17 00:00:00 2001 From: Markus Corazzione <83182424+corazzione@users.noreply.github.com> Date: Sun, 29 Mar 2026 20:41:18 -0300 Subject: [PATCH] fix(oauth): Harden Claude OAuth cache-control handling (#1583) --- .../deerflow/models/claude_provider.py | 38 +++++++++++++++- .../test_claude_provider_oauth_billing.py | 44 +++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/backend/packages/harness/deerflow/models/claude_provider.py b/backend/packages/harness/deerflow/models/claude_provider.py index 1e732bd79..2c0050313 100644 --- a/backend/packages/harness/deerflow/models/claude_provider.py +++ b/backend/packages/harness/deerflow/models/claude_provider.py @@ -27,6 +27,7 @@ from typing import Any import anthropic from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage +from pydantic import PrivateAttr logger = logging.getLogger(__name__) @@ -56,8 +57,8 @@ class ClaudeChatModel(ChatAnthropic): prompt_cache_size: int = 3 auto_thinking_budget: bool = True retry_max_attempts: int = MAX_RETRIES - _is_oauth: bool = False - _oauth_access_token: str = "" + _is_oauth: bool = PrivateAttr(default=False) + _oauth_access_token: str = PrivateAttr(default="") model_config = {"arbitrary_types_allowed": True} @@ -244,6 +245,39 @@ class ClaudeChatModel(ChatAnthropic): max_tokens = payload.get("max_tokens", 8192) thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO) + @staticmethod + def _strip_cache_control(payload: dict) -> None: + """Remove cache_control markers before OAuth requests reach Anthropic.""" + for section in ("system", "messages"): + items = payload.get(section) + if not isinstance(items, list): + continue + for item in items: + if not isinstance(item, dict): + continue + item.pop("cache_control", None) + content = item.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + block.pop("cache_control", None) + + tools = payload.get("tools") + if isinstance(tools, list): + for tool in tools: + if isinstance(tool, dict): + tool.pop("cache_control", None) + + def _create(self, payload: dict) -> Any: + if self._is_oauth: + self._strip_cache_control(payload) + return super()._create(payload) + + async def _acreate(self, payload: dict) -> Any: + if self._is_oauth: + self._strip_cache_control(payload) + return await super()._acreate(payload) + def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any: """Override with OAuth patching and retry logic.""" if self._is_oauth: diff --git a/backend/tests/test_claude_provider_oauth_billing.py b/backend/tests/test_claude_provider_oauth_billing.py index 9f9329bb1..9cb45e430 100644 --- a/backend/tests/test_claude_provider_oauth_billing.py +++ b/backend/tests/test_claude_provider_oauth_billing.py @@ -1,6 +1,8 @@ """Tests for ClaudeChatModel._apply_oauth_billing.""" +import asyncio import json +from unittest import mock import pytest @@ -108,3 +110,45 @@ def test_metadata_non_dict_replaced_with_dict(model): model._apply_oauth_billing(payload) assert isinstance(payload["metadata"], dict) assert "user_id" in payload["metadata"] + + +def test_sync_create_strips_cache_control_from_oauth_payload(model): + payload = { + "system": [{"type": "text", "text": "sys", "cache_control": {"type": "ephemeral"}}], + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "hi", "cache_control": {"type": "ephemeral"}}], + } + ], + "tools": [{"name": "demo", "input_schema": {"type": "object"}, "cache_control": {"type": "ephemeral"}}], + } + + with mock.patch.object(model._client.messages, "create", return_value=object()) as create: + model._create(payload) + + sent_payload = create.call_args.kwargs + assert "cache_control" not in sent_payload["system"][0] + assert "cache_control" not in sent_payload["messages"][0]["content"][0] + assert "cache_control" not in sent_payload["tools"][0] + + +def test_async_create_strips_cache_control_from_oauth_payload(model): + payload = { + "system": [{"type": "text", "text": "sys", "cache_control": {"type": "ephemeral"}}], + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "hi", "cache_control": {"type": "ephemeral"}}], + } + ], + "tools": [{"name": "demo", "input_schema": {"type": "object"}, "cache_control": {"type": "ephemeral"}}], + } + + with mock.patch.object(model._async_client.messages, "create", new=mock.AsyncMock(return_value=object())) as create: + asyncio.run(model._acreate(payload)) + + sent_payload = create.call_args.kwargs + assert "cache_control" not in sent_payload["system"][0] + assert "cache_control" not in sent_payload["messages"][0]["content"][0] + assert "cache_control" not in sent_payload["tools"][0]