diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index cb048152e..e6f4fa2ae 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -90,6 +90,28 @@ class ThreadSearchRequest(BaseModel): offset: int = Field(default=0, ge=0, description="Pagination offset") status: str | None = Field(default=None, description="Filter by thread status") + @field_validator("metadata") + @classmethod + def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]: + """Reject filter entries the SQL backend cannot compile. + + Enforces consistent behaviour across SQL and memory backends. + See ``deerflow.persistence.json_compat`` for the shared validators. + """ + if not v: + return v + from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value + + bad_entries: list[str] = [] + for key, value in v.items(): + if not validate_metadata_filter_key(key): + bad_entries.append(f"{key!r} (unsafe key)") + elif not validate_metadata_filter_value(value): + bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})") + if bad_entries: + raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}") + return v + class ThreadStateResponse(BaseModel): """Response model for thread state.""" @@ -294,14 +316,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th (SQL-backed for sqlite/postgres, Store-backed for memory mode). """ from app.gateway.deps import get_thread_store + from deerflow.persistence.thread_meta import InvalidMetadataFilterError repo = get_thread_store(request) - rows = await repo.search( - metadata=body.metadata or None, - status=body.status, - limit=body.limit, - offset=body.offset, - ) + try: + rows = await repo.search( + metadata=body.metadata or None, + status=body.status, + limit=body.limit, + offset=body.offset, + ) + except InvalidMetadataFilterError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc return [ ThreadResponse( thread_id=r["thread_id"], diff --git a/backend/packages/harness/deerflow/persistence/json_compat.py b/backend/packages/harness/deerflow/persistence/json_compat.py new file mode 100644 index 000000000..442b29e22 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/json_compat.py @@ -0,0 +1,195 @@ +"""Dialect-aware JSON value matching for SQLAlchemy (SQLite + PostgreSQL).""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import BigInteger, Float, String, bindparam +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.expression import ColumnElement +from sqlalchemy.sql.visitors import InternalTraversal +from sqlalchemy.types import Boolean, TypeEngine + +# Key is interpolated into compiled SQL; restrict charset to prevent injection. +_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$") + +# Allowed value types for metadata filter values (same set accepted by JsonMatch). +ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str) + +# SQLite raises an overflow when binding values outside signed 64-bit range; +# PostgreSQL overflows during BIGINT cast. Reject at validation time instead. +_INT64_MIN = -(2**63) +_INT64_MAX = 2**63 - 1 + + +def validate_metadata_filter_key(key: object) -> bool: + """Return True if *key* is safe for use as a JSON metadata filter key. + + A key is "safe" when it is a string matching ``[A-Za-z0-9_-]+``. The + charset is restricted because the key is interpolated into the + compiled SQL path expression (``$.""`` / ``->`` literal), so any + laxer pattern would open a SQL/JSONPath injection surface. + """ + return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key)) + + +def validate_metadata_filter_value(value: object) -> bool: + """Return True if *value* is an allowed type for a JSON metadata filter. + + Matches the set of types ``_build_clause`` knows how to compile into + a dialect-portable predicate. Anything else (list/dict/bytes/...) is + intentionally rejected rather than silently coerced via ``str()`` — + silent coercion would (a) produce wrong matches and (b) break + SQLAlchemy's ``inherit_cache`` invariant when ``value`` is unhashable. + + Integer values are additionally restricted to the signed 64-bit range + ``[-2**63, 2**63 - 1]``: SQLite overflows when binding larger values + and PostgreSQL overflows during the ``BIGINT`` cast. + """ + if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES): + return False + if isinstance(value, int) and not isinstance(value, bool): + if not (_INT64_MIN <= value <= _INT64_MAX): + return False + return True + + +class JsonMatch(ColumnElement): + """Dialect-portable ``column[key] == value`` for JSON columns. + + Compiles to ``json_type``/``json_extract`` on SQLite and + ``json_typeof``/``->>`` on PostgreSQL, with type-safe comparison + that distinguishes bool vs int and NULL vs missing key. + + *key* must be a single literal key matching ``[A-Za-z0-9_-]+``. + *value* must be one of: ``None``, ``bool``, ``int`` (signed 64-bit), ``float``, ``str``. + """ + + inherit_cache = True + type = Boolean() + _is_implicitly_boolean = True + + _traverse_internals = [ + ("column", InternalTraversal.dp_clauseelement), + ("key", InternalTraversal.dp_string), + ("value", InternalTraversal.dp_plain_obj), + ] + + def __init__(self, column: ColumnElement, key: str, value: object) -> None: + if not validate_metadata_filter_key(key): + raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}") + if not validate_metadata_filter_value(value): + if isinstance(value, int) and not isinstance(value, bool): + raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}") + raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}") + self.column = column + self.key = key + self.value = value + super().__init__() + + +@dataclass(frozen=True) +class _Dialect: + """Per-dialect names used when emitting JSON type/value comparisons.""" + + null_type: str + num_types: tuple[str, ...] + num_cast: str + int_types: tuple[str, ...] + int_cast: str + # None for SQLite where json_type already returns 'integer'/'real'; + # regex literal for PostgreSQL where json_typeof returns 'number' for + # both ints and floats, so an extra guard prevents CAST errors on floats. + int_guard: str | None + string_type: str + bool_type: str | None + + +_SQLITE = _Dialect( + null_type="null", + num_types=("integer", "real"), + num_cast="REAL", + int_types=("integer",), + int_cast="INTEGER", + int_guard=None, + string_type="text", + bool_type=None, +) + +_PG = _Dialect( + null_type="null", + num_types=("number",), + num_cast="DOUBLE PRECISION", + int_types=("number",), + int_cast="BIGINT", + int_guard="'^-?[0-9]+$'", + string_type="string", + bool_type="boolean", +) + + +def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str: + param = bindparam(None, value, type_=sa_type) + return compiler.process(param, **kw) + + +def _type_check(typeof: str, types: tuple[str, ...]) -> str: + if len(types) == 1: + return f"{typeof} = '{types[0]}'" + quoted = ", ".join(f"'{t}'" for t in types) + return f"{typeof} IN ({quoted})" + + +def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str: + if value is None: + return f"{typeof} = '{dialect.null_type}'" + if isinstance(value, bool): + # bool check must precede int check — bool is a subclass of int in Python + bool_str = "true" if value else "false" + if dialect.bool_type is None: + return f"{typeof} = '{bool_str}'" + return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')" + if isinstance(value, int): + bp = _bind(compiler, value, BigInteger(), **kw) + if dialect.int_guard: + # CASE prevents CAST error when json_typeof = 'number' also matches floats + return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" + return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})" + if isinstance(value, float): + bp = _bind(compiler, value, Float(), **kw) + return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})" + bp = _bind(compiler, str(value), String(), **kw) + return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})" + + +@compiles(JsonMatch, "sqlite") +def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + path = f'$."{element.key}"' + typeof = f"json_type({col}, '{path}')" + extract = f"json_extract({col}, '{path}')" + return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw) + + +@compiles(JsonMatch, "postgresql") +def _compile_pg(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + typeof = f"json_typeof({col} -> '{element.key}')" + extract = f"({col} ->> '{element.key}')" + return _build_clause(compiler, typeof, extract, element.value, _PG, **kw) + + +@compiles(JsonMatch) +def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + raise NotImplementedError(f"JsonMatch supports only sqlite and postgresql; got dialect: {compiler.dialect.name}") + + +def json_match(column: ColumnElement, key: str, value: object) -> JsonMatch: + return JsonMatch(column, key, value) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py index 080ce8093..b5231f0f9 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.sql import ThreadMetaRepository @@ -14,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker __all__ = [ + "InvalidMetadataFilterError", "MemoryThreadMetaStore", "ThreadMetaRepository", "ThreadMetaRow", diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index c87c10a16..ed55ade8e 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -15,10 +15,15 @@ three-state semantics (see :mod:`deerflow.runtime.user_context`): from __future__ import annotations import abc +from typing import Any from deerflow.runtime.user_context import AUTO, _AutoSentinel +class InvalidMetadataFilterError(ValueError): + """Raised when all client-supplied metadata filter keys are rejected.""" + + class ThreadMetaStore(abc.ABC): @abc.abstractmethod async def create( @@ -40,12 +45,12 @@ class ThreadMetaStore(abc.ABC): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: pass @abc.abstractmethod diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index fbe66fdaf..4f642a938 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -69,12 +69,12 @@ class MemoryThreadMetaStore(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") filter_dict: dict[str, Any] = {} if metadata: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 688fbb247..0d3f587de 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -2,16 +2,20 @@ from __future__ import annotations +import logging from datetime import UTC, datetime from typing import Any from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.json_compat import json_match +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +logger = logging.getLogger(__name__) + class ThreadMetaRepository(ThreadMetaStore): def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: @@ -20,7 +24,7 @@ class ThreadMetaRepository(ThreadMetaStore): @staticmethod def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]: d = row.to_dict() - d["metadata"] = d.pop("metadata_json", {}) + d["metadata"] = d.pop("metadata_json", None) or {} for key in ("created_at", "updated_at"): val = d.get(key) if isinstance(val, datetime): @@ -104,39 +108,43 @@ class ThreadMetaRepository(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: """Search threads with optional metadata and status filters. Owner filter is enforced by default: caller must be in a user context. Pass ``user_id=None`` to bypass (migration/CLI). """ resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search") - stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) + stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc(), ThreadMetaRow.thread_id.desc()) if resolved_user_id is not None: stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id) if status: stmt = stmt.where(ThreadMetaRow.status == status) if metadata: - # When metadata filter is active, fetch a larger window and filter - # in Python. TODO(Phase 2): use JSON DB operators (Postgres @>, - # SQLite json_extract) for server-side filtering. - stmt = stmt.limit(limit * 5 + offset) - async with self._sf() as session: - result = await session.execute(stmt) - rows = [self._row_to_dict(r) for r in result.scalars()] - rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())] - return rows[offset : offset + limit] - else: - stmt = stmt.limit(limit).offset(offset) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] + applied = 0 + for key, value in metadata.items(): + try: + stmt = stmt.where(json_match(ThreadMetaRow.metadata_json, key, value)) + applied += 1 + except (ValueError, TypeError) as exc: + logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc) + if applied == 0: + # Comma-separated plain string (no list repr / nested + # quoting) so the 400 detail surfaced by the Gateway is + # easy for clients to read. Sorted for determinism. + rejected_keys = ", ".join(sorted(str(k) for k in metadata)) + raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}") + + stmt = stmt.limit(limit).offset(offset) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool: """Return True if the row exists and is owned (or filter bypassed).""" diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 3a6532567..1cef3752b 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -1,28 +1,25 @@ """Tests for ThreadMetaRepository (SQLAlchemy-backed).""" +import logging + import pytest -from deerflow.persistence.thread_meta import ThreadMetaRepository +from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository -async def _make_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine +@pytest.fixture +async def repo(tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return ThreadMetaRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - + yield ThreadMetaRepository(get_session_factory()) await close_engine() class TestThreadMetaRepository: @pytest.mark.anyio - async def test_create_and_get(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_and_get(self, repo): record = await repo.create("t1") assert record["thread_id"] == "t1" assert record["status"] == "idle" @@ -31,148 +28,523 @@ class TestThreadMetaRepository: fetched = await repo.get("t1") assert fetched is not None assert fetched["thread_id"] == "t1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_assistant_id(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_assistant_id(self, repo): record = await repo.create("t1", assistant_id="agent1") assert record["assistant_id"] == "agent1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_owner_and_display_name(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_owner_and_display_name(self, repo): record = await repo.create("t1", user_id="user1", display_name="My Thread") assert record["user_id"] == "user1" assert record["display_name"] == "My Thread" - await _cleanup() @pytest.mark.anyio - async def test_create_with_metadata(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_metadata(self, repo): record = await repo.create("t1", metadata={"key": "value"}) assert record["metadata"] == {"key": "value"} - await _cleanup() @pytest.mark.anyio - async def test_get_nonexistent(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_get_nonexistent(self, repo): assert await repo.get("nonexistent") is None - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_record_allows(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_record_allows(self, repo): assert await repo.check_access("unknown", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_matches(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_matches(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_mismatch(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_mismatch(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2") is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_owner_allows_all(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_owner_allows_all(self, repo): # Explicit user_id=None to bypass the new AUTO default that # would otherwise pick up the test user from the autouse fixture. await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_missing_row_denied(self, tmp_path): + async def test_check_access_strict_missing_row_denied(self, repo): """require_existing=True flips the missing-row case to *denied*. Closes the delete-idempotence cross-user gap: after a thread is deleted, the row is gone, and the permissive default would let any caller "claim" it as untracked. The strict mode demands a row. """ - repo = await _make_repo(tmp_path) assert await repo.check_access("never-existed", "user1", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_match_allowed(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_match_allowed(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_mismatch_denied(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): + async def test_check_access_strict_null_owner_still_allowed(self, repo): """Even in strict mode, a row with NULL user_id stays shared. The strict flag tightens the *missing row* case, not the *shared row* case — legacy pre-auth rows that survived a clean migration without an owner are still everyone's. """ - repo = await _make_repo(tmp_path) await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_update_status(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_status(self, repo): await repo.create("t1") await repo.update_status("t1", "busy") record = await repo.get("t1") assert record["status"] == "busy" - await _cleanup() @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete(self, repo): await repo.create("t1") await repo.delete("t1") assert await repo.get("t1") is None - await _cleanup() @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete_nonexistent_is_noop(self, repo): await repo.delete("nonexistent") # should not raise - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_merges(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_merges(self, repo): await repo.create("t1", metadata={"a": 1, "b": 2}) await repo.update_metadata("t1", {"b": 99, "c": 3}) record = await repo.get("t1") # Existing key preserved, overlapping key overwritten, new key added assert record["metadata"] == {"a": 1, "b": 99, "c": 3} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_on_empty(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_on_empty(self, repo): await repo.create("t1") await repo.update_metadata("t1", {"k": "v"}) record = await repo.get("t1") assert record["metadata"] == {"k": "v"} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_nonexistent_is_noop(self, repo): await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise - await _cleanup() + + # --- search with metadata filter (SQL push-down) --- + + @pytest.mark.anyio + async def test_search_metadata_filter_string(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + await repo.create("t3", metadata={"env": "prod", "region": "us"}) + + results = await repo.search(metadata={"env": "prod"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_numeric(self, repo): + await repo.create("t1", metadata={"priority": 1}) + await repo.create("t2", metadata={"priority": 2}) + await repo.create("t3", metadata={"priority": 1, "extra": "x"}) + + results = await repo.search(metadata={"priority": 1}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_multiple_keys(self, repo): + await repo.create("t1", metadata={"env": "prod", "region": "us"}) + await repo.create("t2", metadata={"env": "prod", "region": "eu"}) + await repo.create("t3", metadata={"env": "staging", "region": "us"}) + + results = await repo.search(metadata={"env": "prod", "region": "us"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_metadata_no_match(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "dev"}) + assert results == [] + + @pytest.mark.anyio + async def test_search_metadata_pagination_correct(self, repo): + """Regression: SQL push-down makes limit/offset exact even when most rows don't match.""" + for i in range(30): + meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"} + await repo.create(f"t{i:03d}", metadata=meta) + + # Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows + all_matches = await repo.search(metadata={"target": "yes"}, limit=100) + assert len(all_matches) == 10 + + # Paginate: first page + page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0) + assert len(page1) == 3 + + # Paginate: second page + page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3) + assert len(page2) == 3 + + # No overlap between pages + page1_ids = {r["thread_id"] for r in page1} + page2_ids = {r["thread_id"] for r in page2} + assert page1_ids.isdisjoint(page2_ids) + + # Last page + page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9) + assert len(page_last) == 1 + + @pytest.mark.anyio + async def test_search_metadata_with_status_filter(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "prod"}) + await repo.update_status("t1", "busy") + + results = await repo.search(metadata={"env": "prod"}, status="busy") + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_without_metadata_still_works(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2") + + results = await repo.search(limit=10) + assert len(results) == 2 + + @pytest.mark.anyio + async def test_search_metadata_missing_key_no_match(self, repo): + """Rows without the requested metadata key should not match.""" + await repo.create("t1", metadata={"other": "val"}) + await repo.create("t2", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "prod"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t2" + + @pytest.mark.anyio + async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog): + """When ALL metadata keys are unsafe, raises InvalidMetadataFilterError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info: + await repo.search(metadata={"bad;key": "x"}) + assert any("bad;key" in r.message for r in caplog.records) + # Subclass of ValueError for backward compatibility + assert isinstance(exc_info.value, ValueError) + + @pytest.mark.anyio + async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog): + """Valid keys filter rows; only the invalid key is warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + results = await repo.search(metadata={"env": "prod", "bad;key": "x"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + assert any("bad;key" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_filter_boolean(self, repo): + """True matches only boolean true, not integer 1.""" + await repo.create("t1", metadata={"active": True}) + await repo.create("t2", metadata={"active": False}) + await repo.create("t3", metadata={"active": True, "extra": "x"}) + await repo.create("t4", metadata={"active": 1}) + + results = await repo.search(metadata={"active": True}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_none(self, repo): + """Only rows with explicit JSON null match; missing key does not.""" + await repo.create("t1", metadata={"tag": None}) + await repo.create("t2", metadata={"tag": "present"}) + await repo.create("t3", metadata={"other": "val"}) + + results = await repo.search(metadata={"tag": None}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + + @pytest.mark.anyio + async def test_search_metadata_non_string_key_skipped(self, repo, caplog): + """Non-string keys raise ValueError from isinstance check; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={1: "x"}) + assert any("1" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog): + """Unsupported value types (list, dict) raise TypeError; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"env": ["prod", "staging"]}) + + @pytest.mark.anyio + async def test_search_metadata_dotted_key_raises(self, repo, caplog): + """Dotted keys are rejected; when ALL keys are dotted, raises ValueError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"a.b": "anything"}) + assert any("a.b" in r.message for r in caplog.records) + + # --- dialect-aware type-safe filtering edge cases --- + + @pytest.mark.anyio + async def test_search_metadata_bool_vs_int_distinction(self, repo): + """True must not match 1; False must not match 0.""" + await repo.create("bool_true", metadata={"flag": True}) + await repo.create("bool_false", metadata={"flag": False}) + await repo.create("int_one", metadata={"flag": 1}) + await repo.create("int_zero", metadata={"flag": 0}) + + true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})} + assert true_hits == {"bool_true"} + + false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})} + assert false_hits == {"bool_false"} + + @pytest.mark.anyio + async def test_search_metadata_int_does_not_match_bool(self, repo): + """Integer 1 must not match boolean True.""" + await repo.create("bool_true", metadata={"val": True}) + await repo.create("int_one", metadata={"val": 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})} + assert hits == {"int_one"} + + @pytest.mark.anyio + async def test_search_metadata_none_excludes_missing_key(self, repo): + """Filtering by None matches explicit JSON null only, not missing key or empty {}.""" + await repo.create("explicit_null", metadata={"k": None}) + await repo.create("missing_key", metadata={"other": "x"}) + await repo.create("empty_obj", metadata={}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})} + assert hits == {"explicit_null"} + + @pytest.mark.anyio + async def test_search_metadata_float_value(self, repo): + await repo.create("t1", metadata={"score": 3.14}) + await repo.create("t2", metadata={"score": 2.71}) + await repo.create("t3", metadata={"score": 3.14}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})} + assert hits == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_mixed_types_same_key(self, repo): + """Each type query only matches its own type, even when the key is shared.""" + await repo.create("str_row", metadata={"x": "hello"}) + await repo.create("int_row", metadata={"x": 42}) + await repo.create("bool_row", metadata={"x": True}) + await repo.create("null_row", metadata={"x": None}) + + assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"} + + @pytest.mark.anyio + async def test_search_metadata_large_int_precision(self, repo): + """Integers beyond float precision (> 2**53) must match exactly.""" + large = 2**53 + 1 + await repo.create("t1", metadata={"id": large}) + await repo.create("t2", metadata={"id": large - 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})} + assert hits == {"t1"} + + +class TestJsonMatchCompilation: + """Verify compiled SQL for both SQLite and PostgreSQL dialects.""" + + def test_json_match_compiles_sqlite(self): + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + cases = [ + (None, "json_type(t.data, '$.\"k\"') = 'null'"), + (True, "json_type(t.data, '$.\"k\"') = 'true'"), + (False, "json_type(t.data, '$.\"k\"') = 'false'"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: uses INTEGER cast for precision, type-check narrows to 'integer' only + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "= 'integer'" in sql + assert "INTEGER" in sql + assert "CAST" in sql + + # float: uses REAL cast, type-check spans 'integer' and 'real' + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "IN ('integer', 'real')" in sql + assert "REAL" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "'text'" in sql + + def test_json_match_compiles_pg(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + dialect = postgresql.dialect() + + cases = [ + (None, "json_typeof(t.data -> 'k') = 'null'"), + (True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"), + (False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: CASE guard prevents CAST error when 'number' also matches floats + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "BIGINT" in sql + assert "CASE WHEN" in sql + assert "'^-?[0-9]+$'" in sql + + # float: uses DOUBLE PRECISION cast + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "DOUBLE PRECISION" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'string'" in sql + + def test_json_match_rejects_unsafe_key(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, bad_key, "x") + + # Non-string keys must also raise ValueError (not TypeError from re.match) + for non_str_key in [42, None, ("k",)]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, non_str_key, "x") + + def test_json_match_rejects_unsupported_value_type(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_value in [[], {}, object()]: + with pytest.raises(TypeError, match="JsonMatch value must be"): + json_match(t.c.data, "k", bad_value) + + def test_json_match_unsupported_dialect_raises(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import mysql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + expr = json_match(t.c.data, "k", "v") + + with pytest.raises(NotImplementedError, match="mysql"): + str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True})) + + def test_json_match_rejects_out_of_range_int(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + # boundary values must be accepted + json_match(t.c.data, "k", 2**63 - 1) + json_match(t.c.data, "k", -(2**63)) + + # one beyond each boundary must be rejected + for out_of_range in [2**63, -(2**63) - 1, 10**30]: + with pytest.raises(TypeError, match="out of signed 64-bit range"): + json_match(t.c.data, "k", out_of_range) + + def test_compiler_raises_on_escaped_key(self): + """Compiler raises ValueError even when __init__ validation is bypassed.""" + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + elem = json_match(t.c.data, "k", "v") + elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index daf0c0b13..9e37f3c86 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore from app.gateway.routers import threads from deerflow.config.paths import Paths +from deerflow.persistence.thread_meta import InvalidMetadataFilterError from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore _ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None assert entries, "expected at least one history entry" for entry in entries: assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry + + +# ── Metadata filter validation at API boundary ──────────────────────────────── + + +def test_search_threads_rejects_invalid_key_at_api_boundary() -> None: + """Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic + validator on ThreadSearchRequest.metadata — 422 from both backends. + """ + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}}) + + assert response.status_code == 422 + + +def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None: + """Value types outside (None, bool, int, float, str) are rejected.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}}) + + assert response.status_code == 422 + + +def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None: + """If the backend still raises InvalidMetadataFilterError (defense in + depth), the handler surfaces it as HTTP 400. + """ + app, _store, _checkpointer = _build_thread_app() + thread_store = app.state.thread_store + + async def _raise(**kwargs): + raise InvalidMetadataFilterError("rejected") + + with TestClient(app) as client: + with patch.object(thread_store, "search", side_effect=_raise): + response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}}) + + assert response.status_code == 400 + assert "rejected" in response.json()["detail"] + + +def test_search_threads_succeeds_with_valid_metadata() -> None: + """Sanity check: valid metadata passes through without error.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}}) + + assert response.status_code == 200