mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 19:28:23 +00:00
fix(middleware): fix present_files thread id fallback (#2181)
* fix present files thread id fallback * fix: resolve present_files thread id from runtime config
This commit is contained in:
parent
1df389b9d0
commit
f4c17c66ce
@ -3,6 +3,7 @@ from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
@ -12,6 +13,23 @@ from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
"""Resolve the current thread id from runtime context or RunnableConfig."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
|
||||
runtime_config = getattr(runtime, "config", None) or {}
|
||||
thread_id = runtime_config.get("configurable", {}).get("thread_id")
|
||||
if thread_id:
|
||||
return thread_id
|
||||
|
||||
try:
|
||||
return get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_presented_filepath(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
filepath: str,
|
||||
@ -33,9 +51,9 @@ def _normalize_presented_filepath(
|
||||
if runtime.state is None:
|
||||
raise ValueError("Thread runtime state is not available")
|
||||
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
thread_id = _get_thread_id(runtime)
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is not available in runtime context")
|
||||
raise ValueError("Thread ID is not available in runtime context or runtime config")
|
||||
|
||||
thread_data = runtime.state.get("thread_data") or {}
|
||||
outputs_path = thread_data.get("outputs_path")
|
||||
|
||||
@ -10,6 +10,7 @@ def _make_runtime(outputs_path: str) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
state={"thread_data": {"outputs_path": outputs_path}},
|
||||
context={"thread_id": "thread-1"},
|
||||
config={},
|
||||
)
|
||||
|
||||
|
||||
@ -50,6 +51,34 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||
|
||||
|
||||
def test_present_files_uses_config_thread_id_when_context_missing(tmp_path, monkeypatch):
|
||||
outputs_dir = tmp_path / "threads" / "thread-from-config" / "user-data" / "outputs"
|
||||
outputs_dir.mkdir(parents=True)
|
||||
artifact_path = outputs_dir / "summary.json"
|
||||
artifact_path.write_text("{}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
present_file_tool_module,
|
||||
"get_paths",
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
state={"thread_data": {"outputs_path": str(outputs_dir)}},
|
||||
context={},
|
||||
config={"configurable": {"thread_id": "thread-from-config"}},
|
||||
)
|
||||
|
||||
result = present_file_tool_module.present_file_tool.func(
|
||||
runtime=runtime,
|
||||
filepaths=["/mnt/user-data/outputs/summary.json"],
|
||||
tool_call_id="tc-config",
|
||||
)
|
||||
|
||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||
assert result.update["messages"][0].content == "Successfully presented files"
|
||||
|
||||
|
||||
def test_present_files_rejects_paths_outside_outputs(tmp_path):
|
||||
outputs_dir = tmp_path / "threads" / "thread-1" / "user-data" / "outputs"
|
||||
workspace_dir = tmp_path / "threads" / "thread-1" / "user-data" / "workspace"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user