mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(config): migrate AppConfig.current() to DeerFlowContext in runtime paths
This commit is contained in:
parent
e7bb1e9c54
commit
faec3bf9f2
@ -9,7 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -203,11 +203,12 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
memory_config = runtime.context.app_config.memory
|
||||
ctx = resolve_context(runtime)
|
||||
memory_config = ctx.app_config.memory
|
||||
if not memory_config.enabled:
|
||||
return None
|
||||
|
||||
thread_id = runtime.context.thread_id
|
||||
thread_id = ctx.thread_id
|
||||
if not thread_id:
|
||||
logger.debug("No thread_id in context, skipping memory update")
|
||||
return None
|
||||
|
||||
@ -9,6 +9,8 @@ from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -44,10 +46,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
return ""
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = AppConfig.current().title
|
||||
if not config.enabled:
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
if not title_config.enabled:
|
||||
return False
|
||||
|
||||
# Check if thread already has a title in state
|
||||
@ -66,12 +69,13 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
# Generate title after first complete exchange
|
||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> tuple[str, str]:
|
||||
"""Extract user/assistant messages and build the title prompt.
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
config = AppConfig.current().title
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
@ -80,23 +84,25 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
user_msg = self._normalize_content(user_msg_content)
|
||||
assistant_msg = self._normalize_content(assistant_msg_content)
|
||||
|
||||
prompt = config.prompt_template.format(
|
||||
max_words=config.max_words,
|
||||
prompt = title_config.prompt_template.format(
|
||||
max_words=title_config.max_words,
|
||||
user_msg=user_msg[:500],
|
||||
assistant_msg=assistant_msg[:500],
|
||||
)
|
||||
return prompt, user_msg
|
||||
|
||||
def _parse_title(self, content: object) -> str:
|
||||
def _parse_title(self, content: object, title_config: TitleConfig | None = None) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
config = AppConfig.current().title
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
title_content = self._normalize_content(content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str) -> str:
|
||||
config = AppConfig.current().title
|
||||
fallback_chars = min(config.max_chars, 50)
|
||||
def _fallback_title(self, user_msg: str, title_config: TitleConfig | None = None) -> str:
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
fallback_chars = min(title_config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
@ -115,39 +121,53 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
_, user_msg = self._build_title_prompt(state, title_config)
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
config = AppConfig.current().title
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
prompt, user_msg = self._build_title_prompt(state, title_config)
|
||||
|
||||
try:
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
if title_config.model_name:
|
||||
model = create_chat_model(name=title_config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
title = self._parse_title(response.content, title_config)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
def _resolve_title_config(self, runtime: Runtime[DeerFlowContext]) -> TitleConfig | None:
|
||||
"""Resolve TitleConfig from the runtime context when possible.
|
||||
|
||||
Returns None on any failure so callers can fall back to
|
||||
``AppConfig.current()`` without raising.
|
||||
"""
|
||||
try:
|
||||
return resolve_context(runtime).app_config.title
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return self._generate_title_result(state)
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return self._generate_title_result(state, self._resolve_title_config(runtime))
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return await self._agenerate_title_result(state)
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return await self._agenerate_title_result(state, self._resolve_title_config(runtime))
|
||||
|
||||
@ -8,6 +8,7 @@ from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
SandboxError,
|
||||
@ -987,7 +988,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
if not is_host_bash_allowed():
|
||||
try:
|
||||
host_bash_config = resolve_context(runtime).app_config
|
||||
except Exception:
|
||||
host_bash_config = None
|
||||
if not is_host_bash_allowed(host_bash_config):
|
||||
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
|
||||
ensure_thread_directories_exist(runtime)
|
||||
thread_data = get_thread_data(runtime)
|
||||
@ -996,14 +1001,14 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
command = _apply_cwd_prefix(command, thread_data)
|
||||
output = sandbox.execute_command(command)
|
||||
try:
|
||||
sandbox_cfg = AppConfig.current().sandbox
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
try:
|
||||
sandbox_cfg = AppConfig.current().sandbox
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
@ -1043,7 +1048,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
return "(empty)"
|
||||
output = "\n".join(children)
|
||||
try:
|
||||
sandbox_cfg = AppConfig.current().sandbox
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
@ -1214,7 +1219,7 @@ def read_file_tool(
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
try:
|
||||
sandbox_cfg = AppConfig.current().sandbox
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||
except Exception:
|
||||
max_chars = 50000
|
||||
|
||||
@ -12,6 +12,7 @@ from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
||||
@ -66,8 +67,13 @@ async def task_tool(
|
||||
if config is None:
|
||||
available = ", ".join(available_subagent_names)
|
||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
||||
if subagent_type == "bash" and not is_host_bash_allowed():
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
if subagent_type == "bash":
|
||||
try:
|
||||
host_bash_config = resolve_context(runtime).app_config
|
||||
except Exception:
|
||||
host_bash_config = None
|
||||
if not is_host_bash_allowed(host_bash_config):
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
|
||||
# Build config overrides
|
||||
overrides: dict = {}
|
||||
|
||||
@ -357,7 +357,7 @@ def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) ->
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: SimpleNamespace(execute_command=lambda command: pytest.fail("host bash should not execute")),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda: False)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = bash_tool.func(
|
||||
runtime=runtime,
|
||||
|
||||
@ -109,7 +109,7 @@ def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
|
||||
def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda: False)
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = _run_task_tool(
|
||||
runtime=_make_runtime(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user