mirror of
https://github.com/linyqh/NarratoAI.git
synced 2026-05-01 14:18:19 +00:00
fix(documentary): restore narration repair and explicit vision overrides
This commit is contained in:
parent
a8b6a5bb6b
commit
4e2560651f
@ -77,8 +77,13 @@ JSON 必须包含以下键:
|
||||
)
|
||||
|
||||
markdown_output = parse_frame_analysis_to_markdown(analysis_json_path)
|
||||
narration_input = self._build_narration_input(
|
||||
markdown_output=markdown_output,
|
||||
video_theme=video_theme,
|
||||
custom_prompt=custom_prompt,
|
||||
)
|
||||
narration_raw = generate_narration(
|
||||
markdown_output,
|
||||
narration_input,
|
||||
text_api_key,
|
||||
base_url=text_base_url,
|
||||
model=text_model,
|
||||
@ -172,21 +177,7 @@ JSON 必须包含以下键:
|
||||
}
|
||||
|
||||
def _parse_narration_items(self, narration_raw: str) -> list[dict[str, Any]]:
|
||||
def load_json_candidate(payload: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
cleaned = (narration_raw or "").strip()
|
||||
parsed = load_json_candidate(cleaned)
|
||||
if parsed is None:
|
||||
parsed = load_json_candidate(cleaned.replace("```json", "").replace("```", "").strip())
|
||||
if parsed is None:
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
parsed = load_json_candidate(cleaned[start : end + 1])
|
||||
parsed = self._repair_narration_payload(narration_raw)
|
||||
|
||||
items: list[dict[str, Any]] = []
|
||||
if isinstance(parsed, dict):
|
||||
@ -199,6 +190,64 @@ JSON 必须包含以下键:
|
||||
|
||||
return items
|
||||
|
||||
def _build_narration_input(self, *, markdown_output: str, video_theme: str, custom_prompt: str) -> str:
|
||||
context_lines: list[str] = []
|
||||
if (video_theme or "").strip():
|
||||
context_lines.append(f"视频主题:{video_theme.strip()}")
|
||||
if (custom_prompt or "").strip():
|
||||
context_lines.append(f"补充创作要求:{custom_prompt.strip()}")
|
||||
|
||||
if not context_lines:
|
||||
return markdown_output
|
||||
|
||||
context_block = "\n".join(f"- {line}" for line in context_lines)
|
||||
return f"{markdown_output.rstrip()}\n\n## 创作上下文\n{context_block}\n"
|
||||
|
||||
def _repair_narration_payload(self, narration_raw: str) -> dict[str, Any] | None:
|
||||
def load_json_candidate(payload: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
cleaned = (narration_raw or "").strip()
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
candidates: list[str] = [cleaned]
|
||||
candidates.append(cleaned.replace("{{", "{").replace("}}", "}"))
|
||||
|
||||
json_block = re.search(r"```json\s*(.*?)\s*```", cleaned, re.DOTALL)
|
||||
if json_block:
|
||||
candidates.append(json_block.group(1).strip())
|
||||
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
candidates.append(cleaned[start : end + 1])
|
||||
|
||||
for candidate in candidates:
|
||||
parsed = load_json_candidate(candidate)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
|
||||
fixed = cleaned.replace("{{", "{").replace("}}", "}")
|
||||
fixed_start = fixed.find("{")
|
||||
fixed_end = fixed.rfind("}")
|
||||
if fixed_start >= 0 and fixed_end > fixed_start:
|
||||
fixed = fixed[fixed_start : fixed_end + 1]
|
||||
|
||||
fixed = re.sub(r"^\s*#.*$", "", fixed, flags=re.MULTILINE)
|
||||
fixed = re.sub(r"^\s*//.*$", "", fixed, flags=re.MULTILINE)
|
||||
fixed = re.sub(r",\s*}", "}", fixed)
|
||||
fixed = re.sub(r",\s*]", "]", fixed)
|
||||
fixed = re.sub(r"'([^']*)'\s*:", r'"\1":', fixed)
|
||||
fixed = re.sub(r'([{\[,]\s*)([A-Za-z_][\w\u4e00-\u9fff]*)(\s*:)', r'\1"\2"\3', fixed)
|
||||
fixed = re.sub(r'""([^"]*?)""', r'"\1"', fixed)
|
||||
|
||||
return load_json_candidate(fixed)
|
||||
|
||||
def _resolve_frame_interval(self, frame_interval_input: int | float | None) -> float:
|
||||
interval = frame_interval_input
|
||||
if interval in (None, ""):
|
||||
|
||||
@ -13,6 +13,7 @@ from loguru import logger
|
||||
|
||||
from .unified_service import UnifiedLLMService
|
||||
from .exceptions import LLMServiceError
|
||||
from .manager import LLMServiceManager
|
||||
# 导入新的提示词管理系统
|
||||
from app.services.prompts import PromptManager
|
||||
|
||||
@ -155,6 +156,23 @@ class VisionAnalyzerAdapter:
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
|
||||
def _build_provider_with_explicit_settings(self):
|
||||
provider_name = (self.provider or "").lower()
|
||||
if not LLMServiceManager.is_registered():
|
||||
from .providers import register_all_providers
|
||||
|
||||
register_all_providers()
|
||||
|
||||
provider_class = LLMServiceManager._vision_providers.get(provider_name)
|
||||
if provider_class is None:
|
||||
raise LLMServiceError(f"视觉模型提供商未注册: {provider_name}")
|
||||
|
||||
return provider_class(
|
||||
api_key=self.api_key,
|
||||
model_name=self.model,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
async def analyze_images(self,
|
||||
images: List[Union[str, Path, PIL.Image.Image]],
|
||||
@ -174,11 +192,10 @@ class VisionAnalyzerAdapter:
|
||||
分析结果列表,格式与旧实现兼容
|
||||
"""
|
||||
try:
|
||||
# 使用统一服务分析图片
|
||||
results = await UnifiedLLMService.analyze_images(
|
||||
provider = self._build_provider_with_explicit_settings()
|
||||
results = await provider.analyze_images(
|
||||
images=images,
|
||||
prompt=prompt,
|
||||
provider=self.provider,
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ import unittest
|
||||
from app.config import config
|
||||
from app.services.llm.base import TextModelProvider
|
||||
from app.services.llm.manager import LLMServiceManager
|
||||
from app.services.llm.migration_adapter import VisionAnalyzerAdapter
|
||||
from app.services.llm.openai_compatible_provider import OpenAICompatibleVisionProvider
|
||||
from app.services.llm.providers import register_all_providers
|
||||
|
||||
@ -114,5 +115,53 @@ class OpenAICompatVisionConcurrencyTests(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(2, max_in_flight)
|
||||
|
||||
|
||||
class ExplicitVisionAdapterSettingsTests(unittest.IsolatedAsyncioTestCase):
|
||||
class _CapturingVisionProvider:
|
||||
last_init: tuple[str, str, str | None] | None = None
|
||||
|
||||
def __init__(self, api_key: str, model_name: str, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
ExplicitVisionAdapterSettingsTests._CapturingVisionProvider.last_init = (api_key, model_name, base_url)
|
||||
|
||||
async def analyze_images(self, images, prompt, batch_size=10, max_concurrency=1, **kwargs):
|
||||
return [f"{self.model_name}|{self.api_key}|{self.base_url}"]
|
||||
|
||||
def setUp(self):
|
||||
_reset_manager_state()
|
||||
self._original_app = dict(config.app)
|
||||
|
||||
def tearDown(self):
|
||||
_reset_manager_state()
|
||||
config.app.clear()
|
||||
config.app.update(self._original_app)
|
||||
|
||||
async def test_adapter_uses_explicit_settings_instead_of_global_config(self):
|
||||
LLMServiceManager.register_vision_provider("openai", self._CapturingVisionProvider)
|
||||
config.app["vision_openai_api_key"] = "config-key"
|
||||
config.app["vision_openai_model_name"] = "config-model"
|
||||
config.app["vision_openai_base_url"] = "https://config.example/v1"
|
||||
|
||||
adapter = VisionAnalyzerAdapter(
|
||||
provider="openai",
|
||||
api_key="explicit-key",
|
||||
model="explicit-model",
|
||||
base_url="https://explicit.example/v1",
|
||||
)
|
||||
result = await adapter.analyze_images(
|
||||
images=["/tmp/keyframe_000001_000000100.jpg"],
|
||||
prompt="描述画面",
|
||||
batch_size=1,
|
||||
max_concurrency=1,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
("explicit-key", "explicit-model", "https://explicit.example/v1"),
|
||||
self._CapturingVisionProvider.last_init,
|
||||
)
|
||||
self.assertEqual("explicit-model|explicit-key|https://explicit.example/v1", result[0]["response"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -183,6 +183,86 @@ class DocumentaryFrameAnalysisServiceScriptGenerationTests(unittest.IsolatedAsyn
|
||||
self.assertIn("解说文案格式错误", str(ctx.exception))
|
||||
self.assertIn("items", str(ctx.exception))
|
||||
|
||||
def test_parse_narration_items_recovers_from_common_json_damage(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
damaged_payload = """
|
||||
解释文字
|
||||
```json
|
||||
{{
|
||||
"items": [
|
||||
{{
|
||||
"timestamp": "00:00:00,000-00:00:03,000",
|
||||
"picture": "镜头里有一只猫",
|
||||
"narration": "一只猫警觉地望向镜头。",
|
||||
}},
|
||||
],
|
||||
}}
|
||||
```
|
||||
补充文字
|
||||
""".strip()
|
||||
|
||||
parsed_items = service._parse_narration_items(damaged_payload)
|
||||
|
||||
self.assertEqual(1, len(parsed_items))
|
||||
self.assertEqual("00:00:00,000-00:00:03,000", parsed_items[0]["timestamp"])
|
||||
self.assertEqual("镜头里有一只猫", parsed_items[0]["picture"])
|
||||
self.assertEqual("一只猫警觉地望向镜头。", parsed_items[0]["narration"])
|
||||
|
||||
def test_parse_narration_items_raises_for_unrecoverable_payload(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
service._parse_narration_items("not-json-at-all ::: ???")
|
||||
|
||||
self.assertIn("解说文案格式错误", str(ctx.exception))
|
||||
self.assertIn("items", str(ctx.exception))
|
||||
|
||||
async def test_generate_documentary_script_includes_theme_and_custom_prompt_for_narration(self):
|
||||
service = DocumentaryFrameAnalysisService()
|
||||
analysis_payload = {
|
||||
"batches": [
|
||||
{
|
||||
"batch_index": 0,
|
||||
"time_range": "00:00:00,000-00:00:03,000",
|
||||
"overall_activity_summary": "测试摘要",
|
||||
"fallback_summary": "",
|
||||
"frame_observations": [
|
||||
{"timestamp": "00:00:00,000", "observation": "镜头里有一只猫"},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
analysis_path = Path(temp_dir) / "frame_analysis_test.json"
|
||||
analysis_path.write_text(json.dumps(analysis_payload, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
with patch.object(
|
||||
DocumentaryFrameAnalysisService,
|
||||
"analyze_video",
|
||||
AsyncMock(return_value={"analysis_json_path": str(analysis_path)}),
|
||||
), patch.dict(
|
||||
"app.services.documentary.frame_analysis_service.config.app",
|
||||
{
|
||||
"text_llm_provider": "openai",
|
||||
"text_openai_api_key": "test-key",
|
||||
"text_openai_model_name": "test-model",
|
||||
"text_openai_base_url": "https://example.com/v1",
|
||||
},
|
||||
), patch(
|
||||
"app.services.documentary.frame_analysis_service.generate_narration",
|
||||
return_value='{"items":[{"timestamp":"00:00:00,000-00:00:03,000","picture":"镜头里有一只猫","narration":"一只猫警觉地望向镜头。"}]}',
|
||||
) as mocked_generate:
|
||||
await service.generate_documentary_script(
|
||||
video_path="demo.mp4",
|
||||
video_theme="野生动物纪录片",
|
||||
custom_prompt="重点描述危险信号",
|
||||
)
|
||||
|
||||
narration_input = mocked_generate.call_args.args[0]
|
||||
self.assertIn("视频主题:野生动物纪录片", narration_input)
|
||||
self.assertIn("补充创作要求:重点描述危险信号", narration_input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user