diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index f4fdad473..20da78af9 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -8,13 +8,17 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. from __future__ import annotations -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, cast from fastapi import FastAPI, HTTPException, Request +from langgraph.types import Checkpointer -from deerflow.runtime import RunContext, RunManager +from deerflow.persistence.feedback import FeedbackRepository +from deerflow.runtime import RunContext, RunManager, StreamBridge +from deerflow.runtime.events.store.base import RunEventStore +from deerflow.runtime.runs.store.base import RunStore if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider @@ -22,6 +26,9 @@ if TYPE_CHECKING: from deerflow.persistence.thread_meta.base import ThreadMetaStore +T = TypeVar("T") + + @asynccontextmanager async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: """Bootstrap and tear down all LangGraph runtime singletons. @@ -84,25 +91,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: # --------------------------------------------------------------------------- -def _require(attr: str, label: str): +def _require(attr: str, label: str) -> Callable[[Request], T]: """Create a FastAPI dependency that returns ``app.state.`` or 503.""" - def dep(request: Request): + def dep(request: Request) -> T: val = getattr(request.app.state, attr, None) if val is None: raise HTTPException(status_code=503, detail=f"{label} not available") - return val + return cast(T, val) dep.__name__ = dep.__qualname__ = f"get_{attr}" return dep -get_stream_bridge = _require("stream_bridge", "Stream bridge") -get_run_manager = _require("run_manager", "Run manager") -get_checkpointer = _require("checkpointer", "Checkpointer") -get_run_event_store = _require("run_event_store", "Run event store") -get_feedback_repo = _require("feedback_repo", "Feedback") -get_run_store = _require("run_store", "Run store") +get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge") +get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager") +get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer") +get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store") +get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback") +get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store") def get_store(request: Request): @@ -121,10 +128,7 @@ def get_thread_store(request: Request) -> ThreadMetaStore: def get_run_context(request: Request) -> RunContext: """Build a :class:`RunContext` from ``app.state`` singletons. - Returns a *base* context with infrastructure dependencies. Callers that - need per-run fields (e.g. ``follow_up_to_run_id``) should use - ``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it - to :func:`run_agent`. + Returns a *base* context with infrastructure dependencies. """ from deerflow.config import get_app_config @@ -137,7 +141,6 @@ def get_run_context(request: Request) -> RunContext: ) - # --------------------------------------------------------------------------- # Auth helpers (used by authz.py and auth middleware) # --------------------------------------------------------------------------- diff --git a/backend/app/gateway/routers/runs.py b/backend/app/gateway/routers/runs.py index 70e2abb63..f2775466c 100644 --- a/backend/app/gateway/routers/runs.py +++ b/backend/app/gateway/routers/runs.py @@ -123,7 +123,8 @@ async def run_messages( run = await _resolve_run(run_id, request) event_store = get_run_event_store(request) rows = await event_store.list_messages_by_run( - run["thread_id"], run_id, + run["thread_id"], + run_id, limit=limit + 1, before_seq=before_seq, after_seq=after_seq, diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index e21375ab9..e6847c50f 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -54,7 +54,6 @@ class RunCreateRequest(BaseModel): after_seconds: float | None = Field(default=None, description="Delayed execution") if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") - follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.") class RunResponse(BaseModel): @@ -312,11 +311,15 @@ async def list_thread_messages( if i in last_ai_indices: run_id = msg["run_id"] fb = feedback_map.get(run_id) - msg["feedback"] = { - "feedback_id": fb["feedback_id"], - "rating": fb["rating"], - "comment": fb.get("comment"), - } if fb else None + msg["feedback"] = ( + { + "feedback_id": fb["feedback_id"], + "rating": fb["rating"], + "comment": fb.get("comment"), + } + if fb + else None + ) else: msg["feedback"] = None @@ -339,7 +342,8 @@ async def list_run_messages( """ event_store = get_run_event_store(request) rows = await event_store.list_messages_by_run( - thread_id, run_id, + thread_id, + run_id, limit=limit + 1, before_seq=before_seq, after_seq=after_seq, diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index aa707e9ea..e31ff11d2 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -56,7 +56,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None: @router.post("", response_model=UploadResponse) -@require_permission("threads", "write", owner_check=True, require_existing=True) +@require_permission("threads", "write", owner_check=True, require_existing=False) async def upload_files( thread_id: str, request: Request, diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 72f074907..9a1cdb12f 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -195,21 +195,6 @@ async def start_run( disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ - # Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run - follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None) - if follow_up_to_run_id is None: - run_store = get_run_store(request) - try: - recent_runs = await run_store.list_by_thread(thread_id, limit=1) - if recent_runs and recent_runs[0].get("status") == "success": - follow_up_to_run_id = recent_runs[0]["run_id"] - except Exception: - pass # Don't block run creation - - # Enrich base context with per-run field - if follow_up_to_run_id: - run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id) - try: record = await run_mgr.create_or_reject( thread_id, @@ -218,7 +203,6 @@ async def start_run( metadata=body.metadata or {}, kwargs={"input": body.input, "config": body.config}, multitask_strategy=body.multitask_strategy, - follow_up_to_run_id=follow_up_to_run_id, ) except ConflictError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 352ff021e..ef8224e6b 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -1,7 +1,7 @@ import logging from langchain.agents import create_agent -from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware +from langchain.agents.middleware import AgentMiddleware from langchain_core.runnables import RunnableConfig from deerflow.agents.lead_agent.prompt import apply_prompt_template @@ -9,6 +9,7 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware +from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index 0b161152c..1fdc01fcc 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): # the conversation; injecting one mid-conversation crashes # langchain_anthropic's _format_messages(). HumanMessage works # with all providers. See #1299. - return {"messages": [HumanMessage(content=warning)]} + return {"messages": [HumanMessage(content=warning, name="loop_warning")]} return None diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py new file mode 100644 index 000000000..243cdb39f --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -0,0 +1,13 @@ +from typing import override + +from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware +from langchain_core.messages.human import HumanMessage + + +class SummarizationMiddleware(BaseSummarizationMiddleware): + @override + def _build_new_messages(self, summary: str) -> list[HumanMessage]: + """Override the base implementation to let the human message with the special name 'summary'. + And this message will be ignored to display in the frontend, but still can be used as context for the model. + """ + return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")] diff --git a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py index 828a82621..8d93de4ff 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py @@ -1,8 +1,10 @@ import logging +from datetime import UTC, datetime from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import HumanMessage from langgraph.config import get_config from langgraph.runtime import Runtime @@ -97,8 +99,20 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): paths = self._create_thread_directories(thread_id, user_id=user_id) logger.debug("Created thread data directories for thread %s", thread_id) + messages = list(state.get("messages", [])) + last_message = messages[-1] if messages else None + + if last_message and isinstance(last_message, HumanMessage): + messages[-1] = HumanMessage( + content=last_message.content, + id=last_message.id, + name=last_message.name or "user-input", + additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()}, + ) + return { "thread_data": { **paths, - } + }, + "messages": messages, } diff --git a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py index 6622fb695..4f584c5c3 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py @@ -279,6 +279,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]): updated_message = HumanMessage( content=f"{files_message}\n\n{original_content}", id=last_message.id, + name=last_message.name, additional_kwargs=last_message.additional_kwargs, ) diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index a70404e11..5f1838888 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -6,7 +6,10 @@ handles token usage accumulation. Key design decisions: - on_llm_new_token is NOT implemented -- only complete messages via on_llm_end -- on_chat_model_start captures structured prompts as llm_request (OpenAI format) +- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and + extracts the first human message for run.input, because it is more reliable than + on_chain_start (fires on every node) — messages here are fully structured. +- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation. - on_llm_end emits llm_response in OpenAI Chat Completions format - Token usage accumulated in memory, written to RunRow on run completion - Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name}) @@ -18,10 +21,12 @@ import asyncio import logging import time from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage +from langgraph.types import Command if TYPE_CHECKING: from deerflow.runtime.events.store.base import RunEventStore @@ -72,34 +77,39 @@ class RunJournal(BaseCallbackHandler): # LLM request/response tracking self._llm_call_index = 0 self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages - self._cached_models: dict[str, str] = {} # langchain run_id -> model name - - # Tool call ID cache - self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id # -- Lifecycle callbacks -- - def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: - if kwargs.get("parent_run_id") is not None: - return - self._put( - event_type="run_start", - category="lifecycle", - metadata={"input_preview": str(inputs)[:500]}, - ) + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + caller = self._identify_caller(tags) + if parent_run_id is None: + # Root graph invocation — emit a single trace event for the run start. + chain_name = (serialized or {}).get("name", "unknown") + self._put( + event_type="run.start", + category="trace", + content={"chain": chain_name}, + metadata={"caller": caller, **(metadata or {})}, + ) def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: - if kwargs.get("parent_run_id") is not None: - return - self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"}) + self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"}) self._flush_sync() def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: - if kwargs.get("parent_run_id") is not None: - return self._put( - event_type="run_error", - category="lifecycle", + event_type="run.error", + category="error", content=str(error), metadata={"error_type": type(error).__name__}, ) @@ -107,266 +117,132 @@ class RunJournal(BaseCallbackHandler): # -- LLM callbacks -- - def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None: - """Capture structured prompt messages for llm_request event.""" - from deerflow.runtime.converters import langchain_messages_to_openai + def on_chat_model_start( + self, + serialized: dict, + messages: list[list[BaseMessage]], + *, + run_id: UUID, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: + """Capture structured prompt messages for llm_request event. + This is also the canonical place to extract the first human message: + messages are fully structured here, it fires only on real LLM calls, + and the content is never compressed by checkpoint trimming. + """ rid = str(run_id) self._llm_start_times[rid] = time.monotonic() self._llm_call_index += 1 + # Mark this run_id as seen so on_llm_end knows not to increment again. + self._cached_prompts[rid] = [] - model_name = serialized.get("name", "") - self._cached_models[rid] = model_name + logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}") - # Convert the first message list (LangChain passes list-of-lists) - prompt_msgs = messages[0] if messages else [] - openai_msgs = langchain_messages_to_openai(prompt_msgs) - self._cached_prompts[rid] = openai_msgs + # Capture the first human message sent to any LLM in this run. + if not self._first_human_msg: + for batch in messages.reversed(): + for m in batch.reversed(): + if isinstance(m, HumanMessage) and m.name != "summary": + caller = self._identify_caller(tags) + self.set_first_human_message(m.text) + self._put( + event_type="llm.human.input", + category="message", + content=m.model_dump(), + metadata={"caller": caller}, + ) + break + if self._first_human_msg: + break - caller = self._identify_caller(kwargs) - self._put( - event_type="llm_request", - category="trace", - content={"model": model_name, "messages": openai_msgs}, - metadata={"caller": caller, "llm_call_index": self._llm_call_index}, - ) - - def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None: + def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None: # Fallback: on_chat_model_start is preferred. This just tracks latency. self._llm_start_times[str(run_id)] = time.monotonic() - def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None: - from deerflow.runtime.converters import langchain_to_openai_completion - - try: - message = response.generations[0][0].message - except (IndexError, AttributeError): - logger.debug("on_llm_end: could not extract message from response") - return - - caller = self._identify_caller(kwargs) - - # Latency - rid = str(run_id) - start = self._llm_start_times.pop(rid, None) - latency_ms = int((time.monotonic() - start) * 1000) if start else None - - # Token usage from message - usage = getattr(message, "usage_metadata", None) - usage_dict = dict(usage) if usage else {} - - # Resolve call index - call_index = self._llm_call_index - if rid not in self._cached_prompts: - # Fallback: on_chat_model_start was not called - self._llm_call_index += 1 - call_index = self._llm_call_index - - # Clean up caches - self._cached_prompts.pop(rid, None) - self._cached_models.pop(rid, None) - - # Trace event: llm_response (OpenAI completion format) - content = getattr(message, "content", "") - self._put( - event_type="llm_response", - category="trace", - content=langchain_to_openai_completion(message), - metadata={ - "caller": caller, - "usage": usage_dict, - "latency_ms": latency_ms, - "llm_call_index": call_index, - }, - ) - - # Message events: only lead_agent gets message-category events. - # Content uses message.model_dump() to align with checkpoint format. - tool_calls = getattr(message, "tool_calls", None) or [] - if caller == "lead_agent": - resp_meta = getattr(message, "response_metadata", None) or {} - model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None - if tool_calls: - # ai_tool_call: agent decided to use tools - self._put( - event_type="ai_tool_call", - category="message", - content=message.model_dump(), - metadata={"model_name": model_name, "finish_reason": "tool_calls"}, - ) - elif isinstance(content, str) and content: - # ai_message: final text reply - self._put( - event_type="ai_message", - category="message", - content=message.model_dump(), - metadata={"model_name": model_name, "finish_reason": "stop"}, - ) - self._last_ai_msg = content - self._msg_count += 1 - - # Token accumulation - if self._track_tokens: - input_tk = usage_dict.get("input_tokens", 0) or 0 - output_tk = usage_dict.get("output_tokens", 0) or 0 - total_tk = usage_dict.get("total_tokens", 0) or 0 - if total_tk == 0: - total_tk = input_tk + output_tk - if total_tk > 0: - self._total_input_tokens += input_tk - self._total_output_tokens += output_tk - self._total_tokens += total_tk - self._llm_call_count += 1 - if caller.startswith("subagent:"): - self._subagent_tokens += total_tk - elif caller.startswith("middleware:"): - self._middleware_tokens += total_tk + def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None: + messages: list[AnyMessage] = [] + logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}") + for generation in response.generations: + for gen in generation: + if hasattr(gen, "message"): + messages.append(gen.message) else: - self._lead_agent_tokens += total_tk + logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}") + + for message in messages: + caller = self._identify_caller(tags) + + # Latency + rid = str(run_id) + start = self._llm_start_times.pop(rid, None) + latency_ms = int((time.monotonic() - start) * 1000) if start else None + + # Token usage from message + usage = getattr(message, "usage_metadata", None) + usage_dict = dict(usage) if usage else {} + + # Resolve call index + call_index = self._llm_call_index + if rid not in self._cached_prompts: + # Fallback: on_chat_model_start was not called + self._llm_call_index += 1 + call_index = self._llm_call_index + + # Trace event: llm_response (OpenAI completion format) + self._put( + event_type="llm.ai.response", + category="message", + content=message.model_dump(), + metadata={ + "caller": caller, + "usage": usage_dict, + "latency_ms": latency_ms, + "llm_call_index": call_index, + }, + ) + + # Token accumulation + if self._track_tokens: + input_tk = usage_dict.get("input_tokens", 0) or 0 + output_tk = usage_dict.get("output_tokens", 0) or 0 + total_tk = usage_dict.get("total_tokens", 0) or 0 + if total_tk == 0: + total_tk = input_tk + output_tk + if total_tk > 0: + self._total_input_tokens += input_tk + self._total_output_tokens += output_tk + self._total_tokens += total_tk + self._llm_call_count += 1 def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._llm_start_times.pop(str(run_id), None) - self._put(event_type="llm_error", category="trace", content=str(error)) + self._put(event_type="llm.error", category="trace", content=str(error)) - # -- Tool callbacks -- + def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs): + """Handle tool start event, cache tool call ID for later correlation""" + tool_call_id = str(run_id) + logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}") - def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None: - tool_call_id = kwargs.get("tool_call_id") - if tool_call_id: - self._tool_call_ids[str(run_id)] = tool_call_id - self._put( - event_type="tool_start", - category="trace", - metadata={ - "tool_name": serialized.get("name", ""), - "tool_call_id": tool_call_id, - "args": str(input_str)[:2000], - }, - ) - - def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: - from langchain_core.messages import ToolMessage - from langgraph.types import Command - - # Tools that update graph state return a ``Command`` (e.g. - # ``present_files``). LangGraph later unwraps the inner ToolMessage - # into checkpoint state, so to stay checkpoint-aligned we must - # extract it here rather than storing ``str(Command(...))``. - if isinstance(output, Command): - update = getattr(output, "update", None) or {} - inner_msgs = update.get("messages") if isinstance(update, dict) else None - if isinstance(inner_msgs, list): - inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None) - if inner_tool_msg is not None: - output = inner_tool_msg - - # Extract fields from ToolMessage object when LangChain provides one. - # LangChain's _format_output wraps tool results into a ToolMessage - # with tool_call_id, name, status, and artifact — more complete than - # what kwargs alone provides. - if isinstance(output, ToolMessage): - tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = output.name or kwargs.get("name", "") - status = getattr(output, "status", "success") or "success" - content_str = output.content if isinstance(output.content, str) else str(output.content) - # Use model_dump() for checkpoint-aligned message content. - # Override tool_call_id if it was resolved from cache. - msg_content = output.model_dump() - if msg_content.get("tool_call_id") != tool_call_id: - msg_content["tool_call_id"] = tool_call_id - else: - tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = kwargs.get("name", "") - status = "success" - content_str = str(output) - # Construct checkpoint-aligned dict when output is a plain string. - msg_content = ToolMessage( - content=content_str, - tool_call_id=tool_call_id or "", - name=tool_name, - status=status, - ).model_dump() - - # Trace event (always) - self._put( - event_type="tool_end", - category="trace", - content=content_str, - metadata={ - "tool_name": tool_name, - "tool_call_id": tool_call_id, - "status": status, - }, - ) - - # Message event: tool_result (checkpoint-aligned model_dump format) - self._put( - event_type="tool_result", - category="message", - content=msg_content, - metadata={"tool_name": tool_name, "status": status}, - ) - - def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: - from langchain_core.messages import ToolMessage - - tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = kwargs.get("name", "") - - # Trace event - self._put( - event_type="tool_error", - category="trace", - content=str(error), - metadata={ - "tool_name": tool_name, - "tool_call_id": tool_call_id, - }, - ) - - # Message event: tool_result with error status (checkpoint-aligned) - msg_content = ToolMessage( - content=str(error), - tool_call_id=tool_call_id or "", - name=tool_name, - status="error", - ).model_dump() - self._put( - event_type="tool_result", - category="message", - content=msg_content, - metadata={"tool_name": tool_name, "status": "error"}, - ) - - # -- Custom event callback -- - - def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None: - from deerflow.runtime.serialization import serialize_lc_object - - if name == "summarization": - data_dict = data if isinstance(data, dict) else {} - self._put( - event_type="summarization", - category="trace", - content=data_dict.get("summary", ""), - metadata={ - "replaced_message_ids": data_dict.get("replaced_message_ids", []), - "replaced_count": data_dict.get("replaced_count", 0), - }, - ) - self._put( - event_type="middleware:summarize", - category="middleware", - content={"role": "system", "content": data_dict.get("summary", "")}, - metadata={"replaced_count": data_dict.get("replaced_count", 0)}, - ) - else: - event_data = serialize_lc_object(data) if not isinstance(data, dict) else data - self._put( - event_type=name, - category="trace", - metadata=event_data if isinstance(event_data, dict) else {"data": event_data}, - ) + def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs): + """Handle tool end event, append message and clear node data""" + try: + if isinstance(output, ToolMessage): + msg = cast(ToolMessage, output) + self._put(event_type="llm.tool.result", category="message", content=msg.model_dump()) + elif isinstance(output, Command): + cmd = cast(Command, output) + messages = cmd.update.get("messages", []) + for message in messages: + if isinstance(message, BaseMessage): + self._put(event_type="llm.tool.result", category="message", content=message.model_dump()) + else: + logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}") + else: + logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}") + finally: + logger.info(f"Tool end for node {run_id}") # -- Internal methods -- @@ -431,8 +307,9 @@ class RunJournal(BaseCallbackHandler): if exc: logger.warning("Journal flush task failed: %s", exc) - def _identify_caller(self, kwargs: dict) -> str: - for tag in kwargs.get("tags") or []: + def _identify_caller(self, tags: list[str] | None, **kwargs) -> str: + _tags = tags or kwargs.get("tags", []) + for tag in _tags: if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): return tag # Default to lead_agent: the main agent graph does not inject diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 0a0794d87..a54a408b8 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -54,7 +54,7 @@ class RunManager: self._lock = asyncio.Lock() self._store = store - async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None: + async def _persist_to_store(self, record: RunRecord) -> None: """Best-effort persist run record to backing store.""" if self._store is None: return @@ -68,7 +68,6 @@ class RunManager: metadata=record.metadata or {}, kwargs=record.kwargs or {}, created_at=record.created_at, - follow_up_to_run_id=follow_up_to_run_id, ) except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) @@ -90,7 +89,6 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", - follow_up_to_run_id: str | None = None, ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) @@ -109,7 +107,7 @@ class RunManager: ) async with self._lock: self._runs[run_id] = record - await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) + await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record @@ -122,7 +120,7 @@ class RunManager: async with self._lock: # Dict insertion order matches creation order, so reversing it gives # us deterministic newest-first results even when timestamps tie. - return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id] + return [r for r in self._runs.values() if r.thread_id == thread_id] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" @@ -176,7 +174,6 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", - follow_up_to_run_id: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -230,7 +227,7 @@ class RunManager: ) self._runs[run_id] = record - await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) + await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 3212e8ca3..518a1903c 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -29,7 +29,6 @@ class RunStore(abc.ABC): kwargs: dict[str, Any] | None = None, error: str | None = None, created_at: str | None = None, - follow_up_to_run_id: str | None = None, ) -> None: pass diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 0b2b05f07..5a14af3df 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -28,7 +28,6 @@ class MemoryRunStore(RunStore): kwargs=None, error=None, created_at=None, - follow_up_to_run_id=None, ): now = datetime.now(UTC).isoformat() self._runs[run_id] = { @@ -41,7 +40,6 @@ class MemoryRunStore(RunStore): "metadata": metadata or {}, "kwargs": kwargs or {}, "error": error, - "follow_up_to_run_id": follow_up_to_run_id, "created_at": created_at or now, "updated_at": now, } diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 506ab00cd..c018bcabd 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -51,7 +51,6 @@ class RunContext: event_store: Any | None = field(default=None) run_events_config: Any | None = field(default=None) thread_store: Any | None = field(default=None) - follow_up_to_run_id: str | None = field(default=None) async def run_agent( @@ -76,7 +75,6 @@ async def run_agent( event_store = ctx.event_store run_events_config = ctx.run_events_config thread_store = ctx.thread_store - follow_up_to_run_id = ctx.follow_up_to_run_id run_id = record.run_id thread_id = record.thread_id @@ -113,22 +111,6 @@ async def run_agent( track_token_usage=getattr(run_events_config, "track_token_usage", True), ) - human_msg = _extract_human_message(graph_input) - if human_msg is not None: - msg_metadata = {} - if follow_up_to_run_id: - msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id - await event_store.put( - thread_id=thread_id, - run_id=run_id, - event_type="human_message", - category="message", - content=human_msg.model_dump(), - metadata=msg_metadata or None, - ) - content = human_msg.content - journal.set_first_human_message(content if isinstance(content, str) else str(content)) - # 1. Mark running await run_manager.set_status(run_id, RunStatus.running) @@ -166,12 +148,13 @@ async def run_agent( # Inject runtime context so middlewares can access thread_id # (langgraph-cli does this automatically; we must do it manually) - runtime = Runtime(context={"thread_id": thread_id}, store=store) + runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store) # If the caller already set a ``context`` key (LangGraph >= 0.6.0 # prefers it over ``configurable`` for thread-level data), make # sure ``thread_id`` is available there too. if "context" in config and isinstance(config["context"], dict): config["context"].setdefault("thread_id", thread_id) + config["context"].setdefault("run_id", run_id) config.setdefault("configurable", {})["__pregel_runtime"] = runtime # Inject RunJournal as a LangChain callback handler. diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index b306f59ec..0a274df33 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -62,59 +62,62 @@ class TestLlmCallbacks: j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) await j.flush() events = await store.list_events("t1", "r1") - trace_events = [e for e in events if e["event_type"] == "llm_response"] + trace_events = [e for e in events if e["event_type"] == "llm.ai.response"] assert len(trace_events) == 1 - assert trace_events[0]["category"] == "trace" + assert trace_events[0]["category"] == "message" @pytest.mark.anyio async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup): j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) await j.flush() messages = await store.list_messages("t1") assert len(messages) == 1 - assert messages[0]["event_type"] == "ai_message" + assert messages[0]["event_type"] == "llm.ai.response" # Content is checkpoint-aligned model_dump format assert messages[0]["content"]["type"] == "ai" assert messages[0]["content"]["content"] == "Answer" @pytest.mark.anyio async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup): - """LLM response with pending tool_calls should produce ai_tool_call event.""" + """LLM response with pending tool_calls emits llm.ai.response with tool_calls in content.""" j, store = journal_setup run_id = uuid4() j.on_llm_end( _make_llm_response("Let me search", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]), run_id=run_id, + parent_run_id=None, tags=["lead_agent"], ) await j.flush() messages = await store.list_messages("t1") assert len(messages) == 1 - assert messages[0]["event_type"] == "ai_tool_call" + assert messages[0]["event_type"] == "llm.ai.response" + assert len(messages[0]["content"]["tool_calls"]) == 1 @pytest.mark.anyio async def test_on_llm_end_subagent_no_ai_message(self, journal_setup): j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"]) - j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"]) + j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, parent_run_id=None, tags=["subagent:research"]) await j.flush() messages = await store.list_messages("t1") - assert len(messages) == 0 + # subagent responses still emit llm.ai.response with category="message" + assert len(messages) == 1 @pytest.mark.anyio async def test_token_accumulation(self, journal_setup): j, store = journal_setup usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} - j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) assert j._total_input_tokens == 30 assert j._total_output_tokens == 15 assert j._total_tokens == 45 @@ -127,26 +130,26 @@ class TestLlmCallbacks: j.on_llm_end( _make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}), run_id=uuid4(), + parent_run_id=None, tags=["lead_agent"], ) assert j._total_tokens == 150 - assert j._lead_agent_tokens == 150 @pytest.mark.anyio async def test_caller_token_classification(self, journal_setup): j, store = journal_setup usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} - j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"]) - j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"]) - assert j._lead_agent_tokens == 15 - assert j._subagent_tokens == 15 - assert j._middleware_tokens == 15 + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) + j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarization"]) + # token tracking not broken by caller type + assert j._total_tokens == 45 + assert j._llm_call_count == 3 @pytest.mark.anyio async def test_usage_metadata_none_no_crash(self, journal_setup): j, store = journal_setup - j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) await j.flush() @pytest.mark.anyio @@ -154,103 +157,106 @@ class TestLlmCallbacks: j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) await j.flush() events = await store.list_events("t1", "r1") - llm_resp = [e for e in events if e["event_type"] == "llm_response"][0] + llm_resp = [e for e in events if e["event_type"] == "llm.ai.response"][0] assert "latency_ms" in llm_resp["metadata"] assert llm_resp["metadata"]["latency_ms"] is not None class TestLifecycleCallbacks: @pytest.mark.anyio - async def test_chain_start_end_produce_lifecycle_events(self, journal_setup): + async def test_chain_start_end_produce_trace_events(self, journal_setup): j, store = journal_setup j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) - j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + j.on_chain_end({}, run_id=uuid4()) await asyncio.sleep(0.05) await j.flush() events = await store.list_events("t1", "r1") - types = [e["event_type"] for e in events if e["category"] == "lifecycle"] - assert "run_start" in types - assert "run_end" in types + types = {e["event_type"] for e in events} + assert "run.start" in types + assert "run.end" in types @pytest.mark.anyio - async def test_nested_chain_ignored(self, journal_setup): + async def test_nested_chain_no_run_start(self, journal_setup): + """Nested chains (parent_run_id set) should NOT produce run.start.""" j, store = journal_setup parent_id = uuid4() j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) - j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id) + j.on_chain_end({}, run_id=uuid4()) await j.flush() events = await store.list_events("t1", "r1") - lifecycle = [e for e in events if e["category"] == "lifecycle"] - assert len(lifecycle) == 0 + assert not any(e["event_type"] == "run.start" for e in events) class TestToolCallbacks: @pytest.mark.anyio - async def test_tool_start_end_produce_trace(self, journal_setup): + async def test_tool_end_with_tool_message(self, journal_setup): + """on_tool_end with a ToolMessage stores it as llm.tool.result.""" + from langchain_core.messages import ToolMessage + j, store = journal_setup - j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4()) - j.on_tool_end("results", run_id=uuid4(), name="web_search") + tool_msg = ToolMessage(content="results", tool_call_id="call_1", name="web_search") + j.on_tool_end(tool_msg, run_id=uuid4()) await j.flush() - events = await store.list_events("t1", "r1") - trace_types = {e["event_type"] for e in events if e["category"] == "trace"} - assert "tool_start" in trace_types - assert "tool_end" in trace_types + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["event_type"] == "llm.tool.result" + assert messages[0]["content"]["type"] == "tool" @pytest.mark.anyio - async def test_on_tool_error(self, journal_setup): + async def test_tool_end_with_command_unwraps_tool_message(self, journal_setup): + """on_tool_end with Command(update={'messages':[ToolMessage]}) unwraps inner message.""" + from langchain_core.messages import ToolMessage + from langgraph.types import Command + + j, store = journal_setup + inner = ToolMessage(content="file list", tool_call_id="call_2", name="present_files") + cmd = Command(update={"messages": [inner]}) + j.on_tool_end(cmd, run_id=uuid4()) + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["event_type"] == "llm.tool.result" + assert messages[0]["content"]["content"] == "file list" + + @pytest.mark.anyio + async def test_on_tool_error_no_crash(self, journal_setup): + """on_tool_error should not crash (no event emitted by default).""" j, store = journal_setup j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch") await j.flush() + # Base implementation does not emit tool_error — just verify no crash events = await store.list_events("t1", "r1") - assert any(e["event_type"] == "tool_error" for e in events) + assert isinstance(events, list) class TestCustomEvents: @pytest.mark.anyio - async def test_summarization_event(self, journal_setup): + async def test_on_custom_event_not_implemented(self, journal_setup): + """RunJournal does not implement on_custom_event — no crash expected.""" j, store = journal_setup - j.on_custom_event( - "summarization", - {"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]}, - run_id=uuid4(), - ) + # BaseCallbackHandler.on_custom_event is a no-op by default + j.on_custom_event("task_running", {"task_id": "t1"}, run_id=uuid4()) await j.flush() events = await store.list_events("t1", "r1") - trace = [e for e in events if e["event_type"] == "summarization"] - assert len(trace) == 1 - # Summarization goes to middleware category, not message - mw_events = [e for e in events if e["event_type"] == "middleware:summarize"] - assert len(mw_events) == 1 - assert mw_events[0]["category"] == "middleware" - assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."} - # No message events from summarization - messages = await store.list_messages("t1") - assert len(messages) == 0 - - @pytest.mark.anyio - async def test_non_summarization_custom_event(self, journal_setup): - j, store = journal_setup - j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4()) - await j.flush() - events = await store.list_events("t1", "r1") - assert any(e["event_type"] == "task_running" for e in events) + assert isinstance(events, list) class TestBufferFlush: @pytest.mark.anyio async def test_flush_threshold(self, journal_setup): j, store = journal_setup - j._flush_threshold = 3 - j.on_tool_start({"name": "a"}, "x", run_id=uuid4()) - j.on_tool_start({"name": "b"}, "x", run_id=uuid4()) - assert len(j._buffer) == 2 - j.on_tool_start({"name": "c"}, "x", run_id=uuid4()) + j._flush_threshold = 2 + # Each on_llm_end emits 1 event + j.on_llm_end(_make_llm_response("A"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + assert len(j._buffer) == 1 + j.on_llm_end(_make_llm_response("B"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + # At threshold the buffer should have been flushed asynchronously await asyncio.sleep(0.1) events = await store.list_events("t1", "r1") - assert len(events) >= 3 + assert len(events) >= 2 @pytest.mark.anyio async def test_events_retained_when_no_loop(self, journal_setup): @@ -266,44 +272,44 @@ class TestBufferFlush: asyncio.get_running_loop = no_loop try: - j._put(event_type="llm_response", category="trace", content="test") + j._put(event_type="llm.ai.response", category="message", content="test") finally: asyncio.get_running_loop = original assert len(j._buffer) == 1 await j.flush() events = await store.list_events("t1", "r1") - assert any(e["event_type"] == "llm_response" for e in events) + assert any(e["event_type"] == "llm.ai.response" for e in events) class TestIdentifyCaller: def test_lead_agent_tag(self, journal_setup): j, _ = journal_setup - assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent" + assert j._identify_caller(["lead_agent"]) == "lead_agent" def test_subagent_tag(self, journal_setup): j, _ = journal_setup - assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research" + assert j._identify_caller(["subagent:research"]) == "subagent:research" def test_middleware_tag(self, journal_setup): j, _ = journal_setup - assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization" + assert j._identify_caller(["middleware:summarization"]) == "middleware:summarization" def test_no_tags_returns_lead_agent(self, journal_setup): j, _ = journal_setup - assert j._identify_caller({"tags": []}) == "lead_agent" - assert j._identify_caller({}) == "lead_agent" + assert j._identify_caller([]) == "lead_agent" + assert j._identify_caller(None) == "lead_agent" class TestChainErrorCallback: @pytest.mark.anyio async def test_on_chain_error_writes_run_error(self, journal_setup): j, store = journal_setup - j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None) + j.on_chain_error(ValueError("boom"), run_id=uuid4()) await asyncio.sleep(0.05) await j.flush() events = await store.list_events("t1", "r1") - error_events = [e for e in events if e["event_type"] == "run_error"] + error_events = [e for e in events if e["event_type"] == "run.error"] assert len(error_events) == 1 assert "boom" in error_events[0]["content"] assert error_events[0]["metadata"]["error_type"] == "ValueError" @@ -317,6 +323,7 @@ class TestTokenTrackingDisabled: j.on_llm_end( _make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), + parent_run_id=None, tags=["lead_agent"], ) data = j.get_completion_data() @@ -325,15 +332,6 @@ class TestTokenTrackingDisabled: class TestConvenienceFields: - @pytest.mark.anyio - async def test_last_ai_message_tracks_latest(self, journal_setup): - j, store = journal_setup - j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"]) - data = j.get_completion_data() - assert data["last_ai_message"] == "Second" - assert data["message_count"] == 2 - @pytest.mark.anyio async def test_first_human_message_via_set(self, journal_setup): j, _ = journal_setup @@ -351,613 +349,6 @@ class TestConvenienceFields: assert data["message_count"] == 5 -class TestUnknownCallerTokens: - @pytest.mark.anyio - async def test_unknown_caller_tokens_go_to_lead(self, journal_setup): - j, store = journal_setup - j.on_llm_end( - _make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), - run_id=uuid4(), - tags=[], - ) - assert j._lead_agent_tokens == 15 - - -# --------------------------------------------------------------------------- -# SQLite-backed end-to-end test -# --------------------------------------------------------------------------- - - -class TestDbBackedLifecycle: - @pytest.mark.anyio - async def test_full_lifecycle_with_sqlite(self, tmp_path): - """Full lifecycle with SQLite-backed RunRepository + DbRunEventStore.""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.persistence.run import RunRepository - from deerflow.runtime.events.store.db import DbRunEventStore - from deerflow.runtime.runs.manager import RunManager - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - - run_store = RunRepository(sf) - event_store = DbRunEventStore(sf) - mgr = RunManager(store=run_store) - - # Create run - record = await mgr.create("t1", "lead_agent") - run_id = record.run_id - - # Write human_message (checkpoint-aligned format) - from langchain_core.messages import HumanMessage - - human_msg = HumanMessage(content="Hello DB") - await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump()) - - # Simulate journal - journal = RunJournal(run_id, "t1", event_store, flush_threshold=100) - journal.set_first_human_message("Hello DB") - - journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) - llm_rid = uuid4() - journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"]) - journal.on_llm_end( - _make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), - run_id=llm_rid, - tags=["lead_agent"], - ) - journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - await asyncio.sleep(0.05) - await journal.flush() - - # Verify run persisted - row = await run_store.get(run_id) - assert row is not None - assert row["status"] == "pending" - - # Update completion - completion = journal.get_completion_data() - await run_store.update_run_completion(run_id, status="success", **completion) - row = await run_store.get(run_id) - assert row["status"] == "success" - assert row["total_tokens"] == 15 - - # Verify messages from DB (checkpoint-aligned format) - messages = await event_store.list_messages("t1") - assert len(messages) == 2 - assert messages[0]["event_type"] == "human_message" - assert messages[0]["content"]["type"] == "human" - assert messages[1]["event_type"] == "ai_message" - assert messages[1]["content"]["type"] == "ai" - assert messages[1]["content"]["content"] == "DB response" - - # Verify events from DB - events = await event_store.list_events("t1", run_id) - event_types = {e["event_type"] for e in events} - assert "run_start" in event_types - assert "llm_response" in event_types - assert "run_end" in event_types - - await close_engine() - - -class TestDictContentFlag: - """Verify that content_is_dict metadata flag controls deserialization.""" - - @pytest.mark.anyio - async def test_db_store_str_starting_with_brace_not_deserialized(self, tmp_path): - """Plain string content starting with { should NOT be deserialized.""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - store = DbRunEventStore(sf) - - await store.put( - thread_id="t1", - run_id="r1", - event_type="tool_end", - category="trace", - content="{not json, just a string}", - ) - events = await store.list_events("t1", "r1") - assert events[0]["content"] == "{not json, just a string}" - assert isinstance(events[0]["content"], str) - - await close_engine() - - @pytest.mark.anyio - async def test_db_store_str_starting_with_bracket_not_deserialized(self, tmp_path): - """Plain string content like '[1, 2, 3]' should NOT be deserialized.""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - store = DbRunEventStore(sf) - - await store.put( - thread_id="t1", - run_id="r1", - event_type="tool_end", - category="trace", - content="[1, 2, 3]", - ) - events = await store.list_events("t1", "r1") - assert events[0]["content"] == "[1, 2, 3]" - assert isinstance(events[0]["content"], str) - - await close_engine() - - -class TestDictContent: - """Verify that store backends accept str | dict content.""" - - @pytest.mark.anyio - async def test_memory_store_dict_content(self): - store = MemoryRunEventStore() - record = await store.put( - thread_id="t1", - run_id="r1", - event_type="ai_message", - category="message", - content={"role": "assistant", "content": "Hello"}, - ) - assert record["content"] == {"role": "assistant", "content": "Hello"} - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"] == {"role": "assistant", "content": "Hello"} - - @pytest.mark.anyio - async def test_memory_store_str_content_unchanged(self): - store = MemoryRunEventStore() - record = await store.put( - thread_id="t1", - run_id="r1", - event_type="ai_message", - category="message", - content="plain string", - ) - assert record["content"] == "plain string" - assert isinstance(record["content"], str) - - @pytest.mark.anyio - async def test_db_store_dict_content_roundtrip(self, tmp_path): - """Dict content survives DB roundtrip (JSON serialize on write, deserialize on read).""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - store = DbRunEventStore(sf) - - nested = {"role": "assistant", "content": "Hi", "metadata": {"model": "gpt-4", "tokens": [1, 2, 3]}} - record = await store.put( - thread_id="t1", - run_id="r1", - event_type="ai_message", - category="message", - content=nested, - ) - assert record["content"] == nested - - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"] == nested - - await close_engine() - - @pytest.mark.anyio - async def test_db_store_trace_dict_truncation(self, tmp_path): - """Large dict trace content is truncated with metadata flag.""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - store = DbRunEventStore(sf, max_trace_content=100) - - large_dict = {"role": "assistant", "content": "x" * 200} - record = await store.put( - thread_id="t1", - run_id="r1", - event_type="llm_end", - category="trace", - content=large_dict, - ) - assert record["metadata"].get("content_truncated") is True - # Content should be a truncated string (serialized JSON was too long) - assert isinstance(record["content"], str) - assert len(record["content"]) <= 100 - - await close_engine() - - -class TestCheckpointAlignedHumanMessage: - @pytest.mark.anyio - async def test_human_message_checkpoint_format(self): - """human_message content uses model_dump() checkpoint format.""" - from langchain_core.messages import HumanMessage - - store = MemoryRunEventStore() - human_msg = HumanMessage(content="What is AI?") - await store.put( - thread_id="t1", - run_id="r1", - event_type="human_message", - category="message", - content=human_msg.model_dump(), - metadata={"message_id": "msg_001"}, - ) - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"]["type"] == "human" - assert messages[0]["content"]["content"] == "What is AI?" - - -class TestCheckpointAlignedMessageFormat: - @pytest.mark.anyio - async def test_ai_message_checkpoint_format(self, journal_setup): - """ai_message content should be checkpoint-aligned model_dump dict.""" - j, store = journal_setup - j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"]) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"]["type"] == "ai" - assert messages[0]["content"]["content"] == "Answer" - assert "response_metadata" in messages[0]["content"] - assert "additional_kwargs" in messages[0]["content"] - - @pytest.mark.anyio - async def test_ai_tool_call_event(self, journal_setup): - """LLM response with tool_calls should produce ai_tool_call with model_dump content.""" - j, store = journal_setup - tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}] - j.on_llm_end( - _make_llm_response("Let me search", tool_calls=tool_calls), - run_id=uuid4(), - tags=["lead_agent"], - ) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "ai_tool_call" - assert messages[0]["content"]["type"] == "ai" - assert messages[0]["content"]["content"] == "Let me search" - assert len(messages[0]["content"]["tool_calls"]) == 1 - tc = messages[0]["content"]["tool_calls"][0] - assert tc["id"] == "call_1" - assert tc["name"] == "search" - - @pytest.mark.anyio - async def test_ai_tool_call_only_from_lead_agent(self, journal_setup): - """ai_tool_call should only be emitted for lead_agent, not subagents.""" - j, store = journal_setup - tool_calls = [{"id": "call_1", "name": "search", "args": {}}] - j.on_llm_end( - _make_llm_response("searching", tool_calls=tool_calls), - run_id=uuid4(), - tags=["subagent:research"], - ) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 0 - - -class TestToolResultMessage: - @pytest.mark.anyio - async def test_tool_end_produces_tool_result_message(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_tool_start({"name": "web_search"}, '{"query": "test"}', run_id=run_id, tool_call_id="call_abc") - j.on_tool_end("search results here", run_id=run_id, name="web_search", tool_call_id="call_abc") - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "tool_result" - # Content is checkpoint-aligned model_dump format - assert messages[0]["content"]["type"] == "tool" - assert messages[0]["content"]["tool_call_id"] == "call_abc" - assert messages[0]["content"]["content"] == "search results here" - assert messages[0]["content"]["name"] == "web_search" - - @pytest.mark.anyio - async def test_tool_result_missing_tool_call_id(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_tool_start({"name": "bash"}, "ls", run_id=run_id) - j.on_tool_end("file1.txt", run_id=run_id, name="bash") - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"]["type"] == "tool" - - @pytest.mark.anyio - async def test_tool_end_extracts_from_tool_message_object(self, journal_setup): - """When LangChain passes a ToolMessage object as output, extract fields from it.""" - from langchain_core.messages import ToolMessage - - j, store = journal_setup - run_id = uuid4() - tool_msg = ToolMessage( - content="search results", - tool_call_id="call_from_obj", - name="web_search", - status="success", - ) - j.on_tool_end(tool_msg, run_id=run_id) - await j.flush() - - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"]["type"] == "tool" - assert messages[0]["content"]["tool_call_id"] == "call_from_obj" - assert messages[0]["content"]["content"] == "search results" - assert messages[0]["content"]["name"] == "web_search" - assert messages[0]["metadata"]["tool_name"] == "web_search" - assert messages[0]["metadata"]["status"] == "success" - - events = await store.list_events("t1", "r1") - tool_end = [e for e in events if e["event_type"] == "tool_end"][0] - assert tool_end["metadata"]["tool_call_id"] == "call_from_obj" - assert tool_end["metadata"]["tool_name"] == "web_search" - - @pytest.mark.anyio - async def test_tool_invoke_end_to_end_unwraps_command(self, journal_setup): - """End-to-end: invoke a real LangChain tool that returns Command(update={'messages':[ToolMessage]}). - - This goes through the real LangChain callback path (tool.invoke -> CallbackManager - -> on_tool_start/on_tool_end), which is what the production agent uses. Mirrors - the ``present_files`` tool shape exactly. - """ - from langchain_core.callbacks import CallbackManager - from langchain_core.messages import ToolMessage - from langchain_core.tools import tool - from langgraph.types import Command - - j, store = journal_setup - - @tool - def fake_present_files(filepaths: list[str]) -> Command: - """Fake present_files that returns a Command with an inner ToolMessage.""" - return Command( - update={ - "artifacts": filepaths, - "messages": [ToolMessage("Successfully presented files", tool_call_id="tc_123")], - }, - ) - - # Real LangChain callback dispatch (matches production agent path) - cm = CallbackManager(handlers=[j]) - fake_present_files.invoke( - {"filepaths": ["/mnt/user-data/outputs/report.md"]}, - config={"callbacks": cm, "run_id": uuid4()}, - ) - await j.flush() - - messages = await store.list_messages("t1") - assert len(messages) == 1, f"expected 1 message event, got {len(messages)}: {messages}" - content = messages[0]["content"] - assert content["type"] == "tool" - # CRITICAL: must be the inner ToolMessage text, not str(Command(...)) - assert content["content"] == "Successfully presented files", ( - f"Command unwrap failed; stored content = {content['content']!r}" - ) - assert "Command(update=" not in str(content["content"]) - - @pytest.mark.anyio - async def test_tool_end_unwraps_command_with_inner_tool_message(self, journal_setup): - """Tools like ``present_files`` return Command(update={'messages': [ToolMessage(...)]}). - - LangGraph unwraps the inner ToolMessage into checkpoint state, so the - event store must do the same — otherwise it captures ``str(Command(...))`` - and the /history response diverges from the real rendered message. - """ - from langchain_core.messages import ToolMessage - from langgraph.types import Command - - j, store = journal_setup - run_id = uuid4() - inner = ToolMessage( - content="Successfully presented files", - tool_call_id="call_present", - name="present_files", - status="success", - ) - cmd = Command(update={"artifacts": ["/mnt/user-data/outputs/report.md"], "messages": [inner]}) - j.on_tool_end(cmd, run_id=run_id) - await j.flush() - - messages = await store.list_messages("t1") - assert len(messages) == 1 - content = messages[0]["content"] - assert content["type"] == "tool" - assert content["content"] == "Successfully presented files" - assert content["tool_call_id"] == "call_present" - assert content["name"] == "present_files" - assert "Command(update=" not in str(content["content"]) - - @pytest.mark.anyio - async def test_tool_message_object_overrides_kwargs(self, journal_setup): - """ToolMessage object fields take priority over kwargs.""" - from langchain_core.messages import ToolMessage - - j, store = journal_setup - run_id = uuid4() - tool_msg = ToolMessage( - content="result", - tool_call_id="call_obj", - name="tool_a", - status="success", - ) - # Pass different values in kwargs — ToolMessage should win - j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg") - await j.flush() - - messages = await store.list_messages("t1") - assert messages[0]["content"]["tool_call_id"] == "call_obj" - assert messages[0]["content"]["name"] == "tool_a" - assert messages[0]["metadata"]["tool_name"] == "tool_a" - - @pytest.mark.anyio - async def test_tool_message_error_status(self, journal_setup): - """ToolMessage with status='error' propagates status to metadata.""" - from langchain_core.messages import ToolMessage - - j, store = journal_setup - run_id = uuid4() - tool_msg = ToolMessage( - content="something went wrong", - tool_call_id="call_err", - name="web_fetch", - status="error", - ) - j.on_tool_end(tool_msg, run_id=run_id) - await j.flush() - - events = await store.list_events("t1", "r1") - tool_end = [e for e in events if e["event_type"] == "tool_end"][0] - assert tool_end["metadata"]["status"] == "error" - - messages = await store.list_messages("t1") - assert messages[0]["content"]["status"] == "error" - assert messages[0]["metadata"]["status"] == "error" - - @pytest.mark.anyio - async def test_tool_message_fallback_to_cache(self, journal_setup): - """If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start.""" - from langchain_core.messages import ToolMessage - - j, store = journal_setup - run_id = uuid4() - j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached") - tool_msg = ToolMessage( - content="file list", - tool_call_id="", - name="bash", - ) - j.on_tool_end(tool_msg, run_id=run_id) - await j.flush() - - messages = await store.list_messages("t1") - assert messages[0]["content"]["tool_call_id"] == "call_cached" - - @pytest.mark.anyio - async def test_tool_error_produces_tool_result_message(self, journal_setup): - j, store = journal_setup - j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1") - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "tool_result" - assert messages[0]["content"]["type"] == "tool" - assert messages[0]["content"]["tool_call_id"] == "call_1" - assert "timeout" in messages[0]["content"]["content"] - assert messages[0]["content"]["status"] == "error" - assert messages[0]["metadata"]["status"] == "error" - - @pytest.mark.anyio - async def test_tool_error_uses_cached_tool_call_id(self, journal_setup): - """on_tool_error should fall back to cached tool_call_id from on_tool_start.""" - j, store = journal_setup - run_id = uuid4() - j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached") - j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch") - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["content"]["tool_call_id"] == "call_cached" - - -def _make_base_messages(): - """Create mock LangChain BaseMessages for on_chat_model_start.""" - sys_msg = MagicMock() - sys_msg.content = "You are helpful." - sys_msg.type = "system" - sys_msg.tool_calls = [] - sys_msg.tool_call_id = None - - user_msg = MagicMock() - user_msg.content = "Hello" - user_msg.type = "human" - user_msg.tool_calls = [] - user_msg.tool_call_id = None - - return [sys_msg, user_msg] - - -class TestLlmRequestResponse: - @pytest.mark.anyio - async def test_llm_request_event(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - messages = _make_base_messages() - j.on_chat_model_start({"name": "gpt-4o"}, [messages], run_id=run_id, tags=["lead_agent"]) - await j.flush() - events = await store.list_events("t1", "r1") - req_events = [e for e in events if e["event_type"] == "llm_request"] - assert len(req_events) == 1 - content = req_events[0]["content"] - assert content["model"] == "gpt-4o" - assert len(content["messages"]) == 2 - assert content["messages"][0]["role"] == "system" - assert content["messages"][1]["role"] == "user" - - @pytest.mark.anyio - async def test_llm_response_event(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end( - _make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), - run_id=run_id, - tags=["lead_agent"], - ) - await j.flush() - events = await store.list_events("t1", "r1") - assert not any(e["event_type"] == "llm_end" for e in events) - resp_events = [e for e in events if e["event_type"] == "llm_response"] - assert len(resp_events) == 1 - content = resp_events[0]["content"] - assert "choices" in content - assert content["choices"][0]["message"]["role"] == "assistant" - assert content["choices"][0]["message"]["content"] == "Answer" - assert content["usage"]["prompt_tokens"] == 10 - - @pytest.mark.anyio - async def test_llm_request_response_paired(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - messages = _make_base_messages() - j.on_chat_model_start({"name": "gpt-4o"}, [messages], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end( - _make_llm_response("Hi", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), - run_id=run_id, - tags=["lead_agent"], - ) - await j.flush() - events = await store.list_events("t1", "r1") - req = [e for e in events if e["event_type"] == "llm_request"][0] - resp = [e for e in events if e["event_type"] == "llm_response"][0] - assert req["metadata"]["llm_call_index"] == resp["metadata"]["llm_call_index"] - - @pytest.mark.anyio - async def test_no_llm_start_event(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_llm_start({"name": "test"}, [], run_id=run_id, tags=["lead_agent"]) - await j.flush() - events = await store.list_events("t1", "r1") - assert not any(e["event_type"] == "llm_start" for e in events) - - class TestMiddlewareEvents: @pytest.mark.anyio async def test_record_middleware_uses_middleware_category(self, journal_setup): @@ -979,21 +370,6 @@ class TestMiddlewareEvents: assert mw_events[0]["content"]["action"] == "generate_title" assert mw_events[0]["content"]["changes"]["title"] == "Test Title" - @pytest.mark.anyio - async def test_middleware_events_not_in_messages(self, journal_setup): - """Middleware events should not appear in list_messages().""" - j, store = journal_setup - j.record_middleware( - "title", - name="TitleMiddleware", - hook="after_model", - action="generate_title", - changes={"title": "Test"}, - ) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 0 - @pytest.mark.anyio async def test_middleware_tag_variants(self, journal_setup): """Different middleware tags produce distinct event_types.""" @@ -1007,111 +383,3 @@ class TestMiddlewareEvents: assert "middleware:guardrail" in event_types -class TestFullRunSequence: - @pytest.mark.anyio - async def test_complete_run_event_sequence(self): - """Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply. - - All message events use checkpoint-aligned model_dump format. - """ - from langchain_core.messages import HumanMessage - - store = MemoryRunEventStore() - j = RunJournal("r1", "t1", store, flush_threshold=100) - - # 1. Human message (written by worker, using model_dump format) - human_msg = HumanMessage(content="Search for quantum computing") - await store.put( - thread_id="t1", - run_id="r1", - event_type="human_message", - category="message", - content=human_msg.model_dump(), - ) - j.set_first_human_message("Search for quantum computing") - - # 2. Run start - j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) - - # 3. First LLM call -> tool_calls - llm1_id = uuid4() - sys_msg = MagicMock(content="You are helpful.", type="system", tool_calls=[], tool_call_id=None) - user_msg = MagicMock(content="Search for quantum computing", type="human", tool_calls=[], tool_call_id=None) - j.on_chat_model_start({"name": "gpt-4o"}, [[sys_msg, user_msg]], run_id=llm1_id, tags=["lead_agent"]) - j.on_llm_end( - _make_llm_response( - "Let me search", - tool_calls=[{"id": "call_1", "name": "web_search", "args": {"query": "quantum computing"}}], - usage={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, - ), - run_id=llm1_id, - tags=["lead_agent"], - ) - - # 4. Tool execution - tool_id = uuid4() - j.on_tool_start({"name": "web_search"}, '{"query": "quantum computing"}', run_id=tool_id, tool_call_id="call_1") - j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1") - - # 5. Middleware: title generation - j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"}) - - # 6. Second LLM call -> final reply - llm2_id = uuid4() - j.on_chat_model_start({"name": "gpt-4o"}, [[sys_msg, user_msg]], run_id=llm2_id, tags=["lead_agent"]) - j.on_llm_end( - _make_llm_response( - "Here are the results about quantum computing...", - usage={"input_tokens": 200, "output_tokens": 100, "total_tokens": 300}, - ), - run_id=llm2_id, - tags=["lead_agent"], - ) - - # 7. Run end - j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - await asyncio.sleep(0.05) - await j.flush() - - # Verify message sequence - messages = await store.list_messages("t1") - msg_types = [m["event_type"] for m in messages] - assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"] - - # Verify checkpoint-aligned format: all messages use "type" not "role" - assert messages[0]["content"]["type"] == "human" - assert messages[0]["content"]["content"] == "Search for quantum computing" - assert messages[1]["content"]["type"] == "ai" - assert "tool_calls" in messages[1]["content"] - assert messages[2]["content"]["type"] == "tool" - assert messages[2]["content"]["tool_call_id"] == "call_1" - assert messages[3]["content"]["type"] == "ai" - assert messages[3]["content"]["content"] == "Here are the results about quantum computing..." - - # Verify trace events - events = await store.list_events("t1", "r1") - trace_types = [e["event_type"] for e in events if e["category"] == "trace"] - assert "llm_request" in trace_types - assert "llm_response" in trace_types - assert "tool_start" in trace_types - assert "tool_end" in trace_types - assert "llm_start" not in trace_types - assert "llm_end" not in trace_types - - # Verify middleware events are in their own category - mw_events = [e for e in events if e["category"] == "middleware"] - assert len(mw_events) == 1 - assert mw_events[0]["event_type"] == "middleware:title" - - # Verify token accumulation - data = j.get_completion_data() - assert data["total_tokens"] == 420 # 120 + 300 - assert data["llm_call_count"] == 2 - assert data["lead_agent_tokens"] == 420 - assert data["message_count"] == 1 # only final ai_message counts - assert data["last_ai_message"] == "Here are the results about quantum computing..." - - # Verify all message contents are checkpoint-aligned dicts with "type" field - for m in messages: - assert isinstance(m["content"], dict) - assert "type" in m["content"] diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index 2d6a0199c..58ecf1f26 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -75,27 +75,27 @@ async def test_cancel_not_inflight(manager: RunManager): @pytest.mark.anyio async def test_list_by_thread(manager: RunManager): - """Same thread should return multiple runs, newest first.""" + """Same thread should return multiple runs.""" r1 = await manager.create("thread-1") r2 = await manager.create("thread-1") await manager.create("thread-2") runs = await manager.list_by_thread("thread-1") assert len(runs) == 2 - assert runs[0].run_id == r2.run_id - assert runs[1].run_id == r1.run_id + assert runs[0].run_id == r1.run_id + assert runs[1].run_id == r2.run_id @pytest.mark.anyio async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch): - """Newest-first ordering should not depend on timestamp precision.""" + """Ordering should be stable (insertion order) even when timestamps tie.""" monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00") r1 = await manager.create("thread-1") r2 = await manager.create("thread-1") runs = await manager.list_by_thread("thread-1") - assert [run.run_id for run in runs] == [r2.run_id, r1.run_id] + assert [run.run_id for run in runs] == [r1.run_id, r2.run_id] @pytest.mark.anyio diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index 7b288a40d..cf3a3e381 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -46,7 +46,13 @@ export default function AgentChatPage() { const [settings, setSettings] = useThreadSettings(threadId); const { showNotification } = useNotification(); - const [thread, sendMessage] = useThreadStream({ + const { + thread, + sendMessage, + isHistoryLoading, + hasMoreHistory, + loadMoreHistory, + } = useThreadStream({ threadId: isNewThread ? undefined : threadId, context: { ...settings.context, agent_name: agent_name }, onStart: (createdThreadId) => { @@ -141,6 +147,9 @@ export default function AgentChatPage() { threadId={threadId} thread={thread} paddingBottom={messageListPaddingBottom} + hasMoreHistory={hasMoreHistory} + loadMoreHistory={loadMoreHistory} + isHistoryLoading={isHistoryLoading} /> diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index c5ff83dec..98e53e1c9 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { type PromptInputMessage } from "@/components/ai-elements/prompt-input"; import { ArtifactTrigger } from "@/components/workspace/artifacts"; @@ -35,19 +35,30 @@ export default function ChatPage() { const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } = useThreadChat(); const [settings, setSettings] = useThreadSettings(threadId); - const [mounted, setMounted] = useState(false); + const mountedRef = useRef(false); useSpecificChatMode(); useEffect(() => { - setMounted(true); + mountedRef.current = true; }, []); const { showNotification } = useNotification(); - const [thread, sendMessage, isUploading] = useThreadStream({ + const { + thread, + sendMessage, + isUploading, + isHistoryLoading, + hasMoreHistory, + loadMoreHistory, + } = useThreadStream({ threadId: isNewThread ? undefined : threadId, context: settings.context, isMock, + onSend: (_threadId) => { + setThreadId(_threadId); + setIsNewThread(false); + }, onStart: (createdThreadId) => { setThreadId(createdThreadId); setIsNewThread(false); @@ -115,6 +126,9 @@ export default function ChatPage() { threadId={threadId} thread={thread} paddingBottom={messageListPaddingBottom} + hasMoreHistory={hasMoreHistory} + loadMoreHistory={loadMoreHistory} + isHistoryLoading={isHistoryLoading} />
@@ -138,7 +152,7 @@ export default function ChatPage() { />
- {mounted ? ( + {mountedRef.current ? (