diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 88295b9ff..6ff7ea258 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -158,7 +158,7 @@ from deerflow.config import get_app_config Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`: -1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory +1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory 2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation 3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state 4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption) @@ -216,6 +216,9 @@ FastAPI application on port 8001 with health check at `GET /health`. | **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail | | **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types | | **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing | +| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens | +| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific | +| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id | Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway. @@ -229,7 +232,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → **Virtual Path System**: - Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills` -- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/` +- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/` - Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()` - Detection: `is_local_sandbox()` checks `sandbox_id == "local"` @@ -269,7 +272,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → - `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml` - ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary - Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]` -- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py` +- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py` - `image_search/` - Image search via DuckDuckGo ### MCP System (`packages/harness/deerflow/mcp/`) @@ -338,18 +341,27 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a **Components**: - `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O -- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time) +- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary - `prompt.py` - Prompt templates for memory updates +- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple -**Data Structure** (stored in `backend/.deer-flow/memory.json`): +**Per-User Isolation**: +- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json` +- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json` +- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context` +- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`) +- Absolute `storage_path` in config opts out of per-user isolation +- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run` + +**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`): - **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries) - **History**: `recentMonths`, `earlierContext`, `longTermBackground` - **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source` **Workflow**: -1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation +1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id` 2. Queue debounces (30s default), batches updates, deduplicates per-thread -3. Background thread invokes LLM to extract context updates and facts +3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads) 4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append 5. Next interaction injects top 15 facts + context into `` tags in system prompt @@ -357,7 +369,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_ **Configuration** (`config.yaml` → `memory`): - `enabled` / `injection_enabled` - Master switches -- `storage_path` - Path to memory.json +- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation) - `debounce_seconds` - Wait time before processing (default: 30) - `model_name` - LLM for updates (null = default model) - `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7) diff --git a/backend/app/channels/feishu.py b/backend/app/channels/feishu.py index c2a637ff9..5a80016f0 100644 --- a/backend/app/channels/feishu.py +++ b/backend/app/channels/feishu.py @@ -13,6 +13,7 @@ from app.channels.base import Channel from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths +from deerflow.runtime.user_context import get_effective_user_id from deerflow.sandbox.sandbox_provider import get_sandbox_provider logger = logging.getLogger(__name__) @@ -344,8 +345,9 @@ class FeishuChannel(Channel): return f"Failed to obtain the [{type}]" paths = get_paths() - paths.ensure_thread_dirs(thread_id) - uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve() + user_id = get_effective_user_id() + paths.ensure_thread_dirs(thread_id, user_id=user_id) + uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve() ext = "png" if type == "image" else "bin" raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}" diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 2410dcb64..400d29d60 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -17,6 +17,7 @@ from langgraph_sdk.errors import ConflictError from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.store import ChannelStore +from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -341,14 +342,15 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA attachments: list[ResolvedAttachment] = [] paths = get_paths() - outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve() + user_id = get_effective_user_id() + outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve() for virtual_path in artifacts: # Security: only allow files from the agent outputs directory if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX): logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path) continue try: - actual = paths.resolve_virtual_path(thread_id, virtual_path) + actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id) # Verify the resolved path is actually under the outputs directory # (guards against path-traversal even after prefix check) try: diff --git a/backend/app/gateway/path_utils.py b/backend/app/gateway/path_utils.py index 4869c9404..ded348c78 100644 --- a/backend/app/gateway/path_utils.py +++ b/backend/app/gateway/path_utils.py @@ -5,6 +5,7 @@ from pathlib import Path from fastapi import HTTPException from deerflow.config.paths import get_paths +from deerflow.runtime.user_context import get_effective_user_id def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path: @@ -22,7 +23,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path: HTTPException: If the path is invalid or outside allowed directories. """ try: - return get_paths().resolve_virtual_path(thread_id, virtual_path) + return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id()) except ValueError as e: status = 403 if "traversal" in str(e) else 400 raise HTTPException(status_code=status, detail=str(e)) diff --git a/backend/app/gateway/routers/memory.py b/backend/app/gateway/routers/memory.py index 6ee546924..ca9e5f5e5 100644 --- a/backend/app/gateway/routers/memory.py +++ b/backend/app/gateway/routers/memory.py @@ -13,6 +13,7 @@ from deerflow.agents.memory.updater import ( update_memory_fact, ) from deerflow.config.memory_config import get_memory_config +from deerflow.runtime.user_context import get_effective_user_id router = APIRouter(prefix="/api", tags=["memory"]) @@ -147,7 +148,7 @@ async def get_memory() -> MemoryResponse: } ``` """ - memory_data = get_memory_data() + memory_data = get_memory_data(user_id=get_effective_user_id()) return MemoryResponse(**memory_data) @@ -167,7 +168,7 @@ async def reload_memory() -> MemoryResponse: Returns: The reloaded memory data. """ - memory_data = reload_memory_data() + memory_data = reload_memory_data(user_id=get_effective_user_id()) return MemoryResponse(**memory_data) @@ -181,7 +182,7 @@ async def reload_memory() -> MemoryResponse: async def clear_memory() -> MemoryResponse: """Clear all persisted memory data.""" try: - memory_data = clear_memory_data() + memory_data = clear_memory_data(user_id=get_effective_user_id()) except OSError as exc: raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc @@ -202,6 +203,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo content=request.content, category=request.category, confidence=request.confidence, + user_id=get_effective_user_id(), ) except ValueError as exc: raise _map_memory_fact_value_error(exc) from exc @@ -221,7 +223,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse: """Delete a single fact from memory by fact id.""" try: - memory_data = delete_memory_fact(fact_id) + memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id()) except KeyError as exc: raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc except OSError as exc: @@ -245,6 +247,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) - content=request.content, category=request.category, confidence=request.confidence, + user_id=get_effective_user_id(), ) except ValueError as exc: raise _map_memory_fact_value_error(exc) from exc @@ -265,7 +268,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) - ) async def export_memory() -> MemoryResponse: """Export the current memory data.""" - memory_data = get_memory_data() + memory_data = get_memory_data(user_id=get_effective_user_id()) return MemoryResponse(**memory_data) @@ -279,7 +282,7 @@ async def export_memory() -> MemoryResponse: async def import_memory(request: MemoryResponse) -> MemoryResponse: """Import and persist memory data.""" try: - memory_data = import_memory_data(request.model_dump()) + memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id()) except OSError as exc: raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc @@ -337,7 +340,7 @@ async def get_memory_status() -> MemoryStatusResponse: Combined memory configuration and current data. """ config = get_memory_config() - memory_data = get_memory_data() + memory_data = get_memory_data(user_id=get_effective_user_id()) return MemoryStatusResponse( config=MemoryConfigResponse( diff --git a/backend/app/gateway/routers/runs.py b/backend/app/gateway/routers/runs.py index 7d17488fc..70e2abb63 100644 --- a/backend/app/gateway/routers/runs.py +++ b/backend/app/gateway/routers/runs.py @@ -11,10 +11,11 @@ import asyncio import logging import uuid -from fastapi import APIRouter, Request +from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import StreamingResponse -from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge +from app.gateway.authz import require_permission +from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.routers.thread_runs import RunCreateRequest from app.gateway.services import sse_consumer, start_run from deerflow.runtime import serialize_channel_values @@ -85,3 +86,57 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict: logger.exception("Failed to fetch final state for run %s", record.run_id) return {"status": record.status.value, "error": record.error} + + +# --------------------------------------------------------------------------- +# Run-scoped read endpoints +# --------------------------------------------------------------------------- + + +async def _resolve_run(run_id: str, request: Request) -> dict: + """Fetch run by run_id with user ownership check. Raises 404 if not found.""" + run_store = get_run_store(request) + record = await run_store.get(run_id) # user_id=AUTO filters by contextvar + if record is None: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + return record + + +@router.get("/{run_id}/messages") +@require_permission("runs", "read") +async def run_messages( + run_id: str, + request: Request, + limit: int = Query(default=50, le=200, ge=1), + before_seq: int | None = Query(default=None), + after_seq: int | None = Query(default=None), +) -> dict: + """Return paginated messages for a run (cursor-based). + + Pagination: + - after_seq: messages with seq > after_seq (forward) + - before_seq: messages with seq < before_seq (backward) + - neither: latest messages + + Response: { data: [...], has_more: bool } + """ + run = await _resolve_run(run_id, request) + event_store = get_run_event_store(request) + rows = await event_store.list_messages_by_run( + run["thread_id"], run_id, + limit=limit + 1, + before_seq=before_seq, + after_seq=after_seq, + ) + has_more = len(rows) > limit + data = rows[:limit] if has_more else rows + return {"data": data, "has_more": has_more} + + +@router.get("/{run_id}/feedback") +@require_permission("runs", "read") +async def run_feedback(run_id: str, request: Request) -> list[dict]: + """Return all feedback for a run.""" + run = await _resolve_run(run_id, request) + feedback_repo = get_feedback_repo(request) + return await feedback_repo.list_by_run(run["thread_id"], run_id) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index e414801ee..e21375ab9 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -325,10 +325,28 @@ async def list_thread_messages( @router.get("/{thread_id}/runs/{run_id}/messages") @require_permission("runs", "read", owner_check=True) -async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]: - """Return displayable messages for a specific run.""" +async def list_run_messages( + thread_id: str, + run_id: str, + request: Request, + limit: int = Query(default=50, le=200, ge=1), + before_seq: int | None = Query(default=None), + after_seq: int | None = Query(default=None), +) -> dict: + """Return paginated messages for a specific run. + + Response: { data: [...], has_more: bool } + """ event_store = get_run_event_store(request) - return await event_store.list_messages_by_run(thread_id, run_id) + rows = await event_store.list_messages_by_run( + thread_id, run_id, + limit=limit + 1, + before_seq=before_seq, + after_seq=after_seq, + ) + has_more = len(rows) > limit + data = rows[:limit] if has_more else rows + return {"data": data, "has_more": has_more} @router.get("/{thread_id}/runs/{run_id}/events") diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 5eb4a30b5..c7bfa69b6 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -22,10 +22,11 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field, field_validator from app.gateway.authz import require_permission -from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store +from app.gateway.deps import get_checkpointer from app.gateway.utils import sanitize_log_param from deerflow.config.paths import Paths, get_paths from deerflow.runtime import serialize_channel_values +from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["threads"]) @@ -143,11 +144,11 @@ class ThreadHistoryRequest(BaseModel): # --------------------------------------------------------------------------- -def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse: +def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse: """Delete local persisted filesystem data for a thread.""" path_manager = paths or get_paths() try: - path_manager.delete_thread_dir(thread_id) + path_manager.delete_thread_dir(thread_id, user_id=user_id) except ValueError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc except FileNotFoundError: @@ -198,7 +199,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe from app.gateway.deps import get_thread_store # Clean local filesystem - response = _delete_thread_data(thread_id) + response = _delete_thread_data(thread_id, user_id=get_effective_user_id()) # Remove checkpoints (best-effort) checkpointer = getattr(request.app.state, "checkpointer", None) @@ -404,164 +405,6 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse: # --------------------------------------------------------------------------- -# Event-store-backed message loader -# --------------------------------------------------------------------------- - -_LEGACY_CMD_INNER_CONTENT_RE = re.compile( - r"ToolMessage\(content=(?P['\"])(?P.*?)(?P=q)", - re.DOTALL, -) - - -def _sanitize_legacy_command_repr(content_field: Any) -> Any: - """Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr. - - Runs captured before the ``on_tool_end`` fix in ``journal.py`` stored - ``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the - tool_result content. New runs store ``'X'`` directly. For legacy rows, try - to extract ``'X'`` defensively; return the original string if extraction - fails (still no worse than the checkpoint fallback for summarized threads). - """ - if not isinstance(content_field, str) or not content_field.startswith("Command(update="): - return content_field - match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field) - return match.group("inner") if match else content_field - - -async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None: - """Load the full message stream for ``thread_id`` from the event store. - - The event store is append-only and unaffected by summarization — the - checkpoint's ``channel_values["messages"]`` is rewritten in-place when the - SummarizationMiddleware runs, which drops all pre-summarize messages. The - event store retains the full transcript, so callers in Gateway mode should - prefer it for rendering the conversation history. - - In addition to the core message content, this helper attaches two extra - fields to every returned dict: - - - ``run_id``: the ``run_id`` of the event that produced this message. - Always present. - - ``feedback``: thumbs-up/down data. Present only on the **final - ``ai_message`` of each run** (matching the per-run feedback semantics - of ``POST /api/threads/{id}/runs/{run_id}/feedback``). The frontend uses - the presence of this field to decide whether to render the feedback - button, which sidesteps the positional-index mapping bug that an - out-of-band ``/messages`` fetch exhibited. - - Behaviour contract: - - - **Full pagination.** ``RunEventStore.list_messages`` returns the newest - ``limit`` records when no cursor is given, so a fixed limit silently - drops older messages on long threads. We size the read from - ``count_messages()`` and then page forward with ``after_seq`` cursors. - - **Copy-on-read.** Each content dict is copied before ``id`` is patched - so the live store object is never mutated; ``MemoryRunEventStore`` - returns live references. - - **Stable ids.** Messages with ``id=None`` (human + tool_result) receive - a deterministic ``uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")`` so React - keys are stable across requests without altering stored data. AI messages - retain their LLM-assigned ``lc_run--*`` ids. - - **Legacy Command repr.** Rows captured before the ``journal.py`` - ``on_tool_end`` fix stored ``str(Command(update={...}))`` as the tool - result content. ``_sanitize_legacy_command_repr`` extracts the inner - ToolMessage text. - - **User context.** ``DbRunEventStore`` is user-scoped by default via - ``resolve_user_id(AUTO)`` in ``runtime/user_context.py``. This helper - must run inside a request where ``@require_permission`` has populated - the user contextvar. Both callers below are decorated appropriately. - Do not call this helper from CLI or migration scripts without passing - ``user_id=None`` explicitly to the underlying store methods. - - Returns ``None`` when the event store is not configured or has no message - events for this thread, so callers fall back to checkpoint messages. - """ - try: - event_store = get_run_event_store(request) - except Exception: - return None - - try: - total = await event_store.count_messages(thread_id) - except Exception: - logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id)) - return None - if not total: - return None - - # Batch by page_size to keep memory bounded for very long threads. - page_size = 500 - collected: list[dict] = [] - after_seq: int | None = None - while True: - try: - page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq) - except Exception: - logger.exception("list_messages failed for thread %s", sanitize_log_param(thread_id)) - return None - if not page: - break - collected.extend(page) - if len(page) < page_size: - break - next_cursor = page[-1].get("seq") - if next_cursor is None or (after_seq is not None and next_cursor <= after_seq): - break - after_seq = next_cursor - - # Build the message list; track the final ``ai_message`` index per run so - # feedback can be attached at the right position (matches thread_runs.py). - messages: list[dict] = [] - last_ai_per_run: dict[str, int] = {} - for evt in collected: - raw = evt.get("content") - if not isinstance(raw, dict) or "type" not in raw: - continue - content = dict(raw) - if content.get("id") is None: - content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}")) - if content.get("type") == "tool": - content["content"] = _sanitize_legacy_command_repr(content.get("content")) - run_id = evt.get("run_id") - if run_id: - content["run_id"] = run_id - if evt.get("event_type") == "ai_message" and run_id: - last_ai_per_run[run_id] = len(messages) - messages.append(content) - - if not messages: - return None - - # Attach feedback to the final ai_message of each run. If the feedback - # subsystem is unavailable, leave the ``feedback`` field absent entirely - # so the frontend hides the button rather than showing it over a broken - # write path. - feedback_available = False - feedback_map: dict[str, dict] = {} - try: - feedback_repo = get_feedback_repo(request) - user_id = await get_current_user(request) - feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id) - feedback_available = True - except Exception: - logger.exception("feedback lookup failed for thread %s", sanitize_log_param(thread_id)) - - if feedback_available: - for run_id, idx in last_ai_per_run.items(): - fb = feedback_map.get(run_id) - messages[idx]["feedback"] = ( - { - "feedback_id": fb["feedback_id"], - "rating": fb["rating"], - "comment": fb.get("comment"), - } - if fb - else None - ) - - return messages - - @router.get("/{thread_id}/state", response_model=ThreadStateResponse) @require_permission("threads", "read", owner_check=True) async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: @@ -602,11 +445,6 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo values = serialize_channel_values(channel_values) - # Prefer event-store messages: append-only, immune to summarization. - es_messages = await _get_event_store_messages(request, thread_id) - if es_messages is not None: - values["messages"] = es_messages - return ThreadStateResponse( values=values, next=next_tasks, @@ -726,11 +564,6 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request if body.before: config["configurable"]["checkpoint_id"] = body.before - # Load the full event-store message stream once; attach to the latest - # checkpoint entry only (matching the prior semantics). The event store - # is append-only and immune to summarization. - es_messages = await _get_event_store_messages(request, thread_id) - entries: list[HistoryEntry] = [] is_latest_checkpoint = True try: @@ -754,17 +587,11 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request if thread_data := channel_values.get("thread_data"): values["thread_data"] = thread_data - # Attach messages only to the latest checkpoint. Prefer the - # event-store stream (complete and unaffected by summarization); - # fall back to checkpoint channel_values when the event store is - # unavailable or empty. + # Attach messages only to the latest checkpoint entry. if is_latest_checkpoint: - if es_messages is not None: - values["messages"] = es_messages - else: - messages = channel_values.get("messages") - if messages: - values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) + messages = channel_values.get("messages") + if messages: + values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) is_latest_checkpoint = False # Derive next tasks diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index 3de297355..aa707e9ea 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from app.gateway.authz import require_permission from deerflow.config.paths import get_paths +from deerflow.runtime.user_context import get_effective_user_id from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.uploads.manager import ( PathTraversalError, @@ -69,7 +70,7 @@ async def upload_files( uploads_dir = ensure_uploads_dir(thread_id) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id) + sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) uploaded_files = [] sandbox_provider = get_sandbox_provider() @@ -147,7 +148,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict: enrich_file_listing(result, thread_id) # Gateway additionally includes the sandbox-relative path. - sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id) + sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) for f in result["files"]: f["path"] = str(sandbox_uploads / f["filename"]) diff --git a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py index 71af2e653..8e00e1ea4 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py @@ -519,12 +519,13 @@ def _get_memory_context(agent_name: str | None = None) -> str: try: from deerflow.agents.memory import format_memory_for_injection, get_memory_data from deerflow.config.memory_config import get_memory_config + from deerflow.runtime.user_context import get_effective_user_id config = get_memory_config() if not config.enabled or not config.injection_enabled: return "" - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=get_effective_user_id()) memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens) if not memory_content.strip(): diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index 1db8c63dc..6de0bdcfc 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -20,6 +20,7 @@ class ConversationContext: messages: list[Any] timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) agent_name: str | None = None + user_id: str | None = None correction_detected: bool = False reinforcement_detected: bool = False @@ -44,6 +45,7 @@ class MemoryUpdateQueue: thread_id: str, messages: list[Any], agent_name: str | None = None, + user_id: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, ) -> None: @@ -53,6 +55,9 @@ class MemoryUpdateQueue: thread_id: The thread ID. messages: The conversation messages. agent_name: If provided, memory is stored per-agent. If None, uses global memory. + user_id: The user ID captured at enqueue time. Stored in ConversationContext so it + survives the threading.Timer boundary (ContextVar does not propagate across + raw threads). correction_detected: Whether recent turns include an explicit correction signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal. """ @@ -71,6 +76,7 @@ class MemoryUpdateQueue: thread_id=thread_id, messages=messages, agent_name=agent_name, + user_id=user_id, correction_detected=merged_correction_detected, reinforcement_detected=merged_reinforcement_detected, ) @@ -136,6 +142,7 @@ class MemoryUpdateQueue: agent_name=context.agent_name, correction_detected=context.correction_detected, reinforcement_detected=context.reinforcement_detected, + user_id=context.user_id, ) if success: logger.info("Memory updated successfully for thread %s", context.thread_id) diff --git a/backend/packages/harness/deerflow/agents/memory/storage.py b/backend/packages/harness/deerflow/agents/memory/storage.py index 3d57d059b..d35fefa9d 100644 --- a/backend/packages/harness/deerflow/agents/memory/storage.py +++ b/backend/packages/harness/deerflow/agents/memory/storage.py @@ -43,17 +43,17 @@ class MemoryStorage(abc.ABC): """Abstract base class for memory storage providers.""" @abc.abstractmethod - def load(self, agent_name: str | None = None) -> dict[str, Any]: + def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Load memory data for the given agent.""" pass @abc.abstractmethod - def reload(self, agent_name: str | None = None) -> dict[str, Any]: + def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Force reload memory data for the given agent.""" pass @abc.abstractmethod - def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: + def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: """Save memory data for the given agent.""" pass @@ -63,9 +63,9 @@ class FileMemoryStorage(MemoryStorage): def __init__(self): """Initialize the file memory storage.""" - # Per-agent memory cache: keyed by agent_name (None = global) + # Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global) # Value: (memory_data, file_mtime) - self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} + self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {} def _validate_agent_name(self, agent_name: str) -> None: """Validate that the agent name is safe to use in filesystem paths. @@ -78,21 +78,29 @@ class FileMemoryStorage(MemoryStorage): if not AGENT_NAME_PATTERN.match(agent_name): raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}") - def _get_memory_file_path(self, agent_name: str | None = None) -> Path: + def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path: """Get the path to the memory file.""" + if user_id is not None: + if agent_name is not None: + self._validate_agent_name(agent_name) + return get_paths().user_agent_memory_file(user_id, agent_name) + config = get_memory_config() + if config.storage_path and Path(config.storage_path).is_absolute(): + return Path(config.storage_path) + return get_paths().user_memory_file(user_id) + # Legacy: no user_id if agent_name is not None: self._validate_agent_name(agent_name) return get_paths().agent_memory_file(agent_name) - config = get_memory_config() if config.storage_path: p = Path(config.storage_path) return p if p.is_absolute() else get_paths().base_dir / p return get_paths().memory_file - def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]: + def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Load memory data from file.""" - file_path = self._get_memory_file_path(agent_name) + file_path = self._get_memory_file_path(agent_name, user_id=user_id) if not file_path.exists(): return create_empty_memory() @@ -105,40 +113,42 @@ class FileMemoryStorage(MemoryStorage): logger.warning("Failed to load memory file: %s", e) return create_empty_memory() - def load(self, agent_name: str | None = None) -> dict[str, Any]: + def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Load memory data (cached with file modification time check).""" - file_path = self._get_memory_file_path(agent_name) + file_path = self._get_memory_file_path(agent_name, user_id=user_id) try: current_mtime = file_path.stat().st_mtime if file_path.exists() else None except OSError: current_mtime = None - cached = self._memory_cache.get(agent_name) + cache_key = (user_id, agent_name) + cached = self._memory_cache.get(cache_key) if cached is None or cached[1] != current_mtime: - memory_data = self._load_memory_from_file(agent_name) - self._memory_cache[agent_name] = (memory_data, current_mtime) + memory_data = self._load_memory_from_file(agent_name, user_id=user_id) + self._memory_cache[cache_key] = (memory_data, current_mtime) return memory_data return cached[0] - def reload(self, agent_name: str | None = None) -> dict[str, Any]: + def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Reload memory data from file, forcing cache invalidation.""" - file_path = self._get_memory_file_path(agent_name) - memory_data = self._load_memory_from_file(agent_name) + file_path = self._get_memory_file_path(agent_name, user_id=user_id) + memory_data = self._load_memory_from_file(agent_name, user_id=user_id) try: mtime = file_path.stat().st_mtime if file_path.exists() else None except OSError: mtime = None - self._memory_cache[agent_name] = (memory_data, mtime) + cache_key = (user_id, agent_name) + self._memory_cache[cache_key] = (memory_data, mtime) return memory_data - def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: + def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: """Save memory data to file and update cache.""" - file_path = self._get_memory_file_path(agent_name) + file_path = self._get_memory_file_path(agent_name, user_id=user_id) try: file_path.parent.mkdir(parents=True, exist_ok=True) @@ -155,7 +165,8 @@ class FileMemoryStorage(MemoryStorage): except OSError: mtime = None - self._memory_cache[agent_name] = (memory_data, mtime) + cache_key = (user_id, agent_name) + self._memory_cache[cache_key] = (memory_data, mtime) logger.info("Memory saved to %s", file_path) return True except OSError as e: diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index d1f124d4c..178c6bf62 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]: return create_empty_memory() -def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool: +def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: """Backward-compatible wrapper around the configured memory storage save path.""" - return get_memory_storage().save(memory_data, agent_name) + return get_memory_storage().save(memory_data, agent_name, user_id=user_id) -def get_memory_data(agent_name: str | None = None) -> dict[str, Any]: +def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Get the current memory data via storage provider.""" - return get_memory_storage().load(agent_name) + return get_memory_storage().load(agent_name, user_id=user_id) -def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]: +def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Reload memory data via storage provider.""" - return get_memory_storage().reload(agent_name) + return get_memory_storage().reload(agent_name, user_id=user_id) -def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]: +def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Persist imported memory data via storage provider. Args: memory_data: Full memory payload to persist. agent_name: If provided, imports into per-agent memory. + user_id: If provided, scopes memory to a specific user. Returns: The saved memory data after storage normalization. @@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non OSError: If persisting the imported memory fails. """ storage = get_memory_storage() - if not storage.save(memory_data, agent_name): + if not storage.save(memory_data, agent_name, user_id=user_id): raise OSError("Failed to save imported memory data") - return storage.load(agent_name) + return storage.load(agent_name, user_id=user_id) -def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]: +def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Clear all stored memory data and persist an empty structure.""" cleared_memory = create_empty_memory() - if not _save_memory_to_file(cleared_memory, agent_name): + if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id): raise OSError("Failed to save cleared memory data") return cleared_memory @@ -81,6 +82,8 @@ def create_memory_fact( category: str = "context", confidence: float = 0.5, agent_name: str | None = None, + *, + user_id: str | None = None, ) -> dict[str, Any]: """Create a new fact and persist the updated memory data.""" normalized_content = content.strip() @@ -90,7 +93,7 @@ def create_memory_fact( normalized_category = category.strip() or "context" validated_confidence = _validate_confidence(confidence) now = utc_now_iso_z() - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=user_id) updated_memory = dict(memory_data) facts = list(memory_data.get("facts", [])) facts.append( @@ -105,15 +108,15 @@ def create_memory_fact( ) updated_memory["facts"] = facts - if not _save_memory_to_file(updated_memory, agent_name): + if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): raise OSError("Failed to save memory data after creating fact") return updated_memory -def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]: +def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Delete a fact by its id and persist the updated memory data.""" - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=user_id) facts = memory_data.get("facts", []) updated_facts = [fact for fact in facts if fact.get("id") != fact_id] if len(updated_facts) == len(facts): @@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, updated_memory = dict(memory_data) updated_memory["facts"] = updated_facts - if not _save_memory_to_file(updated_memory, agent_name): + if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'") return updated_memory @@ -134,9 +137,11 @@ def update_memory_fact( category: str | None = None, confidence: float | None = None, agent_name: str | None = None, + *, + user_id: str | None = None, ) -> dict[str, Any]: """Update an existing fact and persist the updated memory data.""" - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=user_id) updated_memory = dict(memory_data) updated_facts: list[dict[str, Any]] = [] found = False @@ -163,7 +168,7 @@ def update_memory_fact( updated_memory["facts"] = updated_facts - if not _save_memory_to_file(updated_memory, agent_name): + if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): raise OSError(f"Failed to save memory data after updating fact '{fact_id}'") return updated_memory @@ -276,6 +281,7 @@ class MemoryUpdater: agent_name: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, + user_id: str | None = None, ) -> bool: """Update memory based on conversation messages. @@ -285,6 +291,7 @@ class MemoryUpdater: agent_name: If provided, updates per-agent memory. If None, updates global memory. correction_detected: Whether recent turns include an explicit correction signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal. + user_id: If provided, scopes memory to a specific user. Returns: True if update was successful, False otherwise. @@ -298,7 +305,7 @@ class MemoryUpdater: try: # Get current memory - current_memory = get_memory_data(agent_name) + current_memory = get_memory_data(agent_name, user_id=user_id) # Format conversation for prompt conversation_text = format_conversation_for_update(messages) @@ -353,7 +360,7 @@ class MemoryUpdater: updated_memory = _strip_upload_mentions_from_memory(updated_memory) # Save - return get_memory_storage().save(updated_memory, agent_name) + return get_memory_storage().save(updated_memory, agent_name, user_id=user_id) except json.JSONDecodeError as e: logger.warning("Failed to parse LLM response for memory update: %s", e) @@ -455,6 +462,7 @@ def update_memory_from_conversation( agent_name: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, + user_id: str | None = None, ) -> bool: """Convenience function to update memory from a conversation. @@ -464,9 +472,10 @@ def update_memory_from_conversation( agent_name: If provided, updates per-agent memory. If None, updates global memory. correction_detected: Whether recent turns include an explicit correction signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal. + user_id: If provided, scopes memory to a specific user. Returns: True if successful, False otherwise. """ updater = MemoryUpdater() - return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected) + return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id) diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index 5e8ca6344..7f239a89e 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -11,6 +11,7 @@ from langgraph.runtime import Runtime from deerflow.agents.memory.queue import get_memory_queue from deerflow.config.memory_config import get_memory_config +from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -236,11 +237,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): # Queue the filtered conversation for memory update correction_detected = detect_correction(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) + # Capture user_id at enqueue time while the request context is still alive. + # threading.Timer fires on a different thread where ContextVar values are not + # propagated, so we must store user_id explicitly in ConversationContext. + user_id = get_effective_user_id() queue = get_memory_queue() queue.add( thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) diff --git a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py index c25531e02..828a82621 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py @@ -8,6 +8,7 @@ from langgraph.runtime import Runtime from deerflow.agents.thread_state import ThreadDataState from deerflow.config.paths import Paths, get_paths +from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -46,32 +47,34 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): self._paths = Paths(base_dir) if base_dir else get_paths() self._lazy_init = lazy_init - def _get_thread_paths(self, thread_id: str) -> dict[str, str]: + def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]: """Get the paths for a thread's data directories. Args: thread_id: The thread ID. + user_id: Optional user ID for per-user path isolation. Returns: Dictionary with workspace_path, uploads_path, and outputs_path. """ return { - "workspace_path": str(self._paths.sandbox_work_dir(thread_id)), - "uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)), - "outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)), + "workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)), + "uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)), + "outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)), } - def _create_thread_directories(self, thread_id: str) -> dict[str, str]: + def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]: """Create the thread data directories. Args: thread_id: The thread ID. + user_id: Optional user ID for per-user path isolation. Returns: Dictionary with the created directory paths. """ - self._paths.ensure_thread_dirs(thread_id) - return self._get_thread_paths(thread_id) + self._paths.ensure_thread_dirs(thread_id, user_id=user_id) + return self._get_thread_paths(thread_id, user_id=user_id) @override def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: @@ -84,12 +87,14 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): if thread_id is None: raise ValueError("Thread ID is required in runtime context or config.configurable") + user_id = get_effective_user_id() + if self._lazy_init: # Lazy initialization: only compute paths, don't create directories - paths = self._get_thread_paths(thread_id) + paths = self._get_thread_paths(thread_id, user_id=user_id) else: # Eager initialization: create directories immediately - paths = self._create_thread_directories(thread_id) + paths = self._create_thread_directories(thread_id, user_id=user_id) logger.debug("Created thread data directories for thread %s", thread_id) return { diff --git a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py index 78c9a7b7b..6622fb695 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py @@ -10,6 +10,7 @@ from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime from deerflow.config.paths import Paths, get_paths +from deerflow.runtime.user_context import get_effective_user_id from deerflow.utils.file_conversion import extract_outline logger = logging.getLogger(__name__) @@ -221,7 +222,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]): thread_id = get_config().get("configurable", {}).get("thread_id") except RuntimeError: pass # get_config() raises outside a runnable context (e.g. unit tests) - uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None + uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None # Get newly uploaded files from the current message's additional_kwargs.files new_files = self._files_from_kwargs(last_message, uploads_dir) or [] diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 950fdb085..7623c8f3e 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -40,6 +40,7 @@ from deerflow.config.app_config import get_app_config, reload_app_config from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.paths import get_paths from deerflow.models import create_chat_model +from deerflow.runtime.user_context import get_effective_user_id from deerflow.skills.installer import install_skill_from_archive from deerflow.uploads.manager import ( claim_unique_filename, @@ -769,19 +770,19 @@ class DeerFlowClient: """ from deerflow.agents.memory.updater import get_memory_data - return get_memory_data() + return get_memory_data(user_id=get_effective_user_id()) def export_memory(self) -> dict: """Export current memory data for backup or transfer.""" from deerflow.agents.memory.updater import get_memory_data - return get_memory_data() + return get_memory_data(user_id=get_effective_user_id()) def import_memory(self, memory_data: dict) -> dict: """Import and persist full memory data.""" from deerflow.agents.memory.updater import import_memory_data - return import_memory_data(memory_data) + return import_memory_data(memory_data, user_id=get_effective_user_id()) def get_model(self, name: str) -> dict | None: """Get a specific model's configuration by name. @@ -956,13 +957,13 @@ class DeerFlowClient: """ from deerflow.agents.memory.updater import reload_memory_data - return reload_memory_data() + return reload_memory_data(user_id=get_effective_user_id()) def clear_memory(self) -> dict: """Clear all persisted memory data.""" from deerflow.agents.memory.updater import clear_memory_data - return clear_memory_data() + return clear_memory_data(user_id=get_effective_user_id()) def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict: """Create a single fact manually.""" @@ -1179,7 +1180,7 @@ class DeerFlowClient: ValueError: If the path is invalid. """ try: - actual = get_paths().resolve_virtual_path(thread_id, path) + actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id()) except ValueError as exc: if "traversal" in str(exc): from deerflow.uploads.manager import PathTraversalError diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py index 5bc3c3981..27a20c701 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py @@ -27,6 +27,7 @@ except ImportError: # pragma: no cover - Windows fallback from deerflow.config import get_app_config from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths +from deerflow.runtime.user_context import get_effective_user_id from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox_provider import SandboxProvider @@ -260,15 +261,16 @@ class AioSandboxProvider(SandboxProvider): mounted Docker socket (DooD), the host Docker daemon can resolve the paths. """ paths = get_paths() - paths.ensure_thread_dirs(thread_id) + user_id = get_effective_user_id() + paths.ensure_thread_dirs(thread_id, user_id=user_id) return [ - (paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False), - (paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False), - (paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False), + (paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False), + (paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False), + (paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False), # ACP workspace: read-only inside the sandbox (lead agent reads results; # the ACP subprocess writes from the host side, not from within the container). - (paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True), + (paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True), ] @staticmethod @@ -480,8 +482,9 @@ class AioSandboxProvider(SandboxProvider): across multiple processes, preventing container-name conflicts. """ paths = get_paths() - paths.ensure_thread_dirs(thread_id) - lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock" + user_id = get_effective_user_id() + paths.ensure_thread_dirs(thread_id, user_id=user_id) + lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock" with open(lock_path, "a", encoding="utf-8") as lock_file: locked = False diff --git a/backend/packages/harness/deerflow/config/memory_config.py b/backend/packages/harness/deerflow/config/memory_config.py index 8565aa216..f9153262f 100644 --- a/backend/packages/harness/deerflow/config/memory_config.py +++ b/backend/packages/harness/deerflow/config/memory_config.py @@ -14,8 +14,9 @@ class MemoryConfig(BaseModel): default="", description=( "Path to store memory data. " - "If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). " - "Absolute paths are used as-is. " + "If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. " + "Absolute paths are used as-is and opt out of per-user isolation " + "(all users share the same file). " "Relative paths are resolved against `Paths.base_dir` " "(not the backend working directory). " "Note: if you previously set this to `.deer-flow/memory.json`, " diff --git a/backend/packages/harness/deerflow/config/paths.py b/backend/packages/harness/deerflow/config/paths.py index 2d5661e63..f1ce7eae1 100644 --- a/backend/packages/harness/deerflow/config/paths.py +++ b/backend/packages/harness/deerflow/config/paths.py @@ -7,6 +7,7 @@ from pathlib import Path, PureWindowsPath VIRTUAL_PATH_PREFIX = "/mnt/user-data" _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") +_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") def _default_local_base_dir() -> Path: @@ -22,6 +23,13 @@ def _validate_thread_id(thread_id: str) -> str: return thread_id +def _validate_user_id(user_id: str) -> str: + """Validate a user ID before using it in filesystem paths.""" + if not _SAFE_USER_ID_RE.match(user_id): + raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.") + return user_id + + def _join_host_path(base: str, *parts: str) -> str: """Join host filesystem path segments while preserving native style. @@ -134,44 +142,63 @@ class Paths: """Per-agent memory file: `{base_dir}/agents/{name}/memory.json`.""" return self.agent_dir(name) / "memory.json" - def thread_dir(self, thread_id: str) -> Path: + def user_dir(self, user_id: str) -> Path: + """Directory for a specific user: `{base_dir}/users/{user_id}/`.""" + return self.base_dir / "users" / _validate_user_id(user_id) + + def user_memory_file(self, user_id: str) -> Path: + """Per-user memory file: `{base_dir}/users/{user_id}/memory.json`.""" + return self.user_dir(user_id) / "memory.json" + + def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path: + """Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`.""" + return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json" + + def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: """ - Host path for a thread's data: `{base_dir}/threads/{thread_id}/` + Host path for a thread's data. + + When *user_id* is provided: + `{base_dir}/users/{user_id}/threads/{thread_id}/` + Otherwise (legacy layout): + `{base_dir}/threads/{thread_id}/` This directory contains a `user-data/` subdirectory that is mounted as `/mnt/user-data/` inside the sandbox. Raises: - ValueError: If `thread_id` contains unsafe characters (path separators - or `..`) that could cause directory traversal. + ValueError: If `thread_id` or `user_id` contains unsafe characters (path + separators or `..`) that could cause directory traversal. """ + if user_id is not None: + return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id) return self.base_dir / "threads" / _validate_thread_id(thread_id) - def sandbox_work_dir(self, thread_id: str) -> Path: + def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: """ Host path for the agent's workspace directory. Host: `{base_dir}/threads/{thread_id}/user-data/workspace/` Sandbox: `/mnt/user-data/workspace/` """ - return self.thread_dir(thread_id) / "user-data" / "workspace" + return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace" - def sandbox_uploads_dir(self, thread_id: str) -> Path: + def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: """ Host path for user-uploaded files. Host: `{base_dir}/threads/{thread_id}/user-data/uploads/` Sandbox: `/mnt/user-data/uploads/` """ - return self.thread_dir(thread_id) / "user-data" / "uploads" + return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads" - def sandbox_outputs_dir(self, thread_id: str) -> Path: + def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: """ Host path for agent-generated artifacts. Host: `{base_dir}/threads/{thread_id}/user-data/outputs/` Sandbox: `/mnt/user-data/outputs/` """ - return self.thread_dir(thread_id) / "user-data" / "outputs" + return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs" - def acp_workspace_dir(self, thread_id: str) -> Path: + def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: """ Host path for the ACP workspace of a specific thread. Host: `{base_dir}/threads/{thread_id}/acp-workspace/` @@ -180,41 +207,43 @@ class Paths: Each thread gets its own isolated ACP workspace so that concurrent sessions cannot read each other's ACP agent outputs. """ - return self.thread_dir(thread_id) / "acp-workspace" + return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace" - def sandbox_user_data_dir(self, thread_id: str) -> Path: + def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: """ Host path for the user-data root. Host: `{base_dir}/threads/{thread_id}/user-data/` Sandbox: `/mnt/user-data/` """ - return self.thread_dir(thread_id) / "user-data" + return self.thread_dir(thread_id, user_id=user_id) / "user-data" - def host_thread_dir(self, thread_id: str) -> str: + def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str: """Host path for a thread directory, preserving Windows path syntax.""" + if user_id is not None: + return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id)) return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id)) - def host_sandbox_user_data_dir(self, thread_id: str) -> str: + def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str: """Host path for a thread's user-data root.""" - return _join_host_path(self.host_thread_dir(thread_id), "user-data") + return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data") - def host_sandbox_work_dir(self, thread_id: str) -> str: + def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str: """Host path for the workspace mount source.""" - return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace") + return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace") - def host_sandbox_uploads_dir(self, thread_id: str) -> str: + def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str: """Host path for the uploads mount source.""" - return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads") + return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads") - def host_sandbox_outputs_dir(self, thread_id: str) -> str: + def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str: """Host path for the outputs mount source.""" - return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs") + return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs") - def host_acp_workspace_dir(self, thread_id: str) -> str: + def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str: """Host path for the ACP workspace mount source.""" - return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace") + return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace") - def ensure_thread_dirs(self, thread_id: str) -> None: + def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None: """Create all standard sandbox directories for a thread. Directories are created with mode 0o777 so that sandbox containers @@ -228,24 +257,24 @@ class Paths: ACP agent invocation. """ for d in [ - self.sandbox_work_dir(thread_id), - self.sandbox_uploads_dir(thread_id), - self.sandbox_outputs_dir(thread_id), - self.acp_workspace_dir(thread_id), + self.sandbox_work_dir(thread_id, user_id=user_id), + self.sandbox_uploads_dir(thread_id, user_id=user_id), + self.sandbox_outputs_dir(thread_id, user_id=user_id), + self.acp_workspace_dir(thread_id, user_id=user_id), ]: d.mkdir(parents=True, exist_ok=True) d.chmod(0o777) - def delete_thread_dir(self, thread_id: str) -> None: + def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None: """Delete all persisted data for a thread. The operation is idempotent: missing thread directories are ignored. """ - thread_dir = self.thread_dir(thread_id) + thread_dir = self.thread_dir(thread_id, user_id=user_id) if thread_dir.exists(): shutil.rmtree(thread_dir) - def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path: + def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path: """Resolve a sandbox virtual path to the actual host filesystem path. Args: @@ -253,6 +282,7 @@ class Paths: virtual_path: Virtual path as seen inside the sandbox, e.g. ``/mnt/user-data/outputs/report.pdf``. Leading slashes are stripped before matching. + user_id: Optional user ID for user-scoped path resolution. Returns: The resolved absolute host filesystem path. @@ -270,7 +300,7 @@ class Paths: raise ValueError(f"Path must start with /{prefix}") relative = stripped[len(prefix) :].lstrip("/") - base = self.sandbox_user_data_dir(thread_id).resolve() + base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve() actual = (base / relative).resolve() try: diff --git a/backend/packages/harness/deerflow/runtime/events/store/base.py b/backend/packages/harness/deerflow/runtime/events/store/base.py index e5da4ed82..df5136ba5 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/base.py +++ b/backend/packages/harness/deerflow/runtime/events/store/base.py @@ -83,8 +83,18 @@ class RunEventStore(abc.ABC): self, thread_id: str, run_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, ) -> list[dict]: - """Return displayable messages (category=message) for a specific run, ordered by seq ascending.""" + """Return displayable messages (category=message) for a specific run, ordered by seq ascending. + + Supports bidirectional cursor pagination: + - after_seq: return the first ``limit`` records with seq > after_seq (ascending) + - before_seq: return the last ``limit`` records with seq < before_seq (ascending) + - neither: return the latest ``limit`` records (ascending) + """ @abc.abstractmethod async def count_messages(self, thread_id: str) -> int: diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 63328db43..e4a21d006 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -205,16 +205,35 @@ class DbRunEventStore(RunEventStore): thread_id, run_id, *, + limit=50, + before_seq=None, + after_seq=None, user_id: str | None | _AutoSentinel = AUTO, ): resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run") - stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message") + stmt = select(RunEventRow).where( + RunEventRow.thread_id == thread_id, + RunEventRow.run_id == run_id, + RunEventRow.category == "message", + ) if resolved_user_id is not None: stmt = stmt.where(RunEventRow.user_id == resolved_user_id) - stmt = stmt.order_by(RunEventRow.seq.asc()) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] + if before_seq is not None: + stmt = stmt.where(RunEventRow.seq < before_seq) + if after_seq is not None: + stmt = stmt.where(RunEventRow.seq > after_seq) + + if after_seq is not None: + stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + else: + stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + rows = list(result.scalars()) + return [self._row_to_dict(r) for r in reversed(rows)] async def count_messages( self, diff --git a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py index 1a4aac38c..378713afc 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py +++ b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py @@ -152,9 +152,17 @@ class JsonlRunEventStore(RunEventStore): events = [e for e in events if e.get("event_type") in event_types] return events[:limit] - async def list_messages_by_run(self, thread_id, run_id): + async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None): events = self._read_run_events(thread_id, run_id) - return [e for e in events if e.get("category") == "message"] + filtered = [e for e in events if e.get("category") == "message"] + if before_seq is not None: + filtered = [e for e in filtered if e.get("seq", 0) < before_seq] + if after_seq is not None: + filtered = [e for e in filtered if e.get("seq", 0) > after_seq] + if after_seq is not None: + return filtered[:limit] + else: + return filtered[-limit:] if len(filtered) > limit else filtered async def count_messages(self, thread_id): all_events = self._read_thread_events(thread_id) diff --git a/backend/packages/harness/deerflow/runtime/events/store/memory.py b/backend/packages/harness/deerflow/runtime/events/store/memory.py index 889159086..cf70e1cdf 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/events/store/memory.py @@ -97,9 +97,17 @@ class MemoryRunEventStore(RunEventStore): filtered = [e for e in filtered if e["event_type"] in event_types] return filtered[:limit] - async def list_messages_by_run(self, thread_id, run_id): + async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None): all_events = self._events.get(thread_id, []) - return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"] + filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"] + if before_seq is not None: + filtered = [e for e in filtered if e["seq"] < before_seq] + if after_seq is not None: + filtered = [e for e in filtered if e["seq"] > after_seq] + if after_seq is not None: + return filtered[:limit] + else: + return filtered[-limit:] if len(filtered) > limit else filtered async def count_messages(self, thread_id): all_events = self._events.get(thread_id, []) diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 74581a275..506ab00cd 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -87,6 +87,8 @@ async def run_agent( journal = None + journal = None + # Track whether "events" was requested but skipped if "events" in requested_modes: logger.info( diff --git a/backend/packages/harness/deerflow/runtime/user_context.py b/backend/packages/harness/deerflow/runtime/user_context.py index 33fce65d5..ffe4be690 100644 --- a/backend/packages/harness/deerflow/runtime/user_context.py +++ b/backend/packages/harness/deerflow/runtime/user_context.py @@ -90,6 +90,25 @@ def require_current_user() -> CurrentUser: return user +# --------------------------------------------------------------------------- +# Effective user_id helpers (filesystem isolation) +# --------------------------------------------------------------------------- + +DEFAULT_USER_ID: Final[str] = "default" + + +def get_effective_user_id() -> str: + """Return the current user's id as a string, or DEFAULT_USER_ID if unset. + + Unlike :func:`require_current_user` this never raises — it is designed + for filesystem-path resolution where a valid user bucket is always needed. + """ + user = _current_user.get() + if user is None: + return DEFAULT_USER_ID + return str(user.id) + + # --------------------------------------------------------------------------- # Sentinel-based user_id resolution # --------------------------------------------------------------------------- diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index 089fa725d..601a7efb8 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -200,8 +200,9 @@ def _get_acp_workspace_host_path(thread_id: str | None = None) -> str | None: if thread_id is not None: try: from deerflow.config.paths import get_paths + from deerflow.runtime.user_context import get_effective_user_id - host_path = get_paths().acp_workspace_dir(thread_id) + host_path = get_paths().acp_workspace_dir(thread_id, user_id=get_effective_user_id()) if host_path.exists(): return str(host_path) except Exception: diff --git a/backend/packages/harness/deerflow/tools/builtins/invoke_acp_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/invoke_acp_agent_tool.py index baf7f8ff5..618649020 100644 --- a/backend/packages/harness/deerflow/tools/builtins/invoke_acp_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/invoke_acp_agent_tool.py @@ -33,11 +33,12 @@ def _get_work_dir(thread_id: str | None) -> str: An absolute physical filesystem path to use as the working directory. """ from deerflow.config.paths import get_paths + from deerflow.runtime.user_context import get_effective_user_id paths = get_paths() if thread_id: try: - work_dir = paths.acp_workspace_dir(thread_id) + work_dir = paths.acp_workspace_dir(thread_id, user_id=get_effective_user_id()) except ValueError: logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id) work_dir = paths.base_dir / "acp-workspace" diff --git a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py index 1e0c76105..39cc61c4f 100644 --- a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py @@ -8,6 +8,7 @@ from langgraph.typing import ContextT from deerflow.agents.thread_state import ThreadState from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths +from deerflow.runtime.user_context import get_effective_user_id OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs" @@ -47,7 +48,7 @@ def _normalize_presented_filepath( virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/") if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"): - actual_path = get_paths().resolve_virtual_path(thread_id, filepath) + actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id()) else: actual_path = Path(filepath).expanduser().resolve() diff --git a/backend/packages/harness/deerflow/uploads/manager.py b/backend/packages/harness/deerflow/uploads/manager.py index 8c60399e7..c36151b38 100644 --- a/backend/packages/harness/deerflow/uploads/manager.py +++ b/backend/packages/harness/deerflow/uploads/manager.py @@ -10,6 +10,7 @@ from pathlib import Path from urllib.parse import quote from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths +from deerflow.runtime.user_context import get_effective_user_id class PathTraversalError(ValueError): @@ -33,7 +34,7 @@ def validate_thread_id(thread_id: str) -> None: def get_uploads_dir(thread_id: str) -> Path: """Return the uploads directory path for a thread (no side effects).""" validate_thread_id(thread_id) - return get_paths().sandbox_uploads_dir(thread_id) + return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) def ensure_uploads_dir(thread_id: str) -> Path: diff --git a/backend/scripts/migrate_user_isolation.py b/backend/scripts/migrate_user_isolation.py new file mode 100644 index 000000000..4d37a0d1e --- /dev/null +++ b/backend/scripts/migrate_user_isolation.py @@ -0,0 +1,160 @@ +"""One-time migration: move legacy thread dirs and memory into per-user layout. + +Usage: + PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run] + +The script is idempotent — re-running it after a successful migration is a no-op. +""" +import argparse +import json +import logging +import shutil +from pathlib import Path + +from deerflow.config.paths import Paths, get_paths + +logger = logging.getLogger(__name__) + + +def migrate_thread_dirs( + paths: Paths, + thread_owner_map: dict[str, str], + *, + dry_run: bool = False, +) -> list[dict]: + """Move legacy thread directories into per-user layout. + + Args: + paths: Paths instance. + thread_owner_map: Mapping of thread_id -> user_id from threads_meta table. + dry_run: If True, only log what would happen. + + Returns: + List of migration report entries. + """ + report: list[dict] = [] + legacy_threads = paths.base_dir / "threads" + if not legacy_threads.exists(): + logger.info("No legacy threads directory found — nothing to migrate.") + return report + + for thread_dir in sorted(legacy_threads.iterdir()): + if not thread_dir.is_dir(): + continue + thread_id = thread_dir.name + user_id = thread_owner_map.get(thread_id, "default") + dest = paths.base_dir / "users" / user_id / "threads" / thread_id + + entry = {"thread_id": thread_id, "user_id": user_id, "action": ""} + + if dest.exists(): + conflicts_dir = paths.base_dir / "migration-conflicts" / thread_id + entry["action"] = f"conflict -> {conflicts_dir}" + if not dry_run: + conflicts_dir.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(thread_dir), str(conflicts_dir)) + logger.warning("Conflict for thread %s: moved to %s", thread_id, conflicts_dir) + else: + entry["action"] = f"moved -> {dest}" + if not dry_run: + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(thread_dir), str(dest)) + logger.info("Migrated thread %s -> user %s", thread_id, user_id) + + report.append(entry) + + # Clean up empty legacy threads dir + if not dry_run and legacy_threads.exists() and not any(legacy_threads.iterdir()): + legacy_threads.rmdir() + + return report + + +def migrate_memory( + paths: Paths, + user_id: str = "default", + *, + dry_run: bool = False, +) -> None: + """Move legacy global memory.json into per-user layout. + + Args: + paths: Paths instance. + user_id: Target user to receive the legacy memory. + dry_run: If True, only log. + """ + legacy_mem = paths.base_dir / "memory.json" + if not legacy_mem.exists(): + logger.info("No legacy memory.json found — nothing to migrate.") + return + + dest = paths.user_memory_file(user_id) + if dest.exists(): + legacy_backup = paths.base_dir / "memory.legacy.json" + logger.warning("Destination %s exists; renaming legacy to %s", dest, legacy_backup) + if not dry_run: + legacy_mem.rename(legacy_backup) + return + + logger.info("Migrating memory.json -> %s", dest) + if not dry_run: + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(legacy_mem), str(dest)) + + +def _build_owner_map_from_db(paths: Paths) -> dict[str, str]: + """Query threads_meta table for thread_id -> user_id mapping. + + Uses raw sqlite3 to avoid async dependencies. + """ + import sqlite3 + + db_path = paths.base_dir / "deer-flow.db" + if not db_path.exists(): + logger.info("No database found at %s — using empty owner map.", db_path) + return {} + + conn = sqlite3.connect(str(db_path)) + try: + cursor = conn.execute("SELECT thread_id, user_id FROM threads_meta WHERE user_id IS NOT NULL") + return {row[0]: row[1] for row in cursor.fetchall()} + except sqlite3.OperationalError as e: + logger.warning("Failed to query threads_meta: %s", e) + return {} + finally: + conn.close() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout") + parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + paths = get_paths() + logger.info("Base directory: %s", paths.base_dir) + logger.info("Dry run: %s", args.dry_run) + + owner_map = _build_owner_map_from_db(paths) + logger.info("Found %d thread ownership records in DB", len(owner_map)) + + report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run) + migrate_memory(paths, user_id="default", dry_run=args.dry_run) + + if report: + logger.info("Migration report:") + for entry in report: + logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"]) + else: + logger.info("No threads to migrate.") + + unowned = [e for e in report if e["user_id"] == "default"] + if unowned: + logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned)) + for e in unowned: + logger.warning(" %s", e["thread_id"]) + + +if __name__ == "__main__": + main() diff --git a/backend/tests/test_aio_sandbox_provider.py b/backend/tests/test_aio_sandbox_provider.py index e797cf7e3..c7984531f 100644 --- a/backend/tests/test_aio_sandbox_provider.py +++ b/backend/tests/test_aio_sandbox_provider.py @@ -57,6 +57,7 @@ def test_get_thread_mounts_includes_acp_workspace(tmp_path, monkeypatch): """_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox.""" aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path)) + monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None) mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3") @@ -95,6 +96,7 @@ def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypat aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow") monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path)) + monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None) mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10") diff --git a/backend/tests/test_channel_file_attachments.py b/backend/tests/test_channel_file_attachments.py index 2843a9cd0..7273b1c82 100644 --- a/backend/tests/test_channel_file_attachments.py +++ b/backend/tests/test_channel_file_attachments.py @@ -231,7 +231,7 @@ class TestResolveAttachments: mock_paths = MagicMock() mock_paths.sandbox_outputs_dir.return_value = outputs_dir - def resolve_side_effect(tid, vpath): + def resolve_side_effect(tid, vpath, *, user_id=None): if "data.csv" in vpath: return good_file return tmp_path / "missing.txt" diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index a9b854e8e..d22e36d17 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -1241,7 +1241,10 @@ class TestMemoryManagement: with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import: result = client.import_memory(imported) - mock_import.assert_called_once_with(imported) + assert mock_import.call_count == 1 + call_args = mock_import.call_args + assert call_args.args == (imported,) + assert "user_id" in call_args.kwargs assert result == imported def test_reload_memory(self, client): @@ -1487,9 +1490,12 @@ class TestUploads: class TestArtifacts: def test_get_artifact(self, client): + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - outputs = paths.sandbox_outputs_dir("t1") + user_id = get_effective_user_id() + outputs = paths.sandbox_outputs_dir("t1", user_id=user_id) outputs.mkdir(parents=True) (outputs / "result.txt").write_text("artifact content") @@ -1500,9 +1506,12 @@ class TestArtifacts: assert "text" in mime def test_get_artifact_not_found(self, client): + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - paths.sandbox_user_data_dir("t1").mkdir(parents=True) + user_id = get_effective_user_id() + paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True) with patch("deerflow.client.get_paths", return_value=paths): with pytest.raises(FileNotFoundError): @@ -1513,9 +1522,12 @@ class TestArtifacts: client.get_artifact("t1", "bad/path/file.txt") def test_get_artifact_path_traversal(self, client): + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - paths.sandbox_user_data_dir("t1").mkdir(parents=True) + user_id = get_effective_user_id() + paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True) with patch("deerflow.client.get_paths", return_value=paths): with pytest.raises(PathTraversalError): @@ -1699,13 +1711,16 @@ class TestScenarioFileLifecycle: def test_upload_then_read_artifact(self, client): """Upload a file, simulate agent producing artifact, read it back.""" + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) uploads_dir = tmp_path / "uploads" uploads_dir.mkdir() paths = Paths(base_dir=tmp_path) - outputs_dir = paths.sandbox_outputs_dir("t-artifact") + user_id = get_effective_user_id() + outputs_dir = paths.sandbox_outputs_dir("t-artifact", user_id=user_id) outputs_dir.mkdir(parents=True) # Upload phase @@ -1955,11 +1970,14 @@ class TestScenarioThreadIsolation: def test_artifacts_isolated_per_thread(self, client): """Artifacts in thread-A are not accessible from thread-B.""" + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - outputs_a = paths.sandbox_outputs_dir("thread-a") + user_id = get_effective_user_id() + outputs_a = paths.sandbox_outputs_dir("thread-a", user_id=user_id) outputs_a.mkdir(parents=True) - paths.sandbox_user_data_dir("thread-b").mkdir(parents=True) + paths.sandbox_outputs_dir("thread-b", user_id=user_id).mkdir(parents=True) (outputs_a / "result.txt").write_text("thread-a artifact") with patch("deerflow.client.get_paths", return_value=paths): @@ -2864,9 +2882,12 @@ class TestUploadDeleteSymlink: class TestArtifactHardening: def test_artifact_directory_rejected(self, client): """get_artifact rejects paths that resolve to a directory.""" + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - subdir = paths.sandbox_outputs_dir("t1") / "subdir" + user_id = get_effective_user_id() + subdir = paths.sandbox_outputs_dir("t1", user_id=user_id) / "subdir" subdir.mkdir(parents=True) with patch("deerflow.client.get_paths", return_value=paths): @@ -2875,9 +2896,12 @@ class TestArtifactHardening: def test_artifact_leading_slash_stripped(self, client): """Paths with leading slash are handled correctly.""" + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - outputs = paths.sandbox_outputs_dir("t1") + user_id = get_effective_user_id() + outputs = paths.sandbox_outputs_dir("t1", user_id=user_id) outputs.mkdir(parents=True) (outputs / "file.txt").write_text("content") @@ -2991,9 +3015,12 @@ class TestBugArtifactPrefixMatchTooLoose: def test_exact_prefix_without_subpath_accepted(self, client): """Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix).""" + from deerflow.runtime.user_context import get_effective_user_id + with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) - paths.sandbox_user_data_dir("t1").mkdir(parents=True) + user_id = get_effective_user_id() + paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True) with patch("deerflow.client.get_paths", return_value=paths): # Accepted at prefix check, but fails because it's a directory. diff --git a/backend/tests/test_client_e2e.py b/backend/tests/test_client_e2e.py index b26e5bff1..6c688933a 100644 --- a/backend/tests/test_client_e2e.py +++ b/backend/tests/test_client_e2e.py @@ -262,8 +262,9 @@ class TestFileUploadIntegration: # Physically exists from deerflow.config.paths import get_paths + from deerflow.runtime.user_context import get_effective_user_id - assert (get_paths().sandbox_uploads_dir(tid) / "readme.txt").exists() + assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists() def test_upload_duplicate_rename(self, e2e_env, tmp_path): """Uploading two files with the same name auto-renames the second.""" @@ -472,12 +473,13 @@ class TestArtifactAccess: def test_get_artifact_happy_path(self, e2e_env): """Write a file to outputs, then read it back via get_artifact().""" from deerflow.config.paths import get_paths + from deerflow.runtime.user_context import get_effective_user_id c = DeerFlowClient(checkpointer=None, thinking_enabled=False) tid = str(uuid.uuid4()) # Create an output file in the thread's outputs directory - outputs_dir = get_paths().sandbox_outputs_dir(tid) + outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id()) outputs_dir.mkdir(parents=True, exist_ok=True) (outputs_dir / "result.txt").write_text("hello artifact") @@ -488,11 +490,12 @@ class TestArtifactAccess: def test_get_artifact_nested_path(self, e2e_env): """Artifacts in subdirectories are accessible.""" from deerflow.config.paths import get_paths + from deerflow.runtime.user_context import get_effective_user_id c = DeerFlowClient(checkpointer=None, thinking_enabled=False) tid = str(uuid.uuid4()) - outputs_dir = get_paths().sandbox_outputs_dir(tid) + outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id()) sub = outputs_dir / "charts" sub.mkdir(parents=True, exist_ok=True) (sub / "data.json").write_text('{"x": 1}') diff --git a/backend/tests/test_invoke_acp_agent_tool.py b/backend/tests/test_invoke_acp_agent_tool.py index 8063875cf..3c5f6f0ff 100644 --- a/backend/tests/test_invoke_acp_agent_tool.py +++ b/backend/tests/test_invoke_acp_agent_tool.py @@ -152,8 +152,10 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path): def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path): """P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/.""" from deerflow.config import paths as paths_module + from deerflow.runtime import user_context as uc_module monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path)) + monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None) result = _get_work_dir("thread-abc-123") expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace" assert result == str(expected) @@ -310,8 +312,10 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path): async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path): """P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace.""" from deerflow.config import paths as paths_module + from deerflow.runtime import user_context as uc_module monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path)) + monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None) monkeypatch.setattr( "deerflow.config.extensions_config.ExtensionsConfig.from_file", diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 204f9d16e..454cf2bf2 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -48,6 +48,7 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None: agent_name="lead_agent", correction_detected=True, reinforcement_detected=False, + user_id=None, ) @@ -88,4 +89,5 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None: agent_name="lead_agent", correction_detected=False, reinforcement_detected=True, + user_id=None, ) diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py new file mode 100644 index 000000000..1a209d659 --- /dev/null +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -0,0 +1,38 @@ +"""Tests for user_id propagation through memory queue.""" +from unittest.mock import MagicMock, patch + +from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue + + +def test_conversation_context_has_user_id(): + ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice") + assert ctx.user_id == "alice" + + +def test_conversation_context_user_id_default_none(): + ctx = ConversationContext(thread_id="t1", messages=[]) + assert ctx.user_id is None + + +def test_queue_add_stores_user_id(): + q = MemoryUpdateQueue() + with patch.object(q, "_reset_timer"): + q.add(thread_id="t1", messages=["msg"], user_id="alice") + assert len(q._queue) == 1 + assert q._queue[0].user_id == "alice" + q.clear() + + +def test_queue_process_passes_user_id_to_updater(): + q = MemoryUpdateQueue() + with patch.object(q, "_reset_timer"): + q.add(thread_id="t1", messages=["msg"], user_id="alice") + + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater): + q._process_queue() + + mock_updater.update_memory.assert_called_once() + call_kwargs = mock_updater.update_memory.call_args.kwargs + assert call_kwargs["user_id"] == "alice" diff --git a/backend/tests/test_memory_router.py b/backend/tests/test_memory_router.py index 23a4f30fe..91fd1d662 100644 --- a/backend/tests/test_memory_router.py +++ b/backend/tests/test_memory_router.py @@ -258,12 +258,13 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None: ) assert response.status_code == 200 - update_fact.assert_called_once_with( - fact_id="fact_edit", - content="User prefers spaces", - category=None, - confidence=None, - ) + assert update_fact.call_count == 1 + call_kwargs = update_fact.call_args.kwargs + assert call_kwargs.get("fact_id") == "fact_edit" + assert call_kwargs.get("content") == "User prefers spaces" + assert call_kwargs.get("category") is None + assert call_kwargs.get("confidence") is None + assert "user_id" in call_kwargs assert response.json()["facts"] == updated_memory["facts"] diff --git a/backend/tests/test_memory_storage_user_isolation.py b/backend/tests/test_memory_storage_user_isolation.py new file mode 100644 index 000000000..a82fffa50 --- /dev/null +++ b/backend/tests/test_memory_storage_user_isolation.py @@ -0,0 +1,150 @@ +"""Tests for per-user memory storage isolation.""" +import pytest +from pathlib import Path +from unittest.mock import patch + +from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory + + +@pytest.fixture +def base_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture +def storage() -> FileMemoryStorage: + return FileMemoryStorage() + + +class TestUserIsolatedStorage: + def test_save_and_load_per_user(self, storage: FileMemoryStorage, base_dir: Path): + from deerflow.config.paths import Paths + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + memory_a = create_empty_memory() + memory_a["user"]["workContext"]["summary"] = "User A context" + storage.save(memory_a, user_id="alice") + + memory_b = create_empty_memory() + memory_b["user"]["workContext"]["summary"] = "User B context" + storage.save(memory_b, user_id="bob") + + loaded_a = storage.load(user_id="alice") + loaded_b = storage.load(user_id="bob") + + assert loaded_a["user"]["workContext"]["summary"] == "User A context" + assert loaded_b["user"]["workContext"]["summary"] == "User B context" + + def test_user_memory_file_location(self, base_dir: Path): + from deerflow.config.paths import Paths + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + s = FileMemoryStorage() + memory = create_empty_memory() + s.save(memory, user_id="alice") + expected_path = base_dir / "users" / "alice" / "memory.json" + assert expected_path.exists() + + def test_cache_isolated_per_user(self, base_dir: Path): + from deerflow.config.paths import Paths + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + s = FileMemoryStorage() + memory_a = create_empty_memory() + memory_a["user"]["workContext"]["summary"] = "A" + s.save(memory_a, user_id="alice") + + memory_b = create_empty_memory() + memory_b["user"]["workContext"]["summary"] = "B" + s.save(memory_b, user_id="bob") + + loaded_a = s.load(user_id="alice") + assert loaded_a["user"]["workContext"]["summary"] == "A" + + def test_no_user_id_uses_legacy_path(self, base_dir: Path): + from deerflow.config.paths import Paths + from deerflow.config.memory_config import MemoryConfig + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")): + s = FileMemoryStorage() + memory = create_empty_memory() + s.save(memory, user_id=None) + expected_path = base_dir / "memory.json" + assert expected_path.exists() + + def test_user_and_legacy_do_not_interfere(self, base_dir: Path): + """user_id=None (legacy) and user_id='alice' must use different files and caches.""" + from deerflow.config.paths import Paths + from deerflow.config.memory_config import MemoryConfig + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")): + s = FileMemoryStorage() + + legacy_mem = create_empty_memory() + legacy_mem["user"]["workContext"]["summary"] = "legacy" + s.save(legacy_mem, user_id=None) + + user_mem = create_empty_memory() + user_mem["user"]["workContext"]["summary"] = "alice" + s.save(user_mem, user_id="alice") + + assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy" + assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice" + + def test_user_agent_memory_file_location(self, base_dir: Path): + """Per-user per-agent memory uses the user_agent_memory_file path.""" + from deerflow.config.paths import Paths + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + s = FileMemoryStorage() + memory = create_empty_memory() + memory["user"]["workContext"]["summary"] = "agent scoped" + s.save(memory, "test-agent", user_id="alice") + expected_path = base_dir / "users" / "alice" / "agents" / "test-agent" / "memory.json" + assert expected_path.exists() + + def test_cache_key_is_user_agent_tuple(self, base_dir: Path): + """Cache keys must be (user_id, agent_name) tuples, not bare agent names.""" + from deerflow.config.paths import Paths + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + s = FileMemoryStorage() + memory = create_empty_memory() + s.save(memory, user_id="alice") + # After save, cache should have tuple key + assert ("alice", None) in s._memory_cache + + def test_reload_with_user_id(self, base_dir: Path): + """reload() with user_id should force re-read from the user-scoped file.""" + from deerflow.config.paths import Paths + + paths = Paths(base_dir) + with patch("deerflow.agents.memory.storage.get_paths", return_value=paths): + s = FileMemoryStorage() + memory = create_empty_memory() + memory["user"]["workContext"]["summary"] = "initial" + s.save(memory, user_id="alice") + + # Load once to prime cache + s.load(user_id="alice") + + # Write updated content directly to file + user_file = base_dir / "users" / "alice" / "memory.json" + import json + + updated = create_empty_memory() + updated["user"]["workContext"]["summary"] = "updated" + user_file.write_text(json.dumps(updated)) + + # reload should pick up the new content + reloaded = s.reload(user_id="alice") + assert reloaded["user"]["workContext"]["summary"] == "updated" diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index 48fdfd89e..995b652ec 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None: with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): result = import_memory_data(imported_memory) - mock_storage.save.assert_called_once_with(imported_memory, None) - mock_storage.load.assert_called_once_with(None) + mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None) + mock_storage.load.assert_called_once_with(None, user_id=None) assert result == imported_memory diff --git a/backend/tests/test_memory_updater_user_isolation.py b/backend/tests/test_memory_updater_user_isolation.py new file mode 100644 index 000000000..d38f3fc90 --- /dev/null +++ b/backend/tests/test_memory_updater_user_isolation.py @@ -0,0 +1,29 @@ +"""Tests for user_id propagation in memory updater.""" +from unittest.mock import MagicMock, patch + +from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file + + +def test_get_memory_data_passes_user_id(): + mock_storage = MagicMock() + mock_storage.load.return_value = {"version": "1.0"} + with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): + get_memory_data(user_id="alice") + mock_storage.load.assert_called_once_with(None, user_id="alice") + + +def test_save_memory_passes_user_id(): + mock_storage = MagicMock() + mock_storage.save.return_value = True + with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): + _save_memory_to_file({"version": "1.0"}, user_id="bob") + mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob") + + +def test_clear_memory_data_passes_user_id(): + mock_storage = MagicMock() + mock_storage.save.return_value = True + with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): + clear_memory_data(user_id="charlie") + # Verify save was called with user_id + assert mock_storage.save.call_args.kwargs["user_id"] == "charlie" diff --git a/backend/tests/test_migration_user_isolation.py b/backend/tests/test_migration_user_isolation.py new file mode 100644 index 000000000..8a07c2130 --- /dev/null +++ b/backend/tests/test_migration_user_isolation.py @@ -0,0 +1,116 @@ +"""Tests for per-user data migration.""" +import json +import pytest +from pathlib import Path + +from deerflow.config.paths import Paths + + +@pytest.fixture +def base_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture +def paths(base_dir: Path) -> Paths: + return Paths(base_dir) + + +class TestMigrateThreadDirs: + def test_moves_thread_to_user_dir(self, base_dir: Path, paths: Paths): + legacy = base_dir / "threads" / "t1" / "user-data" / "workspace" + legacy.mkdir(parents=True) + (legacy / "file.txt").write_text("hello") + + from scripts.migrate_user_isolation import migrate_thread_dirs + migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}) + + expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt" + assert expected.exists() + assert expected.read_text() == "hello" + assert not (base_dir / "threads" / "t1").exists() + + def test_unowned_thread_goes_to_default(self, base_dir: Path, paths: Paths): + legacy = base_dir / "threads" / "t2" / "user-data" / "workspace" + legacy.mkdir(parents=True) + + from scripts.migrate_user_isolation import migrate_thread_dirs + migrate_thread_dirs(paths, thread_owner_map={}) + + expected = base_dir / "users" / "default" / "threads" / "t2" + assert expected.exists() + + def test_idempotent_skip_already_migrated(self, base_dir: Path, paths: Paths): + new_dir = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" + new_dir.mkdir(parents=True) + + from scripts.migrate_user_isolation import migrate_thread_dirs + migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}) + assert new_dir.exists() + + def test_conflict_preserved(self, base_dir: Path, paths: Paths): + legacy = base_dir / "threads" / "t1" / "user-data" / "workspace" + legacy.mkdir(parents=True) + (legacy / "old.txt").write_text("old") + + dest = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" + dest.mkdir(parents=True) + (dest / "new.txt").write_text("new") + + from scripts.migrate_user_isolation import migrate_thread_dirs + migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}) + + assert (dest / "new.txt").read_text() == "new" + conflicts = base_dir / "migration-conflicts" / "t1" + assert conflicts.exists() + + def test_cleans_up_empty_legacy_dir(self, base_dir: Path, paths: Paths): + legacy = base_dir / "threads" / "t1" / "user-data" + legacy.mkdir(parents=True) + + from scripts.migrate_user_isolation import migrate_thread_dirs + migrate_thread_dirs(paths, thread_owner_map={}) + + assert not (base_dir / "threads").exists() + + def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths): + legacy = base_dir / "threads" / "t1" / "user-data" + legacy.mkdir(parents=True) + + from scripts.migrate_user_isolation import migrate_thread_dirs + report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True) + + assert len(report) == 1 + assert (base_dir / "threads" / "t1").exists() # not moved + assert not (base_dir / "users" / "alice" / "threads" / "t1").exists() + + +class TestMigrateMemory: + def test_moves_global_memory(self, base_dir: Path, paths: Paths): + legacy_mem = base_dir / "memory.json" + legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []})) + + from scripts.migrate_user_isolation import migrate_memory + migrate_memory(paths, user_id="default") + + expected = base_dir / "users" / "default" / "memory.json" + assert expected.exists() + assert not legacy_mem.exists() + + def test_skips_if_destination_exists(self, base_dir: Path, paths: Paths): + legacy_mem = base_dir / "memory.json" + legacy_mem.write_text(json.dumps({"version": "old"})) + + dest = base_dir / "users" / "default" / "memory.json" + dest.parent.mkdir(parents=True) + dest.write_text(json.dumps({"version": "new"})) + + from scripts.migrate_user_isolation import migrate_memory + migrate_memory(paths, user_id="default") + + assert json.loads(dest.read_text())["version"] == "new" + assert (base_dir / "memory.legacy.json").exists() + + def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths): + from scripts.migrate_user_isolation import migrate_memory + migrate_memory(paths, user_id="default") # should not raise diff --git a/backend/tests/test_paths_user_isolation.py b/backend/tests/test_paths_user_isolation.py new file mode 100644 index 000000000..e74276a32 --- /dev/null +++ b/backend/tests/test_paths_user_isolation.py @@ -0,0 +1,167 @@ +"""Tests for user-scoped path resolution in Paths.""" +import pytest +from pathlib import Path + +from deerflow.config.paths import Paths + + +@pytest.fixture +def paths(tmp_path: Path) -> Paths: + return Paths(tmp_path) + + +class TestValidateUserId: + def test_valid_user_id(self, paths: Paths): + d = paths.user_dir("u-abc-123") + assert d == paths.base_dir / "users" / "u-abc-123" + + def test_rejects_path_traversal(self, paths: Paths): + with pytest.raises(ValueError, match="Invalid user_id"): + paths.user_dir("../escape") + + def test_rejects_slash(self, paths: Paths): + with pytest.raises(ValueError, match="Invalid user_id"): + paths.user_dir("foo/bar") + + def test_rejects_empty(self, paths: Paths): + with pytest.raises(ValueError, match="Invalid user_id"): + paths.user_dir("") + + +class TestUserDir: + def test_user_dir(self, paths: Paths): + assert paths.user_dir("alice") == paths.base_dir / "users" / "alice" + + +class TestUserMemoryFile: + def test_user_memory_file(self, paths: Paths): + assert paths.user_memory_file("bob") == paths.base_dir / "users" / "bob" / "memory.json" + + +class TestUserAgentMemoryFile: + def test_user_agent_memory_file(self, paths: Paths): + expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json" + assert paths.user_agent_memory_file("bob", "myagent") == expected + + def test_user_agent_memory_file_lowercases_name(self, paths: Paths): + expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json" + assert paths.user_agent_memory_file("bob", "MyAgent") == expected + + +class TestUserThreadDir: + def test_user_thread_dir(self, paths: Paths): + expected = paths.base_dir / "users" / "u1" / "threads" / "t1" + assert paths.thread_dir("t1", user_id="u1") == expected + + def test_thread_dir_no_user_id_falls_back_to_legacy(self, paths: Paths): + expected = paths.base_dir / "threads" / "t1" + assert paths.thread_dir("t1") == expected + + +class TestUserSandboxDirs: + def test_sandbox_work_dir(self, paths: Paths): + expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "workspace" + assert paths.sandbox_work_dir("t1", user_id="u1") == expected + + def test_sandbox_uploads_dir(self, paths: Paths): + expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "uploads" + assert paths.sandbox_uploads_dir("t1", user_id="u1") == expected + + def test_sandbox_outputs_dir(self, paths: Paths): + expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "outputs" + assert paths.sandbox_outputs_dir("t1", user_id="u1") == expected + + def test_sandbox_user_data_dir(self, paths: Paths): + expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" + assert paths.sandbox_user_data_dir("t1", user_id="u1") == expected + + def test_acp_workspace_dir(self, paths: Paths): + expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "acp-workspace" + assert paths.acp_workspace_dir("t1", user_id="u1") == expected + + def test_legacy_sandbox_work_dir(self, paths: Paths): + expected = paths.base_dir / "threads" / "t1" / "user-data" / "workspace" + assert paths.sandbox_work_dir("t1") == expected + + +class TestHostPathsWithUserId: + def test_host_thread_dir_with_user_id(self, paths: Paths): + result = paths.host_thread_dir("t1", user_id="u1") + assert "users" in result + assert "u1" in result + assert "threads" in result + assert "t1" in result + + def test_host_thread_dir_legacy(self, paths: Paths): + result = paths.host_thread_dir("t1") + assert "threads" in result + assert "t1" in result + assert "users" not in result + + def test_host_sandbox_user_data_dir_with_user_id(self, paths: Paths): + result = paths.host_sandbox_user_data_dir("t1", user_id="u1") + assert "users" in result + assert "user-data" in result + + def test_host_sandbox_work_dir_with_user_id(self, paths: Paths): + result = paths.host_sandbox_work_dir("t1", user_id="u1") + assert "workspace" in result + + def test_host_sandbox_uploads_dir_with_user_id(self, paths: Paths): + result = paths.host_sandbox_uploads_dir("t1", user_id="u1") + assert "uploads" in result + + def test_host_sandbox_outputs_dir_with_user_id(self, paths: Paths): + result = paths.host_sandbox_outputs_dir("t1", user_id="u1") + assert "outputs" in result + + def test_host_acp_workspace_dir_with_user_id(self, paths: Paths): + result = paths.host_acp_workspace_dir("t1", user_id="u1") + assert "acp-workspace" in result + + +class TestEnsureAndDeleteWithUserId: + def test_ensure_thread_dirs_creates_user_scoped(self, paths: Paths): + paths.ensure_thread_dirs("t1", user_id="u1") + assert paths.sandbox_work_dir("t1", user_id="u1").is_dir() + assert paths.sandbox_uploads_dir("t1", user_id="u1").is_dir() + assert paths.sandbox_outputs_dir("t1", user_id="u1").is_dir() + assert paths.acp_workspace_dir("t1", user_id="u1").is_dir() + + def test_delete_thread_dir_removes_user_scoped(self, paths: Paths): + paths.ensure_thread_dirs("t1", user_id="u1") + assert paths.thread_dir("t1", user_id="u1").exists() + paths.delete_thread_dir("t1", user_id="u1") + assert not paths.thread_dir("t1", user_id="u1").exists() + + def test_delete_thread_dir_idempotent(self, paths: Paths): + paths.delete_thread_dir("nonexistent", user_id="u1") # should not raise + + def test_ensure_thread_dirs_legacy_still_works(self, paths: Paths): + paths.ensure_thread_dirs("t1") + assert paths.sandbox_work_dir("t1").is_dir() + + def test_user_scoped_and_legacy_are_independent(self, paths: Paths): + paths.ensure_thread_dirs("t1", user_id="u1") + paths.ensure_thread_dirs("t1") + # Both exist independently + assert paths.thread_dir("t1", user_id="u1").exists() + assert paths.thread_dir("t1").exists() + # Delete one doesn't affect the other + paths.delete_thread_dir("t1", user_id="u1") + assert not paths.thread_dir("t1", user_id="u1").exists() + assert paths.thread_dir("t1").exists() + + +class TestResolveVirtualPathWithUserId: + def test_resolve_virtual_path_with_user_id(self, paths: Paths): + paths.ensure_thread_dirs("t1", user_id="u1") + result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt", user_id="u1") + expected_base = paths.sandbox_user_data_dir("t1", user_id="u1").resolve() + assert str(result).startswith(str(expected_base)) + + def test_resolve_virtual_path_legacy(self, paths: Paths): + paths.ensure_thread_dirs("t1") + result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt") + expected_base = paths.sandbox_user_data_dir("t1").resolve() + assert str(result).startswith(str(expected_base)) diff --git a/backend/tests/test_present_file_tool_core_logic.py b/backend/tests/test_present_file_tool_core_logic.py index 3068ca507..de1a90e52 100644 --- a/backend/tests/test_present_file_tool_core_logic.py +++ b/backend/tests/test_present_file_tool_core_logic.py @@ -38,7 +38,7 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch): monkeypatch.setattr( present_file_tool_module, "get_paths", - lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path), + lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path, *, user_id=None: artifact_path), ) result = present_file_tool_module.present_file_tool.func( diff --git a/backend/tests/test_run_event_store_pagination.py b/backend/tests/test_run_event_store_pagination.py new file mode 100644 index 000000000..ac5ba4c2d --- /dev/null +++ b/backend/tests/test_run_event_store_pagination.py @@ -0,0 +1,107 @@ +"""Tests for paginated list_messages_by_run across all RunEventStore backends.""" +import pytest + +from deerflow.runtime.events.store.memory import MemoryRunEventStore + + +@pytest.fixture +def base_store(): + return MemoryRunEventStore() + + +@pytest.mark.anyio +async def test_list_messages_by_run_default_returns_all(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + for i in range(3): + await store.put( + thread_id="t1", run_id="run-b", + event_type="human_message", category="message", content=f"msg-b-{i}", + ) + await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace") + + msgs = await store.list_messages_by_run("t1", "run-a") + assert len(msgs) == 7 + assert all(m["category"] == "message" for m in msgs) + assert all(m["run_id"] == "run-a" for m in msgs) + + +@pytest.mark.anyio +async def test_list_messages_by_run_with_limit(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + + msgs = await store.list_messages_by_run("t1", "run-a", limit=3) + assert len(msgs) == 3 + seqs = [m["seq"] for m in msgs] + assert seqs == sorted(seqs) + + +@pytest.mark.anyio +async def test_list_messages_by_run_after_seq(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + + all_msgs = await store.list_messages_by_run("t1", "run-a") + cursor_seq = all_msgs[2]["seq"] + msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50) + assert all(m["seq"] > cursor_seq for m in msgs) + assert len(msgs) == 4 + + +@pytest.mark.anyio +async def test_list_messages_by_run_before_seq(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message" if i % 2 == 0 else "ai_message", + category="message", content=f"msg-a-{i}", + ) + + all_msgs = await store.list_messages_by_run("t1", "run-a") + cursor_seq = all_msgs[4]["seq"] + msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50) + assert all(m["seq"] < cursor_seq for m in msgs) + assert len(msgs) == 4 + + +@pytest.mark.anyio +async def test_list_messages_by_run_does_not_include_other_run(base_store): + store = base_store + for i in range(7): + await store.put( + thread_id="t1", run_id="run-a", + event_type="human_message", category="message", content=f"msg-a-{i}", + ) + for i in range(3): + await store.put( + thread_id="t1", run_id="run-b", + event_type="human_message", category="message", content=f"msg-b-{i}", + ) + + msgs = await store.list_messages_by_run("t1", "run-b") + assert len(msgs) == 3 + assert all(m["run_id"] == "run-b" for m in msgs) + + +@pytest.mark.anyio +async def test_list_messages_by_run_empty_run(base_store): + store = base_store + msgs = await store.list_messages_by_run("t1", "nonexistent") + assert msgs == [] diff --git a/backend/tests/test_runs_api_endpoints.py b/backend/tests/test_runs_api_endpoints.py new file mode 100644 index 000000000..e6b73d865 --- /dev/null +++ b/backend/tests/test_runs_api_endpoints.py @@ -0,0 +1,243 @@ +"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints.""" +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient + +from app.gateway.routers import runs + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(run_store=None, event_store=None, feedback_repo=None): + """Build a test FastAPI app with stub auth and mocked state.""" + app = make_authed_test_app() + app.include_router(runs.router) + + if run_store is not None: + app.state.run_store = run_store + if event_store is not None: + app.state.run_event_store = event_store + if feedback_repo is not None: + app.state.feedback_repo = feedback_repo + + return app + + +def _make_run_store(run_record: dict | None): + """Return an AsyncMock run store whose get() returns run_record.""" + store = MagicMock() + store.get = AsyncMock(return_value=run_record) + return store + + +def _make_event_store(rows: list[dict]): + """Return an AsyncMock event store whose list_messages_by_run() returns rows.""" + store = MagicMock() + store.list_messages_by_run = AsyncMock(return_value=rows) + return store + + +def _make_message(seq: int) -> dict: + return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"} + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_run_messages_returns_envelope(): + """GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}.""" + rows = [_make_message(i) for i in range(1, 4)] + run_record = {"run_id": "run-1", "thread_id": "thread-1"} + app = _make_app( + run_store=_make_run_store(run_record), + event_store=_make_event_store(rows), + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-1/messages") + assert response.status_code == 200 + body = response.json() + assert "data" in body + assert "has_more" in body + assert body["has_more"] is False + assert len(body["data"]) == 3 + + +def test_run_messages_404_when_run_not_found(): + """Returns 404 when the run store returns None.""" + app = _make_app( + run_store=_make_run_store(None), + event_store=_make_event_store([]), + ) + with TestClient(app) as client: + response = client.get("/api/runs/missing-run/messages") + assert response.status_code == 404 + assert "missing-run" in response.json()["detail"] + + +def test_run_messages_has_more_true_when_extra_row_returned(): + """has_more=True when event store returns limit+1 rows.""" + # Default limit is 50; provide 51 rows + rows = [_make_message(i) for i in range(1, 52)] # 51 rows + run_record = {"run_id": "run-2", "thread_id": "thread-2"} + app = _make_app( + run_store=_make_run_store(run_record), + event_store=_make_event_store(rows), + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-2/messages") + assert response.status_code == 200 + body = response.json() + assert body["has_more"] is True + assert len(body["data"]) == 50 # trimmed to limit + + +def test_run_messages_passes_after_seq_to_event_store(): + """after_seq query param is forwarded to event_store.list_messages_by_run.""" + rows = [_make_message(10)] + run_record = {"run_id": "run-3", "thread_id": "thread-3"} + event_store = _make_event_store(rows) + app = _make_app( + run_store=_make_run_store(run_record), + event_store=event_store, + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-3/messages?after_seq=5") + assert response.status_code == 200 + event_store.list_messages_by_run.assert_awaited_once_with( + "thread-3", "run-3", + limit=51, # default limit(50) + 1 + before_seq=None, + after_seq=5, + ) + + +def test_run_messages_respects_custom_limit(): + """Custom limit is respected and capped at 200.""" + rows = [_make_message(i) for i in range(1, 6)] + run_record = {"run_id": "run-4", "thread_id": "thread-4"} + event_store = _make_event_store(rows) + app = _make_app( + run_store=_make_run_store(run_record), + event_store=event_store, + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-4/messages?limit=10") + assert response.status_code == 200 + event_store.list_messages_by_run.assert_awaited_once_with( + "thread-4", "run-4", + limit=11, # 10 + 1 + before_seq=None, + after_seq=None, + ) + + +def test_run_messages_passes_before_seq_to_event_store(): + """before_seq query param is forwarded to event_store.list_messages_by_run.""" + rows = [_make_message(3)] + run_record = {"run_id": "run-5", "thread_id": "thread-5"} + event_store = _make_event_store(rows) + app = _make_app( + run_store=_make_run_store(run_record), + event_store=event_store, + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-5/messages?before_seq=10") + assert response.status_code == 200 + event_store.list_messages_by_run.assert_awaited_once_with( + "thread-5", "run-5", + limit=51, + before_seq=10, + after_seq=None, + ) + + +def test_run_messages_empty_data(): + """Returns empty data list when no messages exist.""" + run_record = {"run_id": "run-6", "thread_id": "thread-6"} + app = _make_app( + run_store=_make_run_store(run_record), + event_store=_make_event_store([]), + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-6/messages") + assert response.status_code == 200 + body = response.json() + assert body["data"] == [] + assert body["has_more"] is False + + +def _make_feedback_repo(rows: list[dict]): + """Return an AsyncMock feedback repo whose list_by_run() returns rows.""" + repo = MagicMock() + repo.list_by_run = AsyncMock(return_value=rows) + return repo + + +def _make_feedback(run_id: str, idx: int) -> dict: + return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"} + + +# --------------------------------------------------------------------------- +# TestRunFeedback +# --------------------------------------------------------------------------- + + +class TestRunFeedback: + def test_returns_list_of_feedback_dicts(self): + """GET /api/runs/{run_id}/feedback returns a list of feedback dicts.""" + run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"} + rows = [_make_feedback("run-fb-1", i) for i in range(3)] + app = _make_app( + run_store=_make_run_store(run_record), + feedback_repo=_make_feedback_repo(rows), + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-fb-1/feedback") + assert response.status_code == 200 + body = response.json() + assert isinstance(body, list) + assert len(body) == 3 + + def test_404_when_run_not_found(self): + """Returns 404 when run store returns None.""" + app = _make_app( + run_store=_make_run_store(None), + feedback_repo=_make_feedback_repo([]), + ) + with TestClient(app) as client: + response = client.get("/api/runs/missing-run/feedback") + assert response.status_code == 404 + assert "missing-run" in response.json()["detail"] + + def test_empty_list_when_no_feedback(self): + """Returns empty list when no feedback exists for the run.""" + run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"} + app = _make_app( + run_store=_make_run_store(run_record), + feedback_repo=_make_feedback_repo([]), + ) + with TestClient(app) as client: + response = client.get("/api/runs/run-fb-2/feedback") + assert response.status_code == 200 + assert response.json() == [] + + def test_503_when_feedback_repo_not_configured(self): + """Returns 503 when feedback_repo is None (no DB configured).""" + run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"} + app = _make_app( + run_store=_make_run_store(run_record), + ) + # Explicitly set feedback_repo to None to simulate missing DB + app.state.feedback_repo = None + with TestClient(app) as client: + response = client.get("/api/runs/run-fb-3/feedback") + assert response.status_code == 503 diff --git a/backend/tests/test_thread_run_messages_pagination.py b/backend/tests/test_thread_run_messages_pagination.py new file mode 100644 index 000000000..f00100cad --- /dev/null +++ b/backend/tests/test_thread_run_messages_pagination.py @@ -0,0 +1,128 @@ +"""Tests for paginated GET /api/threads/{thread_id}/runs/{run_id}/messages endpoint.""" +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient + +from app.gateway.routers import thread_runs + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(event_store=None): + """Build a test FastAPI app with stub auth and mocked state.""" + app = make_authed_test_app() + app.include_router(thread_runs.router) + + if event_store is not None: + app.state.run_event_store = event_store + + return app + + +def _make_event_store(rows: list[dict]): + """Return an AsyncMock event store whose list_messages_by_run() returns rows.""" + store = MagicMock() + store.list_messages_by_run = AsyncMock(return_value=rows) + return store + + +def _make_message(seq: int) -> dict: + return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"} + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_returns_paginated_envelope(): + """GET /api/threads/{tid}/runs/{rid}/messages returns {data: [...], has_more: bool}.""" + rows = [_make_message(i) for i in range(1, 4)] + app = _make_app(event_store=_make_event_store(rows)) + with TestClient(app) as client: + response = client.get("/api/threads/thread-1/runs/run-1/messages") + assert response.status_code == 200 + body = response.json() + assert "data" in body + assert "has_more" in body + assert body["has_more"] is False + assert len(body["data"]) == 3 + + +def test_has_more_true_when_extra_row_returned(): + """has_more=True when event store returns limit+1 rows.""" + # Default limit is 50; provide 51 rows + rows = [_make_message(i) for i in range(1, 52)] # 51 rows + app = _make_app(event_store=_make_event_store(rows)) + with TestClient(app) as client: + response = client.get("/api/threads/thread-2/runs/run-2/messages") + assert response.status_code == 200 + body = response.json() + assert body["has_more"] is True + assert len(body["data"]) == 50 # trimmed to limit + + +def test_after_seq_forwarded_to_event_store(): + """after_seq query param is forwarded to event_store.list_messages_by_run.""" + rows = [_make_message(10)] + event_store = _make_event_store(rows) + app = _make_app(event_store=event_store) + with TestClient(app) as client: + response = client.get("/api/threads/thread-3/runs/run-3/messages?after_seq=5") + assert response.status_code == 200 + event_store.list_messages_by_run.assert_awaited_once_with( + "thread-3", "run-3", + limit=51, # default limit(50) + 1 + before_seq=None, + after_seq=5, + ) + + +def test_before_seq_forwarded_to_event_store(): + """before_seq query param is forwarded to event_store.list_messages_by_run.""" + rows = [_make_message(3)] + event_store = _make_event_store(rows) + app = _make_app(event_store=event_store) + with TestClient(app) as client: + response = client.get("/api/threads/thread-4/runs/run-4/messages?before_seq=10") + assert response.status_code == 200 + event_store.list_messages_by_run.assert_awaited_once_with( + "thread-4", "run-4", + limit=51, + before_seq=10, + after_seq=None, + ) + + +def test_custom_limit_forwarded_to_event_store(): + """Custom limit is forwarded as limit+1 to the event store.""" + rows = [_make_message(i) for i in range(1, 6)] + event_store = _make_event_store(rows) + app = _make_app(event_store=event_store) + with TestClient(app) as client: + response = client.get("/api/threads/thread-5/runs/run-5/messages?limit=10") + assert response.status_code == 200 + event_store.list_messages_by_run.assert_awaited_once_with( + "thread-5", "run-5", + limit=11, # 10 + 1 + before_seq=None, + after_seq=None, + ) + + +def test_empty_data_when_no_messages(): + """Returns empty data list with has_more=False when no messages exist.""" + app = _make_app(event_store=_make_event_store([])) + with TestClient(app) as client: + response = client.get("/api/threads/thread-6/runs/run-6/messages") + assert response.status_code == 200 + body = response.json() + assert body["data"] == [] + assert body["has_more"] is False diff --git a/backend/tests/test_thread_state_event_store.py b/backend/tests/test_thread_state_event_store.py deleted file mode 100644 index 0d3b19761..000000000 --- a/backend/tests/test_thread_state_event_store.py +++ /dev/null @@ -1,439 +0,0 @@ -"""Tests for event-store-backed message loading in thread state/history endpoints. - -Covers the helper functions added to ``app/gateway/routers/threads.py``: - -- ``_sanitize_legacy_command_repr`` — extracts inner ToolMessage text from - legacy ``str(Command(...))`` strings captured before the ``journal.py`` - fix for state-updating tools like ``present_files``. -- ``_get_event_store_messages`` — loads the full message stream with full - pagination, copy-on-read id patching, legacy Command sanitization, and - a clean fallback to ``None`` when the event store is unavailable. -""" - -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from typing import Any - -import pytest - -from app.gateway.routers.threads import ( - _get_event_store_messages, - _sanitize_legacy_command_repr, -) -from deerflow.runtime.events.store.memory import MemoryRunEventStore - - -@pytest.fixture() -def event_store() -> MemoryRunEventStore: - return MemoryRunEventStore() - - -class _FakeFeedbackRepo: - """Minimal ``FeedbackRepository`` stand-in that returns a configured map.""" - - def __init__(self, by_run: dict[str, dict] | None = None) -> None: - self._by_run = by_run or {} - - async def list_by_thread_grouped(self, thread_id: str, *, user_id: str | None) -> dict[str, dict]: - return dict(self._by_run) - - -def _make_request( - event_store: MemoryRunEventStore, - feedback_repo: _FakeFeedbackRepo | None = None, -) -> Any: - """Build a minimal FastAPI-like Request object. - - ``get_run_event_store(request)`` reads ``request.app.state.run_event_store``. - ``get_feedback_repo(request)`` reads ``request.app.state.feedback_repo``. - ``get_current_user`` is monkey-patched separately in tests that need it. - """ - state = SimpleNamespace( - run_event_store=event_store, - feedback_repo=feedback_repo or _FakeFeedbackRepo(), - ) - app = SimpleNamespace(state=state) - return SimpleNamespace(app=app) - - -@pytest.fixture(autouse=True) -def _stub_current_user(monkeypatch): - """Stub out ``get_current_user`` so tests don't need real auth context.""" - import app.gateway.routers.threads as threads_mod - - async def _fake(_request): - return None - - monkeypatch.setattr(threads_mod, "get_current_user", _fake) - - -async def _seed_simple_run(store: MemoryRunEventStore, thread_id: str, run_id: str) -> None: - """Seed one run: human + ai_tool_call + tool_result + final ai_message, plus a trace.""" - await store.put( - thread_id=thread_id, run_id=run_id, - event_type="human_message", category="message", - content={ - "type": "human", "id": None, - "content": [{"type": "text", "text": "hello"}], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - }, - ) - await store.put( - thread_id=thread_id, run_id=run_id, - event_type="ai_tool_call", category="message", - content={ - "type": "ai", "id": "lc_run--tc1", - "content": "", - "tool_calls": [{"name": "search", "args": {"q": "x"}, "id": "call_1", "type": "tool_call"}], - "invalid_tool_calls": [], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - "usage_metadata": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, - }, - ) - await store.put( - thread_id=thread_id, run_id=run_id, - event_type="tool_result", category="message", - content={ - "type": "tool", "id": None, - "content": "results", - "tool_call_id": "call_1", "name": "search", - "artifact": None, "status": "success", - "additional_kwargs": {}, "response_metadata": {}, - }, - ) - await store.put( - thread_id=thread_id, run_id=run_id, - event_type="ai_message", category="message", - content={ - "type": "ai", "id": "lc_run--final1", - "content": "done", - "tool_calls": [], "invalid_tool_calls": [], - "additional_kwargs": {}, "response_metadata": {"finish_reason": "stop"}, "name": None, - "usage_metadata": {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, - }, - ) - # Non-message trace — must be filtered out. - await store.put( - thread_id=thread_id, run_id=run_id, - event_type="llm_request", category="trace", - content={"model": "test"}, - ) - - -class TestSanitizeLegacyCommandRepr: - def test_passthrough_non_string(self): - assert _sanitize_legacy_command_repr(None) is None - assert _sanitize_legacy_command_repr(42) == 42 - assert _sanitize_legacy_command_repr([{"type": "text", "text": "x"}]) == [{"type": "text", "text": "x"}] - - def test_passthrough_plain_string(self): - assert _sanitize_legacy_command_repr("Successfully presented files") == "Successfully presented files" - assert _sanitize_legacy_command_repr("") == "" - - def test_extracts_inner_content_single_quotes(self): - legacy = ( - "Command(update={'artifacts': ['/mnt/user-data/outputs/report.md'], " - "'messages': [ToolMessage(content='Successfully presented files', " - "tool_call_id='call_abc')]})" - ) - assert _sanitize_legacy_command_repr(legacy) == "Successfully presented files" - - def test_extracts_inner_content_double_quotes(self): - legacy = 'Command(update={"messages": [ToolMessage(content="ok", tool_call_id="x")]})' - assert _sanitize_legacy_command_repr(legacy) == "ok" - - def test_unparseable_command_returns_original(self): - legacy = "Command(update={'something_else': 1})" - assert _sanitize_legacy_command_repr(legacy) == legacy - - -class TestGetEventStoreMessages: - @pytest.mark.anyio - async def test_returns_none_when_store_empty(self, event_store): - request = _make_request(event_store) - assert await _get_event_store_messages(request, "t_missing") is None - - @pytest.mark.anyio - async def test_extracts_all_message_types_in_order(self, event_store): - await _seed_simple_run(event_store, "t1", "r1") - request = _make_request(event_store) - messages = await _get_event_store_messages(request, "t1") - assert messages is not None - types = [m["type"] for m in messages] - assert types == ["human", "ai", "tool", "ai"] - # Trace events must not appear - for m in messages: - assert m.get("type") in {"human", "ai", "tool"} - - @pytest.mark.anyio - async def test_null_ids_get_deterministic_uuid5(self, event_store): - await _seed_simple_run(event_store, "t1", "r1") - request = _make_request(event_store) - messages = await _get_event_store_messages(request, "t1") - assert messages is not None - - # AI messages keep their LLM ids - assert messages[1]["id"] == "lc_run--tc1" - assert messages[3]["id"] == "lc_run--final1" - - # Human (seq=1) + tool (seq=3) get deterministic uuid5 - expected_human_id = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:1")) - expected_tool_id = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:3")) - assert messages[0]["id"] == expected_human_id - assert messages[2]["id"] == expected_tool_id - - # Re-running produces the same ids (stability across requests) - messages2 = await _get_event_store_messages(request, "t1") - assert [m["id"] for m in messages2] == [m["id"] for m in messages] - - @pytest.mark.anyio - async def test_helper_does_not_mutate_store(self, event_store): - """Helper must copy content dicts; the live store must stay unchanged.""" - await _seed_simple_run(event_store, "t1", "r1") - request = _make_request(event_store) - _ = await _get_event_store_messages(request, "t1") - - # Raw store records still have id=None for human/tool - raw = await event_store.list_messages("t1", limit=500) - human = next(e for e in raw if e["content"]["type"] == "human") - tool = next(e for e in raw if e["content"]["type"] == "tool") - assert human["content"]["id"] is None - assert tool["content"]["id"] is None - - @pytest.mark.anyio - async def test_legacy_command_repr_sanitized(self, event_store): - """A tool_result whose content is a legacy ``str(Command(...))`` is cleaned.""" - legacy = ( - "Command(update={'artifacts': ['/mnt/user-data/outputs/x.md'], " - "'messages': [ToolMessage(content='Successfully presented files', " - "tool_call_id='call_p')]})" - ) - await event_store.put( - thread_id="t2", run_id="r1", - event_type="tool_result", category="message", - content={ - "type": "tool", "id": None, - "content": legacy, - "tool_call_id": "call_p", "name": "present_files", - "artifact": None, "status": "success", - "additional_kwargs": {}, "response_metadata": {}, - }, - ) - request = _make_request(event_store) - messages = await _get_event_store_messages(request, "t2") - assert messages is not None and len(messages) == 1 - assert messages[0]["content"] == "Successfully presented files" - - @pytest.mark.anyio - async def test_pagination_covers_more_than_one_page(self, event_store, monkeypatch): - """Simulate a long thread that exceeds a single page to exercise the loop.""" - thread_id = "t_long" - # Seed 12 human messages - for i in range(12): - await event_store.put( - thread_id=thread_id, run_id="r1", - event_type="human_message", category="message", - content={ - "type": "human", "id": None, - "content": [{"type": "text", "text": f"msg {i}"}], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - }, - ) - - # Force small page size to exercise pagination - import app.gateway.routers.threads as threads_mod - original = threads_mod._get_event_store_messages - - # Monkeypatch MemoryRunEventStore.list_messages to assert it's called with cursor pagination - calls: list[dict] = [] - real_list = event_store.list_messages - - async def spy_list_messages(tid, *, limit=50, before_seq=None, after_seq=None): - calls.append({"limit": limit, "after_seq": after_seq}) - return await real_list(tid, limit=limit, before_seq=before_seq, after_seq=after_seq) - - monkeypatch.setattr(event_store, "list_messages", spy_list_messages) - - request = _make_request(event_store) - messages = await original(request, thread_id) - assert messages is not None - assert len(messages) == 12 - assert [m["content"][0]["text"] for m in messages] == [f"msg {i}" for i in range(12)] - # At least one call was made with after_seq=None (the initial page) - assert any(c["after_seq"] is None for c in calls) - - @pytest.mark.anyio - async def test_summarize_regression_recovers_pre_summarize_messages(self, event_store): - """The exact bug: checkpoint would have only post-summarize messages; - event store must surface the original pre-summarize human query.""" - # Run 1 (pre-summarize) - await event_store.put( - thread_id="t_sum", run_id="r1", - event_type="human_message", category="message", - content={ - "type": "human", "id": None, - "content": [{"type": "text", "text": "original question"}], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - }, - ) - await event_store.put( - thread_id="t_sum", run_id="r1", - event_type="ai_message", category="message", - content={ - "type": "ai", "id": "lc_run--r1", - "content": "first answer", - "tool_calls": [], "invalid_tool_calls": [], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - "usage_metadata": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, - }, - ) - # Run 2 (post-summarize — what the checkpoint still has) - await event_store.put( - thread_id="t_sum", run_id="r2", - event_type="human_message", category="message", - content={ - "type": "human", "id": None, - "content": [{"type": "text", "text": "follow up"}], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - }, - ) - await event_store.put( - thread_id="t_sum", run_id="r2", - event_type="ai_message", category="message", - content={ - "type": "ai", "id": "lc_run--r2", - "content": "second answer", - "tool_calls": [], "invalid_tool_calls": [], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - "usage_metadata": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, - }, - ) - - request = _make_request(event_store) - messages = await _get_event_store_messages(request, "t_sum") - assert messages is not None - # 4 messages, not 2 (which is what the summarized checkpoint would yield) - assert len(messages) == 4 - assert messages[0]["content"][0]["text"] == "original question" - assert messages[1]["id"] == "lc_run--r1" - assert messages[3]["id"] == "lc_run--r2" - - @pytest.mark.anyio - async def test_run_id_attached_to_every_message(self, event_store): - await _seed_simple_run(event_store, "t1", "r1") - request = _make_request(event_store) - messages = await _get_event_store_messages(request, "t1") - assert messages is not None - assert all(m.get("run_id") == "r1" for m in messages) - - @pytest.mark.anyio - async def test_feedback_attached_only_to_final_ai_message_per_run(self, event_store): - await _seed_simple_run(event_store, "t1", "r1") - feedback_repo = _FakeFeedbackRepo( - {"r1": {"feedback_id": "fb1", "rating": 1, "comment": "great"}} - ) - request = _make_request(event_store, feedback_repo=feedback_repo) - messages = await _get_event_store_messages(request, "t1") - assert messages is not None - - # human (0), ai_tool_call (1), tool (2), ai_message (3) - final_ai = messages[3] - assert final_ai["feedback"] == { - "feedback_id": "fb1", - "rating": 1, - "comment": "great", - } - # Non-final messages must NOT have a feedback key at all — the - # frontend keys button visibility off of this. - assert "feedback" not in messages[0] - assert "feedback" not in messages[1] - assert "feedback" not in messages[2] - - @pytest.mark.anyio - async def test_feedback_none_when_no_row_for_run(self, event_store): - await _seed_simple_run(event_store, "t1", "r1") - request = _make_request(event_store, feedback_repo=_FakeFeedbackRepo({})) - messages = await _get_event_store_messages(request, "t1") - assert messages is not None - # Final ai_message gets an explicit ``None`` — distinguishes "eligible - # but unrated" from "not eligible" (field absent). - assert messages[3]["feedback"] is None - - @pytest.mark.anyio - async def test_feedback_per_run_for_multi_run_thread(self, event_store): - """A thread with two runs: each final ai_message should get its own feedback.""" - # Run 1 - await event_store.put( - thread_id="t_multi", run_id="r1", - event_type="human_message", category="message", - content={"type": "human", "id": None, "content": "q1", - "additional_kwargs": {}, "response_metadata": {}, "name": None}, - ) - await event_store.put( - thread_id="t_multi", run_id="r1", - event_type="ai_message", category="message", - content={"type": "ai", "id": "lc_run--a1", "content": "a1", - "tool_calls": [], "invalid_tool_calls": [], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - "usage_metadata": None}, - ) - # Run 2 - await event_store.put( - thread_id="t_multi", run_id="r2", - event_type="human_message", category="message", - content={"type": "human", "id": None, "content": "q2", - "additional_kwargs": {}, "response_metadata": {}, "name": None}, - ) - await event_store.put( - thread_id="t_multi", run_id="r2", - event_type="ai_message", category="message", - content={"type": "ai", "id": "lc_run--a2", "content": "a2", - "tool_calls": [], "invalid_tool_calls": [], - "additional_kwargs": {}, "response_metadata": {}, "name": None, - "usage_metadata": None}, - ) - feedback_repo = _FakeFeedbackRepo({ - "r1": {"feedback_id": "fb_r1", "rating": 1, "comment": None}, - "r2": {"feedback_id": "fb_r2", "rating": -1, "comment": "meh"}, - }) - request = _make_request(event_store, feedback_repo=feedback_repo) - messages = await _get_event_store_messages(request, "t_multi") - assert messages is not None - # human[r1], ai[r1], human[r2], ai[r2] - assert messages[1]["feedback"]["feedback_id"] == "fb_r1" - assert messages[1]["feedback"]["rating"] == 1 - assert messages[3]["feedback"]["feedback_id"] == "fb_r2" - assert messages[3]["feedback"]["rating"] == -1 - # Humans don't get feedback - assert "feedback" not in messages[0] - assert "feedback" not in messages[2] - - @pytest.mark.anyio - async def test_feedback_repo_failure_does_not_break_helper(self, monkeypatch, event_store): - """If feedback lookup throws, messages still come back without feedback.""" - await _seed_simple_run(event_store, "t1", "r1") - - class _BoomRepo: - async def list_by_thread_grouped(self, *a, **kw): - raise RuntimeError("db down") - - request = _make_request(event_store, feedback_repo=_BoomRepo()) - messages = await _get_event_store_messages(request, "t1") - assert messages is not None - assert len(messages) == 4 - for m in messages: - assert "feedback" not in m - - @pytest.mark.anyio - async def test_returns_none_when_dep_raises(self, monkeypatch, event_store): - """When ``get_run_event_store`` is not configured, helper returns None.""" - import app.gateway.routers.threads as threads_mod - - def boom(_request): - raise RuntimeError("no store") - - monkeypatch.setattr(threads_mod, "get_run_event_store", boom) - request = _make_request(event_store) - assert await threads_mod._get_event_store_messages(request, "t1") is None diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index c6f063e32..4ffa28a8c 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -50,10 +50,13 @@ def test_delete_thread_data_rejects_invalid_thread_id(tmp_path): def test_delete_thread_route_cleans_thread_directory(tmp_path): + from deerflow.runtime.user_context import get_effective_user_id + paths = Paths(tmp_path) - thread_dir = paths.thread_dir("thread-route") - paths.sandbox_work_dir("thread-route").mkdir(parents=True, exist_ok=True) - (paths.sandbox_work_dir("thread-route") / "notes.txt").write_text("hello", encoding="utf-8") + user_id = get_effective_user_id() + thread_dir = paths.thread_dir("thread-route", user_id=user_id) + paths.sandbox_work_dir("thread-route", user_id=user_id).mkdir(parents=True, exist_ok=True) + (paths.sandbox_work_dir("thread-route", user_id=user_id) / "notes.txt").write_text("hello", encoding="utf-8") app = make_authed_test_app() app.include_router(threads.router) diff --git a/backend/tests/test_uploads_middleware_core_logic.py b/backend/tests/test_uploads_middleware_core_logic.py index 72639fb09..2c562b179 100644 --- a/backend/tests/test_uploads_middleware_core_logic.py +++ b/backend/tests/test_uploads_middleware_core_logic.py @@ -34,7 +34,9 @@ def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock: def _uploads_dir(tmp_path: Path, thread_id: str = THREAD_ID) -> Path: - d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id) + from deerflow.runtime.user_context import get_effective_user_id + + d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) d.mkdir(parents=True, exist_ok=True) return d diff --git a/backend/tests/test_user_context.py b/backend/tests/test_user_context.py index b7dd1efd0..8c7cbd13c 100644 --- a/backend/tests/test_user_context.py +++ b/backend/tests/test_user_context.py @@ -11,7 +11,9 @@ import pytest from deerflow.runtime.user_context import ( CurrentUser, + DEFAULT_USER_ID, get_current_user, + get_effective_user_id, require_current_user, reset_current_user, set_current_user, @@ -67,3 +69,42 @@ def test_protocol_rejects_no_id(): """Objects without .id do not satisfy CurrentUser Protocol.""" not_a_user = SimpleNamespace(email="no-id@example.com") assert not isinstance(not_a_user, CurrentUser) + + +# --------------------------------------------------------------------------- +# get_effective_user_id / DEFAULT_USER_ID tests +# --------------------------------------------------------------------------- + + +def test_default_user_id_is_default(): + assert DEFAULT_USER_ID == "default" + + +@pytest.mark.no_auto_user +def test_effective_user_id_returns_default_when_no_user(): + """No user in context -> fallback to DEFAULT_USER_ID.""" + assert get_effective_user_id() == "default" + + +@pytest.mark.no_auto_user +def test_effective_user_id_returns_user_id_when_set(): + user = SimpleNamespace(id="u-abc-123") + token = set_current_user(user) + try: + assert get_effective_user_id() == "u-abc-123" + finally: + reset_current_user(token) + + +@pytest.mark.no_auto_user +def test_effective_user_id_coerces_to_str(): + """User.id might be a UUID object; must come back as str.""" + import uuid + uid = uuid.uuid4() + + user = SimpleNamespace(id=uid) + token = set_current_user(user) + try: + assert get_effective_user_id() == str(uid) + finally: + reset_current_user(token)