mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 22:28:27 +00:00
Fix documentary narration parsing and explicit vision overrides
This commit is contained in:
parent
4e2560651f
commit
abc9db22e5
@ -119,9 +119,11 @@ JSON 必须包含以下键:
|
||||
concurrency = self._resolve_max_concurrency(max_concurrency)
|
||||
provider = (vision_llm_provider or config.app.get("vision_llm_provider", "openai")).lower()
|
||||
|
||||
api_key = vision_api_key or config.app.get(f"vision_{provider}_api_key")
|
||||
model_name = vision_model_name or config.app.get(f"vision_{provider}_model_name")
|
||||
base_url = vision_base_url or config.app.get(f"vision_{provider}_base_url", "")
|
||||
api_key = vision_api_key if vision_api_key is not None else config.app.get(f"vision_{provider}_api_key")
|
||||
model_name = (
|
||||
vision_model_name if vision_model_name is not None else config.app.get(f"vision_{provider}_model_name")
|
||||
)
|
||||
base_url = vision_base_url if vision_base_url is not None else config.app.get(f"vision_{provider}_base_url", "")
|
||||
if not api_key or not model_name:
|
||||
raise ValueError(
|
||||
f"未配置 {provider} 的 API Key 或模型名称。"
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import PIL.Image
|
||||
@ -111,41 +110,11 @@ class LegacyLLMAdapter:
|
||||
temperature=1.5,
|
||||
response_format="json"
|
||||
)
|
||||
|
||||
# 使用增强的JSON解析器
|
||||
from webui.tools.generate_short_summary import parse_and_fix_json
|
||||
parsed_result = parse_and_fix_json(result)
|
||||
|
||||
if not parsed_result:
|
||||
logger.error("无法解析LLM返回的JSON数据")
|
||||
# 返回一个基本的JSON结构而不是错误字符串
|
||||
return json.dumps({
|
||||
"items": [
|
||||
{
|
||||
"_id": 1,
|
||||
"timestamp": "00:00:00-00:00:10",
|
||||
"picture": "解析失败,请检查LLM输出",
|
||||
"narration": "解说文案生成失败,请重试"
|
||||
}
|
||||
]
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 确保返回的是JSON字符串
|
||||
return json.dumps(parsed_result, ensure_ascii=False)
|
||||
return result if isinstance(result, str) else str(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成解说文案失败: {str(e)}")
|
||||
# 返回一个基本的JSON结构而不是错误字符串
|
||||
return json.dumps({
|
||||
"items": [
|
||||
{
|
||||
"_id": 1,
|
||||
"timestamp": "00:00:00-00:00:10",
|
||||
"picture": "生成失败",
|
||||
"narration": f"解说文案生成失败: {str(e)}"
|
||||
}
|
||||
]
|
||||
}, ensure_ascii=False)
|
||||
raise
|
||||
|
||||
|
||||
class VisionAnalyzerAdapter:
|
||||
@ -198,6 +167,8 @@ class VisionAnalyzerAdapter:
|
||||
prompt=prompt,
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency,
|
||||
api_key=self.api_key,
|
||||
api_base=self.base_url,
|
||||
)
|
||||
|
||||
# 转换为旧格式以保持向后兼容性
|
||||
|
||||
@ -2,11 +2,12 @@
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.config import config
|
||||
from app.services.llm.base import TextModelProvider
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.migration_adapter import VisionAnalyzerAdapter
|
||||
from app.services.llm.migration_adapter import LegacyLLMAdapter, VisionAnalyzerAdapter
|
||||
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
|
||||
from app.services.llm.providers import register_all_providers
|
||||
|
||||
@ -118,6 +119,7 @@ class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
|
||||
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||
class _CapturingVisionProvider:
|
||||
last_init: tuple[str, str, str | None] | None = None
|
||||
last_call_kwargs: dict | None = None
|
||||
|
||||
def __init__(self, api_key: str, model_name: str, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
@ -126,6 +128,7 @@ class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url)
|
||||
|
||||
async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs):
|
||||
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_call_kwargs = dict(kwargs)
|
||||
return [f"{self.model_name}|{self.api_key}|{self.base_url}"]
|
||||
|
||||
def setUp(self):
|
||||
@ -160,8 +163,32 @@ class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||
("explicit-key", "explicit-model", "https://explicit.example/v1"),
|
||||
self._CapturingVisionProvider.last_init,
|
||||
)
|
||||
self.assertEqual("explicit-key", self._CapturingVisionProvider.last_call_kwargs["api_key"])
|
||||
self.assertEqual("https://explicit.example/v1", self._CapturingVisionProvider.last_call_kwargs["api_base"])
|
||||
self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"])
|
||||
|
||||
|
||||
class LegacyNarrationAdapterBehaviorTests(unittest.TestCase):
|
||||
def test_generate_narration_returns_raw_unrecoverable_payload_without_fabrication(self):
|
||||
raw_payload = "not-json-at-all ::: ???"
|
||||
|
||||
with patch(
|
||||
"app.services.llm.migration_adapter.PromptManager.get_prompt",
|
||||
return_value="prompt",
|
||||
), patch(
|
||||
"app.services.llm.migration_adapter._run_async_safely",
|
||||
return_value=raw_payload,
|
||||
):
|
||||
result = LegacyLLMAdapter.generate_narration(
|
||||
markdown_content="markdown",
|
||||
api_key="test-key",
|
||||
base_url="https://example.com/v1",
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
self.assertEqual(raw_payload, result)
|
||||
self.assertNotIn('"items"', result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -260,9 +260,57 @@ class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyn
|
||||
)
|
||||
|
||||
narration_input = mocked_generate.call_args.args[0]
|
||||
self.assertIn("## 创作上下文", narration_input)
|
||||
self.assertIn("视频主题:野生动物纪录片", narration_input)
|
||||
self.assertIn("补充创作要求:重点描述危险信号", narration_input)
|
||||
|
||||
async def test_analyze_video_forwards_explicit_empty_base_url_without_config_fallback(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
|
||||
with patch.dict(
|
||||
"app.services.documentary.frame_analysis_service.config.app",
|
||||
{
|
||||
"vision_llm_provider": "openai",
|
||||
"vision_openai_api_key": "config-key",
|
||||
"vision_openai_model_name": "config-model",
|
||||
"vision_openai_base_url": "https://config.example/v1",
|
||||
},
|
||||
), patch(
|
||||
"app.services.documentary.frame_analysis_service.os.path.exists",
|
||||
return_value=True,
|
||||
), patch.object(
|
||||
service,
|
||||
"_load_or_extract_keyframes",
|
||||
return_value=["/tmp/keyframe_000001_000000100.jpg"],
|
||||
), patch.object(
|
||||
service,
|
||||
"_analyze_batches",
|
||||
AsyncMock(return_value=[]),
|
||||
), patch.object(
|
||||
service,
|
||||
"_save_analysis_artifact",
|
||||
return_value="/tmp/frame_analysis_test.json",
|
||||
), patch.object(
|
||||
service,
|
||||
"_build_video_clip_json",
|
||||
return_value=[],
|
||||
), patch(
|
||||
"app.services.documentary.frame_analysis_service.create_vision_analyzer",
|
||||
return_value=object(),
|
||||
) as mocked_create_analyzer:
|
||||
await service.analyze_video(
|
||||
video_path="/tmp/demo.mp4",
|
||||
vision_api_key="explicit-key",
|
||||
vision_model_name="explicit-model",
|
||||
vision_base_url="",
|
||||
)
|
||||
|
||||
called_kwargs = mocked_create_analyzer.call_args.kwargs
|
||||
self.assertEqual("openai", called_kwargs["provider"])
|
||||
self.assertEqual("explicit-key", called_kwargs["api_key"])
|
||||
self.assertEqual("explicit-model", called_kwargs["model"])
|
||||
self.assertEqual("", called_kwargs["base_url"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user