diff --git a/backend/app/gateway/routers/suggestions.py b/backend/app/gateway/routers/suggestions.py index 3fac15dd6..0a388a3df 100644 --- a/backend/app/gateway/routers/suggestions.py +++ b/backend/app/gateway/routers/suggestions.py @@ -2,6 +2,7 @@ import json import logging from fastapi import APIRouter +from langchain_core.messages import HumanMessage, SystemMessage from pydantic import BaseModel, Field from deerflow.models import create_chat_model @@ -106,7 +107,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S if not conversation: return SuggestionsResponse(suggestions=[]) - prompt = ( + system_instruction = ( "You are generating follow-up questions to help the user continue the conversation.\n" f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n" "Requirements:\n" @@ -114,14 +115,13 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S "- Questions must be written in the same language as the user.\n" "- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n" "- Do NOT include numbering, markdown, or any extra text.\n" - "- Output MUST be a JSON array of strings only.\n\n" - "Conversation:\n" - f"{conversation}\n" + "- Output MUST be a JSON array of strings only.\n" ) + user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" try: model = create_chat_model(name=request.model_name, thinking_enabled=False) - response = model.invoke(prompt) + response = model.invoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)]) raw = _extract_response_text(response.content) suggestions = _parse_json_string_list(raw) or [] cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()] diff --git a/backend/tests/test_suggestions_router.py b/backend/tests/test_suggestions_router.py index 75237ee10..862005371 100644 --- a/backend/tests/test_suggestions_router.py +++ b/backend/tests/test_suggestions_router.py @@ -1,6 +1,8 @@ import asyncio from unittest.mock import MagicMock +from langchain_core.messages import HumanMessage, SystemMessage + from app.gateway.routers import suggestions @@ -100,3 +102,26 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch): result = asyncio.run(suggestions.generate_suggestions("t1", req)) assert result.suggestions == [] + + +def test_generate_suggestions_invokes_model_with_system_and_human_messages(monkeypatch): + req = suggestions.SuggestionsRequest( + messages=[ + suggestions.SuggestionMessage(role="user", content="What is Python?"), + suggestions.SuggestionMessage(role="assistant", content="Python is a programming language."), + ], + n=2, + model_name=None, + ) + fake_model = MagicMock() + fake_model.invoke.return_value = MagicMock(content='["Q1", "Q2"]') + monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) + + asyncio.run(suggestions.generate_suggestions("t1", req)) + + call_args = fake_model.invoke.call_args[0][0] + assert len(call_args) == 2 + assert isinstance(call_args[0], SystemMessage) + assert isinstance(call_args[1], HumanMessage) + assert "follow-up questions" in call_args[0].content + assert "What is Python?" in call_args[1].content