From ad02059e5d82f42c2573e98bc3c0d356c6b85b8e Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 3 Apr 2026 02:04:21 +0800 Subject: [PATCH] fix(documentary): validate batch response contract before success --- .../documentary/frame_analysis_service.py | 41 +++++++++++++++---- ...test_documentary_frame_analysis_service.py | 39 ++++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/app/services/documentary/frame_analysis_service.py b/app/services/documentary/frame_analysis_service.py index 05dfa13..9cefda9 100644 --- a/app/services/documentary/frame_analysis_service.py +++ b/app/services/documentary/frame_analysis_service.py @@ -97,8 +97,6 @@ JSON 必须包含以下键: ) -> FrameBatchResult: try: payload = json.loads(self._strip_code_fence(raw_response)) - if not isinstance(payload, dict): - raise ValueError("Batch response JSON payload must be an object") except Exception as exc: return self._build_failed_batch_result( batch_index=batch_index, @@ -108,9 +106,17 @@ JSON 必须包含以下键: time_range=time_range, ) - raw_observations = payload.get("frame_observations") - if not isinstance(raw_observations, list): - raw_observations = [] + validation_error = self._validate_batch_payload_contract(payload, expected_frame_count=len(frame_paths)) + if validation_error: + return self._build_failed_batch_result( + batch_index=batch_index, + raw_response=raw_response, + error_message=validation_error, + frame_paths=frame_paths, + time_range=time_range, + ) + + raw_observations = payload["frame_observations"] frame_observations: list[dict] = [] for index, frame_path in enumerate(frame_paths): @@ -129,9 +135,7 @@ JSON 必须包含以下键: } ) - summary = payload.get("overall_activity_summary", "") - if not isinstance(summary, str): - summary = str(summary or "") + summary = payload["overall_activity_summary"] return FrameBatchResult( batch_index=batch_index, @@ -142,3 +146,24 @@ JSON 必须包含以下键: frame_observations=frame_observations, overall_activity_summary=summary, ) + + def _validate_batch_payload_contract(self, payload: object, *, expected_frame_count: int) -> str: + if not isinstance(payload, dict): + return "Batch response JSON payload must be an object" + + if "frame_observations" not in payload or not isinstance(payload["frame_observations"], list): + return "Batch response must include frame_observations as a list" + + if len(payload["frame_observations"]) < expected_frame_count: + return ( + "Batch response frame_observations length is shorter than provided frame_paths: " + f"{len(payload['frame_observations'])} < {expected_frame_count}" + ) + + if "overall_activity_summary" not in payload: + return "Batch response must include overall_activity_summary" + + if not isinstance(payload["overall_activity_summary"], str): + return "Batch response overall_activity_summary must be a string" + + return "" diff --git a/tests/test_documentary_frame_analysis_service.py b/tests/test_documentary_frame_analysis_service.py index 1d3415f..edf585c 100644 --- a/tests/test_documentary_frame_analysis_service.py +++ b/tests/test_documentary_frame_analysis_service.py @@ -82,6 +82,45 @@ class DocumentaryFrameAnalysisServiceTests(unittest.TestCase): self.assertEqual([], batch.frame_observations) self.assertEqual("", batch.overall_activity_summary) + def test_parse_batch_returns_failed_result_for_empty_json_object(self): + service = DocumentaryFrameAnalysisService() + + batch = service._parse_batch_response( + batch_index=0, + raw_response="{}", + frame_paths=["/tmp/keyframe_000000_000000000.jpg"], + time_range="00:00:00,000-00:00:03,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual("{}", batch.raw_response) + self.assertIn("frame_observations", batch.error_message) + + def test_parse_batch_returns_failed_result_when_observations_are_too_short(self): + service = DocumentaryFrameAnalysisService() + raw_response = """ +{ + "frame_observations": [ + {"observation": "第一帧画面"} + ], + "overall_activity_summary": "只有一条帧观察" +} +""".strip() + + batch = service._parse_batch_response( + batch_index=1, + raw_response=raw_response, + frame_paths=[ + "/tmp/keyframe_000000_000000000.jpg", + "/tmp/keyframe_000075_000003000.jpg", + ], + time_range="00:00:00,000-00:00:06,000", + ) + + self.assertEqual("failed", batch.status) + self.assertEqual(raw_response, batch.raw_response) + self.assertIn("frame_observations", batch.error_message) + def test_parse_batch_parses_code_fenced_json_into_structured_result(self): service = DocumentaryFrameAnalysisService() raw_response = """```json