mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 19:28:23 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
# Conflicts: # backend/tests/test_model_factory.py
This commit is contained in:
commit
a5831d3abf
@ -17,6 +17,7 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||
# NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai
|
||||
# MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io
|
||||
# VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible
|
||||
# FEISHU_APP_ID=your-feishu-app-id
|
||||
# FEISHU_APP_SECRET=your-feishu-app-secret
|
||||
|
||||
|
||||
14
README.md
14
README.md
@ -141,12 +141,26 @@ That prompt is intended for coding agents. It tells the agent to clone the repo
|
||||
api_key: $OPENAI_API_KEY
|
||||
use_responses_api: true
|
||||
output_version: responses/v1
|
||||
|
||||
- name: qwen3-32b-vllm
|
||||
display_name: Qwen3 32B (vLLM)
|
||||
use: deerflow.models.vllm_provider:VllmChatModel
|
||||
model: Qwen/Qwen3-32B
|
||||
api_key: $VLLM_API_KEY
|
||||
base_url: http://localhost:8000/v1
|
||||
supports_thinking: true
|
||||
when_thinking_enabled:
|
||||
extra_body:
|
||||
chat_template_kwargs:
|
||||
enable_thinking: true
|
||||
```
|
||||
|
||||
OpenRouter and similar OpenAI-compatible gateways should be configured with `langchain_openai:ChatOpenAI` plus `base_url`. If you prefer a provider-specific environment variable name, point `api_key` at that variable explicitly (for example `api_key: $OPENROUTER_API_KEY`).
|
||||
|
||||
To route OpenAI models through `/v1/responses`, keep using `langchain_openai:ChatOpenAI` and set `use_responses_api: true` with `output_version: responses/v1`.
|
||||
|
||||
For vLLM 0.19.0, use `deerflow.models.vllm_provider:VllmChatModel`. For Qwen-style reasoning models, DeerFlow toggles reasoning with `extra_body.chat_template_kwargs.enable_thinking` and preserves vLLM's non-standard `reasoning` field across multi-turn tool-call conversations. Legacy `thinking` configs are normalized automatically for backward compatibility. Reasoning models may also require the server to be started with `--reasoning-parser ...`. If your local vLLM deployment accepts any non-empty API key, you can still set `VLLM_API_KEY` to a placeholder value.
|
||||
|
||||
CLI-backed provider examples:
|
||||
|
||||
```yaml
|
||||
|
||||
@ -293,10 +293,17 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
|
||||
- `create_chat_model(name, thinking_enabled)` instantiates LLM from config via reflection
|
||||
- Supports `thinking_enabled` flag with per-model `when_thinking_enabled` overrides
|
||||
- Supports vLLM-style thinking toggles via `when_thinking_enabled.extra_body.chat_template_kwargs.enable_thinking` for Qwen reasoning models, while normalizing legacy `thinking` configs for backward compatibility
|
||||
- Supports `supports_vision` flag for image understanding models
|
||||
- Config values starting with `$` resolved as environment variables
|
||||
- Missing provider modules surface actionable install hints from reflection resolvers (for example `uv add langchain-google-genai`)
|
||||
|
||||
### vLLM Provider (`packages/harness/deerflow/models/vllm_provider.py`)
|
||||
|
||||
- `VllmChatModel` subclasses `langchain_openai:ChatOpenAI` for vLLM 0.19.0 OpenAI-compatible endpoints
|
||||
- Preserves vLLM's non-standard assistant `reasoning` field on full responses, streaming deltas, and follow-up tool-call turns
|
||||
- Designed for configs that enable thinking through `extra_body.chat_template_kwargs.enable_thinking` on vLLM 0.19.0 Qwen reasoning models, while accepting the older `thinking` alias
|
||||
|
||||
### IM Channels System (`app/channels/`)
|
||||
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via the LangGraph Server.
|
||||
@ -365,6 +372,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
||||
|
||||
**`config.yaml`** key sections:
|
||||
- `models[]` - LLM configs with `use` class path, `supports_thinking`, `supports_vision`, provider-specific fields
|
||||
- vLLM reasoning models should use `deerflow.models.vllm_provider:VllmChatModel`; for Qwen-style parsers prefer `when_thinking_enabled.extra_body.chat_template_kwargs.enable_thinking`, and DeerFlow will also normalize the older `thinking` alias
|
||||
- `tools[]` - Tool configs with `use` variable path and `group`
|
||||
- `tool_groups[]` - Logical groupings for tools
|
||||
- `sandbox.use` - Sandbox provider class path
|
||||
|
||||
@ -51,6 +51,7 @@ async def stateless_stream(body: RunCreateRequest, request: Request) -> Streamin
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -119,8 +119,9 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
# LangGraph Platform includes run metadata in this header.
|
||||
# The SDK's _get_run_metadata_from_response() parses it.
|
||||
"Content-Location": (f"/api/threads/{thread_id}/runs/{record.run_id}/stream?thread_id={thread_id}&run_id={record.run_id}"),
|
||||
# The SDK uses a greedy regex to extract the run id from this path,
|
||||
# so it must point at the canonical run resource without extra suffixes.
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -375,8 +375,9 @@ async def sse_consumer(
|
||||
- ``cancel``: abort the background task on client disconnect.
|
||||
- ``continue``: let the task run; events are discarded.
|
||||
"""
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
try:
|
||||
async for entry in bridge.subscribe(record.run_id):
|
||||
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
|
||||
@ -1,22 +1,19 @@
|
||||
"""Middleware for injecting image details into conversation before LLM call."""
|
||||
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ViewedImageData
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ViewImageMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
viewed_images: NotRequired[dict[str, ViewedImageData] | None]
|
||||
class ViewImageMiddlewareState(ThreadState):
|
||||
"""Reuse the thread state so reducer-backed keys keep their annotations."""
|
||||
|
||||
|
||||
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
||||
|
||||
@ -74,5 +74,10 @@ class SandboxConfig(BaseModel):
|
||||
ge=0,
|
||||
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
|
||||
)
|
||||
ls_output_max_chars: int = Field(
|
||||
default=20000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@ -9,6 +9,27 @@ from deerflow.tracing import build_tracing_callbacks
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _deep_merge_dicts(base: dict | None, override: dict) -> dict:
|
||||
"""Recursively merge two dictionaries without mutating the inputs."""
|
||||
merged = dict(base or {})
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge_dicts(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
|
||||
"""Build the disable payload for vLLM/Qwen chat template kwargs."""
|
||||
disable_kwargs: dict[str, bool] = {}
|
||||
if "thinking" in chat_template_kwargs:
|
||||
disable_kwargs["thinking"] = False
|
||||
if "enable_thinking" in chat_template_kwargs:
|
||||
disable_kwargs["enable_thinking"] = False
|
||||
return disable_kwargs
|
||||
|
||||
|
||||
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
|
||||
"""Create a chat model instance from the config.
|
||||
|
||||
@ -54,13 +75,23 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
if not thinking_enabled and has_thinking_settings:
|
||||
if effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
|
||||
# OpenAI-compatible gateway: thinking is nested under extra_body
|
||||
kwargs.update({"extra_body": {"thinking": {"type": "disabled"}}})
|
||||
kwargs.update({"reasoning_effort": "minimal"})
|
||||
model_settings_from_config["extra_body"] = _deep_merge_dicts(
|
||||
model_settings_from_config.get("extra_body"),
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
model_settings_from_config["reasoning_effort"] = "minimal"
|
||||
elif disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {}):
|
||||
# vLLM uses chat template kwargs to switch thinking on/off.
|
||||
model_settings_from_config["extra_body"] = _deep_merge_dicts(
|
||||
model_settings_from_config.get("extra_body"),
|
||||
{"chat_template_kwargs": disable_chat_template_kwargs},
|
||||
)
|
||||
elif effective_wte.get("thinking", {}).get("type"):
|
||||
# Native langchain_anthropic: thinking is a direct constructor parameter
|
||||
kwargs.update({"thinking": {"type": "disabled"}})
|
||||
if not model_config.supports_reasoning_effort and "reasoning_effort" in kwargs:
|
||||
del kwargs["reasoning_effort"]
|
||||
model_settings_from_config["thinking"] = {"type": "disabled"}
|
||||
if not model_config.supports_reasoning_effort:
|
||||
kwargs.pop("reasoning_effort", None)
|
||||
model_settings_from_config.pop("reasoning_effort", None)
|
||||
|
||||
# For Codex Responses API models: map thinking mode to reasoning_effort
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
258
backend/packages/harness/deerflow/models/vllm_provider.py
Normal file
258
backend/packages/harness/deerflow/models/vllm_provider.py
Normal file
@ -0,0 +1,258 @@
|
||||
"""Custom vLLM provider built on top of LangChain ChatOpenAI.
|
||||
|
||||
vLLM 0.19.0 exposes reasoning models through an OpenAI-compatible API, but
|
||||
LangChain's default OpenAI adapter drops the non-standard ``reasoning`` field
|
||||
from assistant messages and streaming deltas. That breaks interleaved
|
||||
thinking/tool-call flows because vLLM expects the assistant's prior reasoning to
|
||||
be echoed back on subsequent turns.
|
||||
|
||||
This provider preserves ``reasoning`` on:
|
||||
- non-streaming responses
|
||||
- streaming deltas
|
||||
- multi-turn request payloads
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import _create_usage_metadata
|
||||
|
||||
|
||||
def _normalize_vllm_chat_template_kwargs(payload: dict[str, Any]) -> None:
|
||||
"""Map DeerFlow's legacy ``thinking`` toggle to vLLM/Qwen's ``enable_thinking``.
|
||||
|
||||
DeerFlow originally documented ``extra_body.chat_template_kwargs.thinking``
|
||||
for vLLM, but vLLM 0.19.0's Qwen reasoning parser reads
|
||||
``chat_template_kwargs.enable_thinking``. Normalize the payload just before
|
||||
it is sent so existing configs keep working and flash mode can truly
|
||||
disable reasoning.
|
||||
"""
|
||||
extra_body = payload.get("extra_body")
|
||||
if not isinstance(extra_body, dict):
|
||||
return
|
||||
|
||||
chat_template_kwargs = extra_body.get("chat_template_kwargs")
|
||||
if not isinstance(chat_template_kwargs, dict):
|
||||
return
|
||||
|
||||
if "thinking" not in chat_template_kwargs:
|
||||
return
|
||||
|
||||
normalized_chat_template_kwargs = dict(chat_template_kwargs)
|
||||
normalized_chat_template_kwargs.setdefault("enable_thinking", normalized_chat_template_kwargs["thinking"])
|
||||
normalized_chat_template_kwargs.pop("thinking", None)
|
||||
extra_body["chat_template_kwargs"] = normalized_chat_template_kwargs
|
||||
|
||||
|
||||
def _reasoning_to_text(reasoning: Any) -> str:
|
||||
"""Best-effort extraction of readable reasoning text from vLLM payloads."""
|
||||
if isinstance(reasoning, str):
|
||||
return reasoning
|
||||
|
||||
if isinstance(reasoning, list):
|
||||
parts = [_reasoning_to_text(item) for item in reasoning]
|
||||
return "".join(part for part in parts if part)
|
||||
|
||||
if isinstance(reasoning, dict):
|
||||
for key in ("text", "content", "reasoning"):
|
||||
value = reasoning.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is not None:
|
||||
text = _reasoning_to_text(value)
|
||||
if text:
|
||||
return text
|
||||
try:
|
||||
return json.dumps(reasoning, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(reasoning)
|
||||
|
||||
try:
|
||||
return json.dumps(reasoning, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(reasoning)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk_with_reasoning(_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]) -> BaseMessageChunk:
|
||||
"""Convert a streaming delta to a LangChain message chunk while preserving reasoning."""
|
||||
id_ = _dict.get("id")
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: dict[str, Any] = {}
|
||||
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
|
||||
reasoning = _dict.get("reasoning")
|
||||
if reasoning is not None:
|
||||
additional_kwargs["reasoning"] = reasoning
|
||||
reasoning_text = _reasoning_to_text(reasoning)
|
||||
if reasoning_text:
|
||||
additional_kwargs["reasoning_content"] = reasoning_text
|
||||
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content, id=id_)
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=id_,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
)
|
||||
if role in ("system", "developer") or default_class == SystemMessageChunk:
|
||||
role_kwargs = {"__openai_role__": "developer"} if role == "developer" else {}
|
||||
return SystemMessageChunk(content=content, id=id_, additional_kwargs=role_kwargs)
|
||||
if role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
|
||||
if role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"], id=id_)
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role, id=id_) # type: ignore[arg-type]
|
||||
return default_class(content=content, id=id_) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _restore_reasoning_field(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
|
||||
"""Re-inject vLLM reasoning onto outgoing assistant messages."""
|
||||
reasoning = orig_msg.additional_kwargs.get("reasoning")
|
||||
if reasoning is None:
|
||||
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning is not None:
|
||||
payload_msg["reasoning"] = reasoning
|
||||
|
||||
|
||||
class VllmChatModel(ChatOpenAI):
|
||||
"""ChatOpenAI variant that preserves vLLM reasoning fields across turns."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vllm-openai-compatible"
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Restore assistant reasoning in request payloads for interleaved thinking."""
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
_normalize_vllm_chat_template_kwargs(payload)
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_reasoning_field(payload_msg, orig_msg)
|
||||
else:
|
||||
ai_messages = [message for message in original_messages if isinstance(message, AIMessage)]
|
||||
assistant_payloads = [message for message in payload_messages if message.get("role") == "assistant"]
|
||||
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_reasoning_field(payload_msg, ai_msg)
|
||||
|
||||
return payload
|
||||
|
||||
def _create_chat_result(self, response: dict | openai.BaseModel, generation_info: dict | None = None) -> ChatResult:
|
||||
"""Preserve vLLM reasoning on non-streaming responses."""
|
||||
result = super()._create_chat_result(response, generation_info=generation_info)
|
||||
response_dict = response if isinstance(response, dict) else response.model_dump()
|
||||
|
||||
for generation, choice in zip(result.generations, response_dict.get("choices", [])):
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
continue
|
||||
message = generation.message
|
||||
if not isinstance(message, AIMessage):
|
||||
continue
|
||||
reasoning = choice.get("message", {}).get("reasoning")
|
||||
if reasoning is None:
|
||||
continue
|
||||
message.additional_kwargs["reasoning"] = reasoning
|
||||
reasoning_text = _reasoning_to_text(reasoning)
|
||||
if reasoning_text:
|
||||
message.additional_kwargs["reasoning_content"] = reasoning_text
|
||||
|
||||
return result
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: dict | None,
|
||||
) -> ChatGenerationChunk | None:
|
||||
"""Preserve vLLM reasoning on streaming deltas."""
|
||||
if chunk.get("type") == "content.delta":
|
||||
return None
|
||||
|
||||
token_usage = chunk.get("usage")
|
||||
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
|
||||
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
|
||||
|
||||
if len(choices) == 0:
|
||||
generation_chunk = ChatGenerationChunk(message=default_chunk_class(content="", usage_metadata=usage_metadata), generation_info=base_generation_info)
|
||||
if self.output_version == "v1":
|
||||
generation_chunk.message.content = []
|
||||
generation_chunk.message.response_metadata["output_version"] = "v1"
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk_with_reasoning(choice["delta"], default_chunk_class)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
if service_tier := chunk.get("service_tier"):
|
||||
generation_info["service_tier"] = service_tier
|
||||
|
||||
if logprobs := choice.get("logprobs"):
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
message_chunk.response_metadata["model_provider"] = "openai"
|
||||
return ChatGenerationChunk(message=message_chunk, generation_info=generation_info or None)
|
||||
@ -1,4 +1,4 @@
|
||||
"""In-memory stream bridge backed by :class:`asyncio.Queue`."""
|
||||
"""In-memory stream bridge backed by an in-process event log."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -6,35 +6,41 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .base import END_SENTINEL, HEARTBEAT_SENTINEL, StreamBridge, StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PUBLISH_TIMEOUT = 30.0 # seconds to wait when queue is full
|
||||
|
||||
@dataclass
|
||||
class _RunStream:
|
||||
events: list[StreamEvent] = field(default_factory=list)
|
||||
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||
ended: bool = False
|
||||
start_offset: int = 0
|
||||
|
||||
|
||||
class MemoryStreamBridge(StreamBridge):
|
||||
"""Per-run ``asyncio.Queue`` implementation.
|
||||
"""Per-run in-memory event log implementation.
|
||||
|
||||
Each *run_id* gets its own queue on first :meth:`publish` call.
|
||||
Events are retained for a bounded time window per run so late subscribers
|
||||
and reconnecting clients can replay buffered events from ``Last-Event-ID``.
|
||||
"""
|
||||
|
||||
def __init__(self, *, queue_maxsize: int = 256) -> None:
|
||||
self._maxsize = queue_maxsize
|
||||
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
|
||||
self._streams: dict[str, _RunStream] = {}
|
||||
self._counters: dict[str, int] = {}
|
||||
self._dropped_counts: dict[str, int] = {}
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
def _get_or_create_queue(self, run_id: str) -> asyncio.Queue[StreamEvent]:
|
||||
if run_id not in self._queues:
|
||||
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
|
||||
def _get_or_create_stream(self, run_id: str) -> _RunStream:
|
||||
if run_id not in self._streams:
|
||||
self._streams[run_id] = _RunStream()
|
||||
self._counters[run_id] = 0
|
||||
self._dropped_counts[run_id] = 0
|
||||
return self._queues[run_id]
|
||||
return self._streams[run_id]
|
||||
|
||||
def _next_id(self, run_id: str) -> str:
|
||||
self._counters[run_id] = self._counters.get(run_id, 0) + 1
|
||||
@ -42,49 +48,39 @@ class MemoryStreamBridge(StreamBridge):
|
||||
seq = self._counters[run_id] - 1
|
||||
return f"{ts}-{seq}"
|
||||
|
||||
def _resolve_start_offset(self, stream: _RunStream, last_event_id: str | None) -> int:
|
||||
if last_event_id is None:
|
||||
return stream.start_offset
|
||||
|
||||
for index, entry in enumerate(stream.events):
|
||||
if entry.id == last_event_id:
|
||||
return stream.start_offset + index + 1
|
||||
|
||||
if stream.events:
|
||||
logger.warning(
|
||||
"last_event_id=%s not found in retained buffer; replaying from earliest retained event",
|
||||
last_event_id,
|
||||
)
|
||||
return stream.start_offset
|
||||
|
||||
# -- StreamBridge API ------------------------------------------------------
|
||||
|
||||
async def publish(self, run_id: str, event: str, data: Any) -> None:
|
||||
queue = self._get_or_create_queue(run_id)
|
||||
stream = self._get_or_create_stream(run_id)
|
||||
entry = StreamEvent(id=self._next_id(run_id), event=event, data=data)
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
|
||||
run_id,
|
||||
event,
|
||||
self._dropped_counts[run_id],
|
||||
)
|
||||
async with stream.condition:
|
||||
stream.events.append(entry)
|
||||
if len(stream.events) > self._maxsize:
|
||||
overflow = len(stream.events) - self._maxsize
|
||||
del stream.events[:overflow]
|
||||
stream.start_offset += overflow
|
||||
stream.condition.notify_all()
|
||||
|
||||
async def publish_end(self, run_id: str) -> None:
|
||||
queue = self._get_or_create_queue(run_id)
|
||||
|
||||
# END sentinel is critical — it is the only signal that allows
|
||||
# subscribers to terminate. If the queue is full we evict the
|
||||
# oldest *regular* events to make room rather than dropping END,
|
||||
# which would cause the SSE connection to hang forever and leak
|
||||
# the queue/counter resources for this run_id.
|
||||
if queue.full():
|
||||
evicted = 0
|
||||
while queue.full():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
evicted += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break # pragma: no cover – defensive
|
||||
if evicted:
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
|
||||
run_id,
|
||||
evicted,
|
||||
)
|
||||
|
||||
# After eviction the queue is guaranteed to have space, so a
|
||||
# simple non-blocking put is safe. We still use put() (which
|
||||
# blocks until space is available) as a defensive measure.
|
||||
await queue.put(END_SENTINEL)
|
||||
stream = self._get_or_create_stream(run_id)
|
||||
async with stream.condition:
|
||||
stream.ended = True
|
||||
stream.condition.notify_all()
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
@ -93,16 +89,34 @@ class MemoryStreamBridge(StreamBridge):
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
if last_event_id is not None:
|
||||
logger.debug("last_event_id=%s accepted but ignored (memory bridge has no replay)", last_event_id)
|
||||
stream = self._get_or_create_stream(run_id)
|
||||
async with stream.condition:
|
||||
next_offset = self._resolve_start_offset(stream, last_event_id)
|
||||
|
||||
queue = self._get_or_create_queue(run_id)
|
||||
while True:
|
||||
try:
|
||||
entry = await asyncio.wait_for(queue.get(), timeout=heartbeat_interval)
|
||||
except TimeoutError:
|
||||
yield HEARTBEAT_SENTINEL
|
||||
continue
|
||||
async with stream.condition:
|
||||
if next_offset < stream.start_offset:
|
||||
logger.warning(
|
||||
"subscriber for run %s fell behind retained buffer; resuming from offset %s",
|
||||
run_id,
|
||||
stream.start_offset,
|
||||
)
|
||||
next_offset = stream.start_offset
|
||||
|
||||
local_index = next_offset - stream.start_offset
|
||||
if 0 <= local_index < len(stream.events):
|
||||
entry = stream.events[local_index]
|
||||
next_offset += 1
|
||||
elif stream.ended:
|
||||
entry = END_SENTINEL
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(stream.condition.wait(), timeout=heartbeat_interval)
|
||||
except TimeoutError:
|
||||
entry = HEARTBEAT_SENTINEL
|
||||
else:
|
||||
continue
|
||||
|
||||
if entry is END_SENTINEL:
|
||||
yield END_SENTINEL
|
||||
return
|
||||
@ -111,20 +125,9 @@ class MemoryStreamBridge(StreamBridge):
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
self._queues.pop(run_id, None)
|
||||
self._streams.pop(run_id, None)
|
||||
self._counters.pop(run_id, None)
|
||||
self._dropped_counts.pop(run_id, None)
|
||||
|
||||
async def close(self) -> None:
|
||||
self._queues.clear()
|
||||
self._streams.clear()
|
||||
self._counters.clear()
|
||||
self._dropped_counts.clear()
|
||||
|
||||
def dropped_count(self, run_id: str) -> int:
|
||||
"""Return the number of events dropped for *run_id*."""
|
||||
return self._dropped_counts.get(run_id, 0)
|
||||
|
||||
@property
|
||||
def dropped_total(self) -> int:
|
||||
"""Return the total number of events dropped across all runs."""
|
||||
return sum(self._dropped_counts.values())
|
||||
|
||||
@ -963,6 +963,29 @@ def _truncate_read_file_output(output: str, max_chars: int) -> str:
|
||||
return f"{output[:kept]}{marker}"
|
||||
|
||||
|
||||
def _truncate_ls_output(output: str, max_chars: int) -> str:
|
||||
"""Head-truncate ls output, preserving the beginning of the listing.
|
||||
|
||||
Directory listings are read top-to-bottom; the head shows the most
|
||||
relevant structure.
|
||||
|
||||
The returned string (including the truncation marker) is guaranteed to be
|
||||
no longer than max_chars characters. Pass max_chars=0 to disable truncation
|
||||
and return the full output unchanged.
|
||||
"""
|
||||
if max_chars == 0:
|
||||
return output
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
total = len(output)
|
||||
marker_max_len = len(f"\n... [truncated: showing first {total} of {total} chars. Use a more specific path to see fewer results] ...")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return output[:max_chars]
|
||||
marker = f"\n... [truncated: showing first {kept} of {total} chars. Use a more specific path to see fewer results] ..."
|
||||
return f"{output[:kept]}{marker}"
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
@ -1037,7 +1060,15 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
children = sandbox.list_dir(path)
|
||||
if not children:
|
||||
return "(empty)"
|
||||
return "\n".join(children)
|
||||
output = "\n".join(children)
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_ls_output(output, max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
"""Tests for create_deerflow_agent SDK entry point."""
|
||||
|
||||
from typing import get_type_hints
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.factory import create_deerflow_agent
|
||||
from deerflow.agents.features import Next, Prev, RuntimeFeatures
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
|
||||
def _make_mock_model():
|
||||
@ -127,6 +130,13 @@ def test_vision_injects_view_image_tool(mock_create_agent):
|
||||
assert "view_image" in tool_names
|
||||
|
||||
|
||||
def test_view_image_middleware_preserves_viewed_images_reducer():
|
||||
middleware_hints = get_type_hints(ViewImageMiddleware.state_schema, include_extras=True)
|
||||
thread_hints = get_type_hints(ThreadState, include_extras=True)
|
||||
|
||||
assert middleware_hints["viewed_images"] == thread_hints["viewed_images"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Subagent feature auto-injects task_tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -604,6 +604,63 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
|
||||
assert "max_tokens" not in FakeChatModel.captured_kwargs
|
||||
|
||||
|
||||
def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
wte = {"extra_body": {"chat_template_kwargs": {"thinking": True}}}
|
||||
model = _make_model(
|
||||
"vllm-qwen",
|
||||
use="deerflow.models.vllm_provider:VllmChatModel",
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
|
||||
|
||||
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
|
||||
def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
wte = {"extra_body": {"chat_template_kwargs": {"enable_thinking": True}}}
|
||||
model = _make_model(
|
||||
"vllm-qwen-enable",
|
||||
use="deerflow.models.vllm_provider:VllmChatModel",
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
|
||||
|
||||
assert captured.get("extra_body") == {
|
||||
"top_k": 20,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_usage injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
||||
@ -44,7 +45,7 @@ async def test_publish_subscribe(bridge: MemoryStreamBridge):
|
||||
async def test_heartbeat(bridge: MemoryStreamBridge):
|
||||
"""When no events arrive within the heartbeat interval, yield a heartbeat."""
|
||||
run_id = "run-heartbeat"
|
||||
bridge._get_or_create_queue(run_id) # ensure queue exists
|
||||
bridge._get_or_create_stream(run_id) # ensure stream exists
|
||||
|
||||
received = []
|
||||
|
||||
@ -61,37 +62,35 @@ async def test_heartbeat(bridge: MemoryStreamBridge):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup(bridge: MemoryStreamBridge):
|
||||
"""After cleanup, the run's queue is removed."""
|
||||
"""After cleanup, the run's stream/event log is removed."""
|
||||
run_id = "run-cleanup"
|
||||
await bridge.publish(run_id, "test", {})
|
||||
assert run_id in bridge._queues
|
||||
assert run_id in bridge._streams
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._queues
|
||||
assert run_id not in bridge._streams
|
||||
assert run_id not in bridge._counters
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_backpressure():
|
||||
"""With maxsize=1, publish should not block forever."""
|
||||
async def test_history_is_bounded():
|
||||
"""Retained history should be bounded by queue_maxsize."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-bp"
|
||||
|
||||
await bridge.publish(run_id, "first", {})
|
||||
await bridge.publish(run_id, "second", {})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# Second publish should either succeed after queue drains or warn+drop
|
||||
# It should not hang indefinitely
|
||||
async def publish_second():
|
||||
await bridge.publish(run_id, "second", {})
|
||||
received = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
# Give it a generous timeout — the publish timeout is 30s but we don't
|
||||
# want to wait that long in tests. Instead, drain the queue first.
|
||||
async def drain():
|
||||
await asyncio.sleep(0.05)
|
||||
bridge._queues[run_id].get_nowait()
|
||||
|
||||
await asyncio.gather(publish_second(), drain())
|
||||
assert bridge._queues[run_id].qsize() == 1
|
||||
assert len(received) == 2
|
||||
assert received[0].event == "second"
|
||||
assert received[1] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -140,54 +139,116 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
|
||||
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_subscribe_replays_after_last_event_id(bridge: MemoryStreamBridge):
|
||||
"""Reconnect should replay buffered events after the provided Last-Event-ID."""
|
||||
run_id = "run-replay"
|
||||
await bridge.publish(run_id, "metadata", {"run_id": run_id})
|
||||
await bridge.publish(run_id, "values", {"step": 1})
|
||||
await bridge.publish(run_id, "updates", {"step": 2})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
first_pass = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
||||
first_pass.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
received = []
|
||||
async for entry in bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=first_pass[0].id,
|
||||
heartbeat_interval=1.0,
|
||||
):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert [entry.event for entry in received[:-1]] == ["values", "updates"]
|
||||
assert received[-1] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_slow_subscriber_does_not_skip_after_buffer_trim():
|
||||
"""A slow subscriber should continue from the correct absolute offset."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-slow-subscriber"
|
||||
await bridge.publish(run_id, "e1", {"step": 1})
|
||||
await bridge.publish(run_id, "e2", {"step": 2})
|
||||
|
||||
stream = bridge._streams[run_id]
|
||||
e1_id = stream.events[0].id
|
||||
assert stream.start_offset == 0
|
||||
|
||||
await bridge.publish(run_id, "e3", {"step": 3}) # trims e1
|
||||
assert stream.start_offset == 1
|
||||
assert [entry.event for entry in stream.events] == ["e2", "e3"]
|
||||
|
||||
resumed_after_e1 = []
|
||||
async for entry in bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=e1_id,
|
||||
heartbeat_interval=1.0,
|
||||
):
|
||||
resumed_after_e1.append(entry)
|
||||
if len(resumed_after_e1) == 2:
|
||||
break
|
||||
|
||||
assert [entry.event for entry in resumed_after_e1] == ["e2", "e3"]
|
||||
e2_id = resumed_after_e1[0].id
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
received = []
|
||||
async for entry in bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=e2_id,
|
||||
heartbeat_interval=1.0,
|
||||
):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert [entry.event for entry in received[:-1]] == ["e3"]
|
||||
assert received[-1] is END_SENTINEL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# END sentinel guarantee tests
|
||||
# Stream termination tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_delivered_when_queue_full():
|
||||
"""END sentinel must always be delivered, even when the queue is completely full.
|
||||
|
||||
This is the critical regression test for the bug where publish_end()
|
||||
would silently drop the END sentinel when the queue was full, causing
|
||||
subscribe() to hang forever and leaking resources.
|
||||
"""
|
||||
async def test_publish_end_terminates_even_when_history_is_full():
|
||||
"""publish_end() should terminate subscribers without mutating retained history."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-end-full"
|
||||
run_id = "run-end-history-full"
|
||||
|
||||
# Fill the queue to capacity
|
||||
await bridge.publish(run_id, "event-1", {"n": 1})
|
||||
await bridge.publish(run_id, "event-2", {"n": 2})
|
||||
assert bridge._queues[run_id].full()
|
||||
stream = bridge._streams[run_id]
|
||||
assert [entry.event for entry in stream.events] == ["event-1", "event-2"]
|
||||
|
||||
# publish_end should succeed by evicting old events
|
||||
await bridge.publish_end(run_id)
|
||||
assert [entry.event for entry in stream.events] == ["event-1", "event-2"]
|
||||
|
||||
# Subscriber must receive END_SENTINEL
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
|
||||
assert [entry.event for entry in events[:-1]] == ["event-1", "event-2"]
|
||||
assert events[-1] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_evicts_oldest_events():
|
||||
"""When queue is full, publish_end evicts the oldest events to make room."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-evict"
|
||||
|
||||
# Fill queue with one event
|
||||
await bridge.publish(run_id, "will-be-evicted", {})
|
||||
assert bridge._queues[run_id].full()
|
||||
|
||||
# publish_end must succeed
|
||||
async def test_publish_end_without_history_yields_end_immediately():
|
||||
"""Subscribers should still receive END when a run completes without events."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-end-empty"
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# The only event we should get is END_SENTINEL (the regular event was evicted)
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
@ -199,8 +260,8 @@ async def test_end_sentinel_evicts_oldest_events():
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_no_eviction_when_space_available():
|
||||
"""When queue has space, publish_end should not evict anything."""
|
||||
async def test_publish_end_preserves_history_when_space_available():
|
||||
"""When history has spare capacity, publish_end should preserve prior events."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=10)
|
||||
run_id = "run-no-evict"
|
||||
|
||||
@ -244,87 +305,23 @@ async def test_concurrent_tasks_end_sentinel():
|
||||
return events
|
||||
return events # pragma: no cover
|
||||
|
||||
# Run producers and consumers concurrently
|
||||
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
|
||||
producers = [producer(rid) for rid in run_ids]
|
||||
consumers = [consumer(rid) for rid in run_ids]
|
||||
results: dict[str, list] = {}
|
||||
|
||||
# Start consumers first, then producers
|
||||
consumer_tasks = [asyncio.create_task(c) for c in consumers]
|
||||
await asyncio.gather(*producers)
|
||||
async def consume_into(run_id: str) -> None:
|
||||
results[run_id] = await consumer(run_id)
|
||||
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*consumer_tasks),
|
||||
timeout=10.0,
|
||||
)
|
||||
with anyio.fail_after(10):
|
||||
async with anyio.create_task_group() as task_group:
|
||||
for run_id in run_ids:
|
||||
task_group.start_soon(consume_into, run_id)
|
||||
await anyio.sleep(0)
|
||||
for run_id in run_ids:
|
||||
task_group.start_soon(producer, run_id)
|
||||
|
||||
for i, events in enumerate(results):
|
||||
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drop counter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_count_tracking():
|
||||
"""Dropped events should be tracked per run_id."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-drop-count"
|
||||
|
||||
# Fill the queue
|
||||
await bridge.publish(run_id, "first", {})
|
||||
|
||||
# This publish will time out and be dropped (we patch timeout to be instant)
|
||||
# Instead, we verify the counter after publish_end eviction
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# dropped_count tracks publish() drops, not publish_end evictions
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
# cleanup should also clear the counter
|
||||
await bridge.cleanup(run_id)
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_total():
|
||||
"""dropped_total should sum across all runs."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
|
||||
# No drops yet
|
||||
assert bridge.dropped_total == 0
|
||||
|
||||
# Manually set some counts to verify the property
|
||||
bridge._dropped_counts["run-a"] = 3
|
||||
bridge._dropped_counts["run-b"] = 7
|
||||
assert bridge.dropped_total == 10
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_clears_dropped_counts():
|
||||
"""cleanup() should clear the dropped counter for the run."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
run_id = "run-cleanup-drops"
|
||||
|
||||
bridge._get_or_create_queue(run_id)
|
||||
bridge._dropped_counts[run_id] = 5
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._dropped_counts
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_clears_dropped_counts():
|
||||
"""close() should clear all dropped counters."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
bridge._dropped_counts["run-x"] = 10
|
||||
bridge._dropped_counts["run-y"] = 20
|
||||
|
||||
await bridge.close()
|
||||
assert bridge.dropped_total == 0
|
||||
assert len(bridge._dropped_counts) == 0
|
||||
for run_id in run_ids:
|
||||
events = results[run_id]
|
||||
assert events[-1] is END_SENTINEL, f"Run {run_id} did not receive END sentinel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
These functions truncate long tool outputs to prevent context window overflow.
|
||||
- _truncate_bash_output: middle-truncation (head + tail), for bash tool
|
||||
- _truncate_read_file_output: head-truncation, for read_file tool
|
||||
- _truncate_ls_output: head-truncation, for ls tool
|
||||
"""
|
||||
|
||||
from deerflow.sandbox.tools import _truncate_bash_output, _truncate_read_file_output
|
||||
from deerflow.sandbox.tools import _truncate_bash_output, _truncate_ls_output, _truncate_read_file_output
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_bash_output
|
||||
@ -159,3 +160,71 @@ class TestTruncateReadFileOutput:
|
||||
for max_chars in [100, 1000, 5000, 20000, 49999]:
|
||||
result = _truncate_read_file_output(output, max_chars)
|
||||
assert len(result) <= max_chars, f"failed for max_chars={max_chars}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_ls_output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncateLsOutput:
|
||||
def test_short_output_returned_unchanged(self):
|
||||
output = "dir1\ndir2\nfile1.txt"
|
||||
assert _truncate_ls_output(output, 20000) == output
|
||||
|
||||
def test_output_equal_to_limit_returned_unchanged(self):
|
||||
output = "X" * 20000
|
||||
assert _truncate_ls_output(output, 20000) == output
|
||||
|
||||
def test_long_output_is_truncated(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(5000))
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert len(result) < len(output)
|
||||
|
||||
def test_result_never_exceeds_max_chars(self):
|
||||
output = "\n".join(f"subdir/file_{i}.txt" for i in range(5000))
|
||||
max_chars = 20000
|
||||
result = _truncate_ls_output(output, max_chars)
|
||||
assert len(result) <= max_chars
|
||||
|
||||
def test_head_is_preserved(self):
|
||||
head = "first_dir\nsecond_dir\n"
|
||||
output = head + "\n".join(f"file_{i}" for i in range(5000))
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert result.startswith(head)
|
||||
|
||||
def test_truncation_marker_present(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(5000))
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "[truncated:" in result
|
||||
assert "showing first" in result
|
||||
|
||||
def test_total_chars_reported_correctly(self):
|
||||
output = "X" * 30000
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "of 30000 chars" in result
|
||||
|
||||
def test_hint_suggests_specific_path(self):
|
||||
output = "X" * 30000
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "Use a more specific path" in result
|
||||
|
||||
def test_max_chars_zero_disables_truncation(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(10000))
|
||||
assert _truncate_ls_output(output, 0) == output
|
||||
|
||||
def test_tail_is_not_preserved(self):
|
||||
output = "H" * 20000 + "TAIL_SHOULD_NOT_APPEAR"
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "TAIL_SHOULD_NOT_APPEAR" not in result
|
||||
|
||||
def test_small_max_chars_does_not_crash(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(100))
|
||||
result = _truncate_ls_output(output, 10)
|
||||
assert len(result) <= 10
|
||||
|
||||
def test_result_never_exceeds_max_chars_various_sizes(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(5000))
|
||||
for max_chars in [100, 1000, 5000, 20000, len(output) - 1]:
|
||||
result = _truncate_ls_output(output, max_chars)
|
||||
assert len(result) <= max_chars, f"failed for max_chars={max_chars}"
|
||||
|
||||
138
backend/tests/test_vllm_provider.py
Normal file
138
backend/tests/test_vllm_provider.py
Normal file
@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
from deerflow.models.vllm_provider import VllmChatModel
|
||||
|
||||
|
||||
def _make_model() -> VllmChatModel:
|
||||
return VllmChatModel(
|
||||
model="Qwen/QwQ-32B",
|
||||
api_key="dummy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
|
||||
|
||||
def test_vllm_provider_restores_reasoning_in_request_payload():
|
||||
model = _make_model()
|
||||
payload = model._get_request_payload(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "bash", "args": {"cmd": "pwd"}, "id": "tool-1", "type": "tool_call"}],
|
||||
additional_kwargs={"reasoning": "Need to inspect the workspace first."},
|
||||
),
|
||||
HumanMessage(content="Continue"),
|
||||
]
|
||||
)
|
||||
|
||||
assistant_message = payload["messages"][0]
|
||||
assert assistant_message["role"] == "assistant"
|
||||
assert assistant_message["reasoning"] == "Need to inspect the workspace first."
|
||||
assert assistant_message["tool_calls"][0]["function"]["name"] == "bash"
|
||||
|
||||
|
||||
def test_vllm_provider_normalizes_legacy_thinking_kwarg_to_enable_thinking():
|
||||
model = VllmChatModel(
|
||||
model="qwen3",
|
||||
api_key="dummy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
extra_body={"chat_template_kwargs": {"thinking": True}},
|
||||
)
|
||||
|
||||
payload = model._get_request_payload([HumanMessage(content="Hello")])
|
||||
|
||||
assert payload["extra_body"]["chat_template_kwargs"] == {"enable_thinking": True}
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_explicit_enable_thinking_kwarg():
|
||||
model = VllmChatModel(
|
||||
model="qwen3",
|
||||
api_key="dummy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": False, "foo": "bar"}},
|
||||
)
|
||||
|
||||
payload = model._get_request_payload([HumanMessage(content="Hello")])
|
||||
|
||||
assert payload["extra_body"]["chat_template_kwargs"] == {
|
||||
"enable_thinking": False,
|
||||
"foo": "bar",
|
||||
}
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_reasoning_in_chat_result():
|
||||
model = _make_model()
|
||||
result = model._create_chat_result(
|
||||
{
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "42",
|
||||
"reasoning": "I compared the two numbers directly.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
)
|
||||
|
||||
message = result.generations[0].message
|
||||
assert message.additional_kwargs["reasoning"] == "I compared the two numbers directly."
|
||||
assert message.additional_kwargs["reasoning_content"] == "I compared the two numbers directly."
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_reasoning_in_streaming_chunks():
|
||||
model = _make_model()
|
||||
chunk = model._convert_chunk_to_generation_chunk(
|
||||
{
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"reasoning": "First, call the weather tool.",
|
||||
"content": "Calling tool...",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
assert chunk.message.additional_kwargs["reasoning"] == "First, call the weather tool."
|
||||
assert chunk.message.additional_kwargs["reasoning_content"] == "First, call the weather tool."
|
||||
assert chunk.message.content == "Calling tool..."
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_empty_reasoning_values_in_streaming_chunks():
|
||||
model = _make_model()
|
||||
chunk = model._convert_chunk_to_generation_chunk(
|
||||
{
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"reasoning": "",
|
||||
"content": "Still replying...",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
assert "reasoning" in chunk.message.additional_kwargs
|
||||
assert chunk.message.additional_kwargs["reasoning"] == ""
|
||||
assert "reasoning_content" not in chunk.message.additional_kwargs
|
||||
assert chunk.message.content == "Still replying..."
|
||||
@ -245,6 +245,28 @@ models:
|
||||
# max_tokens: 8192
|
||||
# temperature: 0.7
|
||||
|
||||
# Example: vLLM 0.19.0 (OpenAI-compatible, with reasoning toggle)
|
||||
# DeerFlow's vLLM provider preserves vLLM reasoning across tool-call turns and
|
||||
# toggles Qwen-style reasoning by writing
|
||||
# extra_body.chat_template_kwargs.enable_thinking=true/false.
|
||||
# Some reasoning models also require the server to be started with
|
||||
# `vllm serve ... --reasoning-parser <parser>`.
|
||||
# - name: qwen3-32b-vllm
|
||||
# display_name: Qwen3 32B (vLLM)
|
||||
# use: deerflow.models.vllm_provider:VllmChatModel
|
||||
# model: Qwen/Qwen3-32B
|
||||
# api_key: $VLLM_API_KEY
|
||||
# base_url: http://localhost:8000/v1
|
||||
# request_timeout: 600.0
|
||||
# max_retries: 2
|
||||
# max_tokens: 8192
|
||||
# supports_thinking: true
|
||||
# supports_vision: false
|
||||
# when_thinking_enabled:
|
||||
# extra_body:
|
||||
# chat_template_kwargs:
|
||||
# enable_thinking: true
|
||||
|
||||
# ============================================================================
|
||||
# Tool Groups Configuration
|
||||
# ============================================================================
|
||||
@ -392,10 +414,11 @@ sandbox:
|
||||
|
||||
# Tool output truncation limits (characters).
|
||||
# bash uses middle-truncation (head + tail) since errors can appear anywhere in the output.
|
||||
# read_file uses head-truncation since source code context is front-loaded.
|
||||
# read_file and ls use head-truncation since their content is front-loaded.
|
||||
# Set to 0 to disable truncation.
|
||||
bash_output_max_chars: 20000
|
||||
read_file_output_max_chars: 50000
|
||||
ls_output_max_chars: 20000
|
||||
|
||||
# Option 2: Container-based AIO Sandbox
|
||||
# Executes commands in isolated containers (Docker or Apple Container)
|
||||
|
||||
@ -127,7 +127,7 @@ services:
|
||||
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
||||
UV_EXTRAS: ${UV_EXTRAS:-}
|
||||
container_name: deer-flow-langgraph
|
||||
command: sh -c 'cd /app/backend && allow_blocking="" && if [ "\${LANGGRAPH_ALLOW_BLOCKING:-0}" = "1" ]; then allow_blocking="--allow-blocking"; fi && uv run langgraph dev --no-browser \${allow_blocking} --no-reload --host 0.0.0.0 --port 2024 --n-jobs-per-worker \${LANGGRAPH_JOBS_PER_WORKER:-10}'
|
||||
command: sh -c 'cd /app/backend && allow_blocking="" && if [ "$${LANGGRAPH_ALLOW_BLOCKING:-0}" = "1" ]; then allow_blocking="--allow-blocking"; fi && uv run langgraph dev --no-browser $${allow_blocking} --no-reload --host 0.0.0.0 --port 2024 --n-jobs-per-worker $${LANGGRAPH_JOBS_PER_WORKER:-10}'
|
||||
volumes:
|
||||
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
||||
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import type { NextRequest } from "next/server";
|
||||
|
||||
const BACKEND_BASE_URL =
|
||||
process.env.NEXT_PUBLIC_BACKEND_BASE_URL ?? "http://127.0.0.1:8010";
|
||||
process.env.NEXT_PUBLIC_BACKEND_BASE_URL ?? "http://127.0.0.1:8001";
|
||||
|
||||
function buildBackendUrl(pathname: string) {
|
||||
return new URL(pathname, BACKEND_BASE_URL);
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import type { LucideIcon } from "lucide-react";
|
||||
import { Children, type ComponentProps } from "react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { LucideIcon } from "lucide-react";
|
||||
import { Children, type ComponentProps } from "react";
|
||||
|
||||
const STAGGER_DELAY_MS = 60;
|
||||
const STAGGER_DELAY_MS_OFFSET = 250;
|
||||
@ -16,12 +17,15 @@ export const Suggestions = ({
|
||||
children,
|
||||
...props
|
||||
}: SuggestionsProps) => (
|
||||
<ScrollArea className="overflow-x-auto whitespace-nowrap" {...props}>
|
||||
<div className={cn("flex w-max flex-nowrap items-center gap-2", className)}>
|
||||
<ScrollArea className="overflow-x-auto whitespace-normal" {...props}>
|
||||
<div
|
||||
className={cn("flex w-full flex-wrap items-center gap-2", className)}
|
||||
data-slot="suggestions-list"
|
||||
>
|
||||
{Children.map(children, (child, index) =>
|
||||
child != null ? (
|
||||
<span
|
||||
className="animate-fade-in-up inline-block opacity-0"
|
||||
className="animate-fade-in-up max-w-full opacity-0"
|
||||
style={{
|
||||
animationDelay: `${STAGGER_DELAY_MS_OFFSET + index * STAGGER_DELAY_MS}ms`,
|
||||
}}
|
||||
@ -60,7 +64,7 @@ export const Suggestion = ({
|
||||
return (
|
||||
<Button
|
||||
className={cn(
|
||||
"text-muted-foreground cursor-pointer rounded-full px-4 text-xs font-normal",
|
||||
"text-muted-foreground h-auto max-w-full cursor-pointer rounded-full px-4 py-2 text-center text-xs font-normal whitespace-normal",
|
||||
className,
|
||||
)}
|
||||
onClick={handleClick}
|
||||
@ -70,7 +74,7 @@ export const Suggestion = ({
|
||||
{...props}
|
||||
>
|
||||
{Icon && <Icon className="size-4" />}
|
||||
{children || suggestion}
|
||||
{children ?? suggestion}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
@ -79,17 +79,17 @@ export function ArtifactFileList({
|
||||
className="relative cursor-pointer p-3"
|
||||
onClick={() => handleClick(file)}
|
||||
>
|
||||
<CardHeader className="pr-2 pl-1">
|
||||
<CardTitle className="relative pl-8">
|
||||
<div>{getFileName(file)}</div>
|
||||
<CardHeader className="grid-cols-[minmax(0,1fr)_auto] items-center gap-x-3 gap-y-1 pr-2 pl-1">
|
||||
<CardTitle className="relative min-w-0 pl-8 leading-tight [overflow-wrap:anywhere] break-words">
|
||||
<div className="min-w-0">{getFileName(file)}</div>
|
||||
<div className="absolute top-2 -left-0.5">
|
||||
{getFileIcon(file, "size-6")}
|
||||
</div>
|
||||
</CardTitle>
|
||||
<CardDescription className="pl-8 text-xs">
|
||||
<CardDescription className="min-w-0 pl-8 text-xs">
|
||||
{getFileExtensionDisplayName(file)} file
|
||||
</CardDescription>
|
||||
<CardAction>
|
||||
<CardAction className="row-span-1 self-center">
|
||||
{file.endsWith(".skill") && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
|
||||
@ -429,7 +429,38 @@ export function InputBox({
|
||||
}, [context.model_name, disabled, isMock, status, thread.messages, threadId]);
|
||||
|
||||
return (
|
||||
<div ref={promptRootRef} className="relative">
|
||||
<div ref={promptRootRef} className="relative flex flex-col gap-4">
|
||||
{showFollowups && (
|
||||
<div className="flex items-center justify-center pb-2">
|
||||
<div className="flex items-center gap-2">
|
||||
{followupsLoading ? (
|
||||
<div className="text-muted-foreground bg-background/80 rounded-full border px-4 py-2 text-xs backdrop-blur-sm">
|
||||
{t.inputBox.followupLoading}
|
||||
</div>
|
||||
) : (
|
||||
<Suggestions className="min-h-16 w-fit items-start">
|
||||
{followups.map((s) => (
|
||||
<Suggestion
|
||||
key={s}
|
||||
suggestion={s}
|
||||
onClick={() => handleFollowupClick(s)}
|
||||
/>
|
||||
))}
|
||||
<Button
|
||||
aria-label={t.common.close}
|
||||
className="text-muted-foreground cursor-pointer rounded-full px-3 text-xs font-normal"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
type="button"
|
||||
onClick={() => setFollowupsHidden(true)}
|
||||
>
|
||||
<XIcon className="size-4" />
|
||||
</Button>
|
||||
</Suggestions>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<PromptInput
|
||||
className={cn(
|
||||
"bg-background/85 rounded-2xl backdrop-blur-sm transition-all duration-300 ease-out *:data-[slot='input-group']:rounded-2xl",
|
||||
@ -807,45 +838,14 @@ export function InputBox({
|
||||
/>
|
||||
</PromptInputTools>
|
||||
</PromptInputFooter>
|
||||
{isNewThread && searchParams.get("mode") !== "skill" && (
|
||||
<div className="absolute right-0 -bottom-20 left-0 z-0 flex items-center justify-center">
|
||||
<SuggestionList />
|
||||
</div>
|
||||
)}
|
||||
{!isNewThread && (
|
||||
<div className="bg-background absolute right-0 -bottom-[17px] left-0 z-0 h-4"></div>
|
||||
)}
|
||||
</PromptInput>
|
||||
|
||||
{showFollowups && (
|
||||
<div className="absolute -top-20 right-0 left-0 z-20 flex items-center justify-center">
|
||||
<div className="flex items-center gap-2">
|
||||
{followupsLoading ? (
|
||||
<div className="text-muted-foreground bg-background/80 rounded-full border px-4 py-2 text-xs backdrop-blur-sm">
|
||||
{t.inputBox.followupLoading}
|
||||
</div>
|
||||
) : (
|
||||
<Suggestions className="min-h-16 w-fit items-start">
|
||||
{followups.map((s) => (
|
||||
<Suggestion
|
||||
key={s}
|
||||
suggestion={s}
|
||||
onClick={() => handleFollowupClick(s)}
|
||||
/>
|
||||
))}
|
||||
<Button
|
||||
aria-label={t.common.close}
|
||||
className="text-muted-foreground cursor-pointer rounded-full px-3 text-xs font-normal"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
type="button"
|
||||
onClick={() => setFollowupsHidden(true)}
|
||||
>
|
||||
<XIcon className="size-4" />
|
||||
</Button>
|
||||
</Suggestions>
|
||||
)}
|
||||
</div>
|
||||
{isNewThread && searchParams.get("mode") !== "skill" && (
|
||||
<div className="flex items-center justify-center pt-2">
|
||||
<SuggestionList />
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@ -36,6 +36,81 @@ type SendMessageOptions = {
|
||||
additionalKwargs?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function normalizeStoredRunId(runId: string | null): string | null {
|
||||
if (!runId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const trimmed = runId.trim();
|
||||
if (!trimmed) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const queryIndex = trimmed.indexOf("?");
|
||||
if (queryIndex >= 0) {
|
||||
const params = new URLSearchParams(trimmed.slice(queryIndex + 1));
|
||||
const queryRunId = params.get("run_id")?.trim();
|
||||
if (queryRunId) {
|
||||
return queryRunId;
|
||||
}
|
||||
}
|
||||
|
||||
const pathWithoutQueryOrHash = trimmed.split(/[?#]/, 1)[0]?.trim() ?? "";
|
||||
if (!pathWithoutQueryOrHash) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const runsMarker = "/runs/";
|
||||
const runsIndex = pathWithoutQueryOrHash.lastIndexOf(runsMarker);
|
||||
if (runsIndex >= 0) {
|
||||
const runIdAfterMarker = pathWithoutQueryOrHash
|
||||
.slice(runsIndex + runsMarker.length)
|
||||
.split("/", 1)[0]
|
||||
?.trim();
|
||||
if (runIdAfterMarker) {
|
||||
return runIdAfterMarker;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const segments = pathWithoutQueryOrHash
|
||||
.split("/")
|
||||
.map((segment) => segment.trim())
|
||||
.filter(Boolean);
|
||||
return segments.at(-1) ?? null;
|
||||
}
|
||||
|
||||
function getRunMetadataStorage(): {
|
||||
getItem(key: `lg:stream:${string}`): string | null;
|
||||
setItem(key: `lg:stream:${string}`, value: string): void;
|
||||
removeItem(key: `lg:stream:${string}`): void;
|
||||
} {
|
||||
return {
|
||||
getItem(key) {
|
||||
const normalized = normalizeStoredRunId(
|
||||
window.sessionStorage.getItem(key),
|
||||
);
|
||||
if (normalized) {
|
||||
window.sessionStorage.setItem(key, normalized);
|
||||
return normalized;
|
||||
}
|
||||
window.sessionStorage.removeItem(key);
|
||||
return null;
|
||||
},
|
||||
setItem(key, value) {
|
||||
const normalized = normalizeStoredRunId(value);
|
||||
if (normalized) {
|
||||
window.sessionStorage.setItem(key, normalized);
|
||||
return;
|
||||
}
|
||||
window.sessionStorage.removeItem(key);
|
||||
},
|
||||
removeItem(key) {
|
||||
window.sessionStorage.removeItem(key);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getStreamErrorMessage(error: unknown): string {
|
||||
if (typeof error === "string" && error.trim()) {
|
||||
return error;
|
||||
@ -113,12 +188,24 @@ export function useThreadStream({
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const updateSubtask = useUpdateSubtask();
|
||||
const runMetadataStorageRef = useRef<
|
||||
ReturnType<typeof getRunMetadataStorage> | undefined
|
||||
>(undefined);
|
||||
|
||||
if (
|
||||
typeof window !== "undefined" &&
|
||||
runMetadataStorageRef.current === undefined
|
||||
) {
|
||||
runMetadataStorageRef.current = getRunMetadataStorage();
|
||||
}
|
||||
|
||||
const thread = useStream<AgentThreadState>({
|
||||
client: getAPIClient(isMock),
|
||||
assistantId: "lead_agent",
|
||||
threadId: onStreamThreadId,
|
||||
reconnectOnMount: true,
|
||||
reconnectOnMount: runMetadataStorageRef.current
|
||||
? () => runMetadataStorageRef.current!
|
||||
: false,
|
||||
fetchStateHistory: { limit: 1 },
|
||||
onCreated(meta) {
|
||||
handleStreamStart(meta.thread_id);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user