ChatDev/runtime/node/agent/providers/gemini_provider.py
Shu Yao 2833fcd7ff fix: parallel image generation error
upgrade: gemini provider enable setting image ratio
2026-01-16 16:10:28 +08:00

834 lines
32 KiB
Python
Executable File

"""Gemini provider implementation."""
import base64
import binascii
import json
import os
import uuid
from typing import Any, Dict, List, Optional, Sequence, Tuple
from google import genai
from google.genai import types as genai_types
from google.genai.types import GenerateContentResponse
from entity.messages import (
AttachmentRef,
FunctionCallOutputEvent,
Message,
MessageBlock,
MessageBlockType,
MessageRole,
ToolCallPayload,
)
from entity.tool_spec import ToolSpec
from runtime.node.agent import ModelProvider
from runtime.node.agent import ModelResponse
from utils.token_tracker import TokenUsage
class GeminiProvider(ModelProvider):
"""Gemini provider implementation."""
CSV_INLINE_CHAR_LIMIT = 200_000
CSV_INLINE_SIZE_THRESHOLD_BYTES = 3 * 1024 * 1024 # 3 MB
def create_client(self):
"""
Create and return the Gemini client.
"""
client_kwargs: Dict[str, Any] = {}
if self.api_key:
client_kwargs["api_key"] = self.api_key
base_url = (self.base_url or "").strip()
http_options = self._build_http_options(base_url)
if http_options:
client_kwargs["http_options"] = http_options
return genai.Client(**client_kwargs)
def call_model(
self,
client,
conversation: List[Message],
timeline: List[Any],
tool_specs: Optional[List[ToolSpec]] = None,
**kwargs,
) -> ModelResponse:
"""
Call the Gemini model using the unified conversation timeline.
"""
contents, system_instruction = self._build_contents(timeline)
config = self._build_generation_config(system_instruction, tool_specs, kwargs)
# print(contents)
# print(config)
response: GenerateContentResponse = client.models.generate_content(
model=self.model_name,
contents=contents,
config=config,
)
# print(response)
self._track_token_usage(response)
self._append_response_contents(timeline, response)
message = self._deserialize_response(response)
return ModelResponse(message=message, raw_response=response)
def extract_token_usage(self, response: Any) -> TokenUsage:
"""Extract token usage from Gemini usage metadata."""
usage_metadata = getattr(response, "usage_metadata", None)
if not usage_metadata:
return TokenUsage()
prompt_tokens = getattr(usage_metadata, "prompt_token_count", None) or 0
candidate_tokens = getattr(usage_metadata, "candidates_token_count", None) or 0
total_tokens = getattr(usage_metadata, "total_token_count", None)
cached_tokens = getattr(usage_metadata, "cached_content_token_count", None)
metadata = {
"prompt_token_count": prompt_tokens,
"candidates_token_count": candidate_tokens,
}
if total_tokens is not None:
metadata["total_token_count"] = total_tokens
if cached_tokens is not None:
metadata["cached_content_token_count"] = cached_tokens
return TokenUsage(
input_tokens=prompt_tokens,
output_tokens=candidate_tokens,
total_tokens=total_tokens or (prompt_tokens + candidate_tokens),
metadata=metadata,
)
# ---------------------------------------------------------------------
# Serialization helpers
# ---------------------------------------------------------------------
def _build_contents(
self,
timeline: List[Any],
) -> Tuple[List[genai_types.Content], Optional[str]]:
contents: List[genai_types.Content] = []
system_prompts: List[str] = []
for item in timeline:
if isinstance(item, Message):
if item.role is MessageRole.SYSTEM:
text = item.text_content().strip()
if text:
system_prompts.append(text)
continue
contents.append(self._message_to_content(item))
continue
if isinstance(item, FunctionCallOutputEvent):
contents.append(self._function_output_event_to_content(item))
continue
if isinstance(item, genai_types.Content):
contents.append(item)
if not contents:
contents.append(
genai_types.Content(
role="user",
parts=[genai_types.Part(text="")],
)
)
system_instruction = "\n\n".join(system_prompts) if system_prompts else None
return contents, system_instruction
def _append_response_contents(self, timeline: List[Any], response: Any) -> None:
candidates = getattr(response, "candidates", None)
if not candidates:
return
for candidate in candidates:
content = getattr(candidate, "content", None)
if content:
timeline.append(content)
def _message_to_content(self, message: Message) -> genai_types.Content:
role = self._map_role(message.role)
if message.role is MessageRole.TOOL:
part = self._build_tool_response_part(message)
return genai_types.Content(role="user", parts=[part])
parts: List[genai_types.Part] = []
for block in message.blocks():
parts.extend(self._block_to_parts(block))
if not parts:
text = message.text_content()
parts.append(genai_types.Part(text=text))
return genai_types.Content(role=role, parts=parts)
def _function_output_event_to_content(
self,
event: FunctionCallOutputEvent,
) -> genai_types.Content:
function_name = event.function_name or event.call_id or "tool"
payload: Dict[str, Any] = {}
function_result_parts: List[genai_types.FunctionResponsePart] = []
result_texts: List[str] = []
if event.output_blocks:
for block in event.output_blocks:
# Describe the block for the text result
desc = self._describe_block(block)
if desc:
result_texts.append(desc)
if self._block_has_attachment(block):
# Check if we should inline this attachment as text
if self._should_inline_attachment_as_text(block):
text_content = self._read_attachment_text(block.attachment)
if text_content:
result_texts.append(f"\n[Attachment Content: {block.attachment.name}]\n{text_content}")
continue
# Otherwise treat as binary part
general_parts = self._block_to_parts(block)
function_result_parts.extend(self._general_parts_to_function_response_parts(general_parts))
else:
if event.output_text:
result_texts.append(event.output_text)
payload["result"] = "\n".join(result_texts)
function_part = genai_types.Part.from_function_response(
name=function_name,
response=payload or {"result": ""},
parts=function_result_parts or None
)
parts: List[genai_types.Part] = [function_part]
return genai_types.Content(role="user", parts=parts)
def _should_inline_attachment_as_text(self, block: MessageBlock) -> bool:
if not block.attachment:
return False
mime = (block.attachment.mime_type or "").lower()
return (
mime.startswith("text/") or
mime == "application/json" or
mime.endswith("+json") or
mime.endswith("+xml")
)
def _read_attachment_text(self, attachment: AttachmentRef) -> Optional[str]:
data_bytes = self._read_attachment_bytes(attachment)
return self._bytes_to_text(data_bytes)
def _general_parts_to_function_response_parts(self, parts: List[genai_types.Part]) -> List[genai_types.FunctionResponsePart]:
function_response_parts: List[genai_types.FunctionResponsePart] = []
for part in parts:
if part.inline_data:
# Convert inline_data (bytes) to base64 data URI and use from_uri
function_response_parts.append(
genai_types.FunctionResponsePart.from_bytes(data=part.inline_data.data, mime_type=part.inline_data.mime_type or "application/octet-stream")
)
if part.file_data:
function_response_parts.append(
genai_types.FunctionResponsePart.from_uri(file_uri=part.file_data.file_uri, mime_type=part.file_data.mime_type or "application/octet-stream")
)
return function_response_parts
def _build_tool_response_part(self, message: Message) -> genai_types.Part:
tool_name = message.metadata.get("tool_name") if isinstance(message.metadata, dict) else None
tool_name = tool_name or message.tool_call_id or "tool"
payload, block_parts = self._serialize_tool_message_payload(message)
return genai_types.Part(
function_response=genai_types.FunctionResponse(
name=tool_name,
response=payload,
parts=block_parts or None,
)
)
def _block_has_attachment(self, block: Any) -> bool:
return isinstance(block, MessageBlock) and block.attachment is not None
def _serialize_tool_message_payload(self, message: Message) -> Tuple[Dict[str, Any], List[genai_types.FunctionResponsePart]]:
content = message.content
blocks: List[MessageBlock] = []
if isinstance(content, str):
stripped = content.strip()
if stripped:
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
payload = {"result": stripped}
else:
payload = {"result": ""}
return payload, []
if isinstance(content, list):
blocks_payload = []
for block in content:
if isinstance(block, MessageBlock):
blocks_payload.append(block.to_dict())
blocks.append(block)
elif isinstance(block, dict):
blocks_payload.append(block)
try:
blocks.append(MessageBlock.from_dict(block))
except Exception:
continue
parts = self._blocks_to_function_parts(blocks)
return {"blocks": blocks_payload, "result": message.text_content()}, parts
parts = self._blocks_to_function_parts(blocks)
return {"result": message.text_content()}, parts
def _describe_block(self, block: Any) -> str:
if isinstance(block, MessageBlock):
return block.describe()
if isinstance(block, dict):
text = block.get("text")
if text:
return str(text)
return str(block)
def _block_to_parts(self, block: MessageBlock) -> List[genai_types.Part]:
if block.type is MessageBlockType.TEXT:
return [genai_types.Part(text=block.text or "")]
if block.type is MessageBlockType.FILE:
csv_text = self._maybe_inline_large_csv(block)
if csv_text is not None:
return [genai_types.Part(text=csv_text)]
if block.type in (
MessageBlockType.IMAGE,
MessageBlockType.AUDIO,
MessageBlockType.VIDEO,
MessageBlockType.FILE,
):
media_part = self._attachment_block_to_part(block)
return [media_part] if media_part else []
if block.type is MessageBlockType.DATA:
data_payload = block.data or {}
text = block.text or json.dumps(data_payload, ensure_ascii=False)
return [genai_types.Part(text=text)]
return []
def _maybe_inline_large_csv(self, block: MessageBlock) -> Optional[str]:
"""Convert large CSV attachments to inline text to avoid Gemini upload size limits."""
attachment = block.attachment
if not attachment:
return None
mime = (attachment.mime_type or "").lower()
name = (attachment.name or "").lower()
if "text/csv" not in mime and not name.endswith(".csv"):
return None
if attachment.remote_file_id:
return None
threshold = getattr(
self,
"csv_inline_size_threshold_bytes",
self.CSV_INLINE_SIZE_THRESHOLD_BYTES,
)
size_bytes = attachment.size
data_bytes: Optional[bytes] = None
if size_bytes is None:
data_bytes = self._read_attachment_bytes(attachment)
if data_bytes is None:
return None
size_bytes = len(data_bytes)
if size_bytes is None or size_bytes <= threshold:
return None
if data_bytes is None:
data_bytes = self._read_attachment_bytes(attachment)
if data_bytes is None:
return None
text = self._bytes_to_text(data_bytes)
if text is None:
return None
char_limit = getattr(self, "csv_inline_char_limit", self.CSV_INLINE_CHAR_LIMIT)
truncated = False
if len(text) > char_limit:
text = text[:char_limit]
truncated = True
display_name = attachment.name or attachment.attachment_id or "attachment.csv"
suffix = f"\n\n[truncated after {char_limit} characters]" if truncated else ""
return f"CSV file '{display_name}' (converted from >3MB upload):\n{text}{suffix}"
def _bytes_to_text(self, data_bytes: Optional[bytes]) -> Optional[str]:
if data_bytes is None:
return None
try:
return data_bytes.decode("utf-8")
except UnicodeDecodeError:
return data_bytes.decode("utf-8", errors="replace")
def _attachment_block_to_part(self, block: MessageBlock) -> Optional[genai_types.Part]:
attachment = block.attachment
if not attachment:
return None
metadata = attachment.metadata or {}
gemini_file_uri = metadata.get("gemini_file_uri") or attachment.remote_file_id
mime_type = attachment.mime_type or self._guess_mime_from_block(block)
if gemini_file_uri:
return genai_types.Part(
file_data=genai_types.FileData(
file_uri=gemini_file_uri,
mime_type=mime_type,
# display_name=attachment.name
)
)
blob_data = self._read_attachment_bytes(attachment)
if blob_data is None:
return None
return genai_types.Part(
inline_data=genai_types.Blob(
mime_type=mime_type or "application/octet-stream",
data=blob_data,
# display_name=attachment.name,
)
)
def _blocks_to_function_parts(
self,
blocks: Optional[Sequence[Any]],
) -> List[genai_types.FunctionResponsePart]:
if not blocks:
return []
parts: List[genai_types.FunctionResponsePart] = []
for block in blocks:
if not isinstance(block, MessageBlock):
if isinstance(block, dict):
try:
block = MessageBlock.from_dict(block)
except Exception:
continue
else:
continue
attachment = block.attachment
if not attachment:
continue
mime_type = attachment.mime_type or self._guess_mime_from_block(block)
file_uri = (attachment.metadata or {}).get("gemini_file_uri") or attachment.remote_file_id
if file_uri:
parts.append(
genai_types.FunctionResponsePart(
file_data=genai_types.FunctionResponseFileData(
file_uri=file_uri,
mime_type=mime_type,
display_name=attachment.name,
)
)
)
continue
data_bytes = self._read_attachment_bytes(attachment)
if not data_bytes:
continue
parts.append(
genai_types.FunctionResponsePart(
inline_data=genai_types.FunctionResponseBlob(
mime_type=mime_type or "application/octet-stream",
data=data_bytes,
display_name=attachment.name,
)
)
)
return parts
def _coerce_message_blocks(self, payload: Any) -> List[MessageBlock]:
if not isinstance(payload, Sequence) or isinstance(payload, (str, bytes, bytearray)):
return []
blocks: List[MessageBlock] = []
for item in payload:
if isinstance(item, MessageBlock):
blocks.append(item)
elif isinstance(item, dict):
try:
blocks.append(MessageBlock.from_dict(item))
except Exception:
continue
return blocks
def _encode_thought_signature(self, value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, bytes):
return base64.b64encode(value).decode("ascii")
try:
return str(value)
except Exception:
return None
def _read_attachment_bytes(self, attachment: AttachmentRef) -> Optional[bytes]:
if attachment.data_uri:
decoded = self._decode_data_uri(attachment.data_uri)
if decoded is not None:
return decoded
if attachment.local_path and os.path.exists(attachment.local_path):
try:
with open(attachment.local_path, "rb") as handle:
return handle.read()
except OSError:
return None
return None
def _decode_data_uri(self, data_uri: str) -> Optional[bytes]:
if not data_uri.startswith("data:"):
return None
header, _, data = data_uri.partition(",")
if not _:
return None
if ";base64" in header:
try:
return base64.b64decode(data)
except (ValueError, binascii.Error):
return None
return data.encode("utf-8")
def _guess_mime_from_block(self, block: MessageBlock) -> str:
if block.attachment and block.attachment.mime_type:
return block.attachment.mime_type
if block.type is MessageBlockType.IMAGE:
return "image/png"
if block.type is MessageBlockType.AUDIO:
return "audio/mpeg"
if block.type is MessageBlockType.VIDEO:
return "video/mp4"
return "application/octet-stream"
def _map_role(self, role: MessageRole) -> str:
if role is MessageRole.USER:
return "user"
if role is MessageRole.ASSISTANT:
return "model"
if role is MessageRole.TOOL:
return "tool"
return "user"
# ---------------------------------------------------------------------
# Config builders
# ---------------------------------------------------------------------
def _build_generation_config(
self,
system_instruction: Optional[str],
tool_specs: Optional[List[ToolSpec]],
call_params: Dict[str, Any],
) -> genai_types.GenerateContentConfig:
params = dict(self.params or {})
params.update(call_params)
config_kwargs: Dict[str, Any] = {}
if system_instruction:
config_kwargs["system_instruction"] = system_instruction
for key in (
"temperature",
"top_p",
"top_k",
"candidate_count",
"max_output_tokens",
"response_modalities",
"stop_sequences",
"seed",
"presence_penalty",
"frequency_penalty",
):
if key in params:
config_kwargs[key] = params.pop(key)
safety_settings = params.pop("safety_settings", None)
if safety_settings:
config_kwargs["safety_settings"] = safety_settings
image_config = params.pop("image_config", None)
aspect_ratio = params.pop("aspect_ratio", None)
if aspect_ratio:
if image_config is None:
image_config = {"aspect_ratio": aspect_ratio}
elif isinstance(image_config, dict):
image_config = dict(image_config)
image_config.setdefault("aspect_ratio", aspect_ratio)
elif isinstance(image_config, genai_types.ImageConfig):
try:
image_config.aspect_ratio = aspect_ratio
except Exception:
image_config = {"aspect_ratio": aspect_ratio}
else:
image_config = {"aspect_ratio": aspect_ratio}
if image_config:
config_kwargs["image_config"] = self._coerce_image_config(image_config)
audio_config = params.pop("audio_config", None)
if audio_config:
config_kwargs["audio_config"] = audio_config
video_config = params.pop("video_config", None)
if video_config:
config_kwargs["video_config"] = video_config
tools = self._build_tools(tool_specs or [])
if tools:
config_kwargs["tools"] = tools
tool_config_payload = params.pop("tool_config", None)
function_calling_payload = params.pop("function_calling_config", None)
if function_calling_payload:
tool_config_payload = tool_config_payload or {}
tool_config_payload["function_calling_config"] = function_calling_payload
if tool_config_payload:
config_kwargs["tool_config"] = self._coerce_tool_config(tool_config_payload)
automatic_fn_calling = params.pop("automatic_function_calling", None)
if automatic_fn_calling:
config_kwargs["automatic_function_calling"] = self._coerce_automatic_function_calling(
automatic_fn_calling
)
return genai_types.GenerateContentConfig(**config_kwargs)
def _build_http_options(self, base_url: str) -> Optional[genai_types.HttpOptions]:
if not base_url:
return None
try:
return genai_types.HttpOptions(base_url=base_url, timeout=4 * 60 * 1000) # 4 min
except Exception:
return None
def _coerce_image_config(self, image_config: Any) -> Any:
if isinstance(image_config, genai_types.ImageConfig):
return image_config
if isinstance(image_config, dict):
try:
return genai_types.ImageConfig(**image_config)
except Exception:
return image_config
return image_config
def _build_tools(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]:
if not tool_specs:
return []
declarations = []
for spec in tool_specs:
fn_payload = spec.to_gemini_function()
parameters = fn_payload.get("parameters") or {"type": "object", "properties": {}}
if 'title' in parameters:
parameters.pop('title')
# Replace 'title' with 'description' in properties
for prop_name, prop_value in parameters.get('properties', {}).items():
if isinstance(prop_value, dict) and 'title' in prop_value:
prop_value['description'] = prop_value.pop('title')
declarations.append(
genai_types.FunctionDeclaration(
name=fn_payload.get("name", ""),
description=fn_payload.get("description") or "",
parameters=parameters,
)
)
return [genai_types.Tool(function_declarations=declarations)]
def _coerce_tool_config(self, payload: Any) -> genai_types.ToolConfig:
if isinstance(payload, genai_types.ToolConfig):
return payload
kwargs: Dict[str, Any] = {}
if isinstance(payload, dict):
fn_payload = payload.get("function_calling_config")
if fn_payload:
kwargs["function_calling_config"] = self._coerce_function_calling_config(fn_payload)
return genai_types.ToolConfig(**kwargs)
def _coerce_function_calling_config(self, payload: Any) -> genai_types.FunctionCallingConfig:
if isinstance(payload, genai_types.FunctionCallingConfig):
return payload
if isinstance(payload, str):
return genai_types.FunctionCallingConfig(mode=payload)
if isinstance(payload, dict):
return genai_types.FunctionCallingConfig(**payload)
raise ValueError("Invalid function calling configuration payload")
def _coerce_automatic_function_calling(self, payload: Any) -> Any:
config_cls = getattr(genai_types, "AutomaticFunctionCallingConfig", None)
if config_cls is None:
raise ValueError("Automatic function calling config not supported in current SDK version")
if isinstance(payload, config_cls):
return payload
if isinstance(payload, dict):
return config_cls(**payload)
raise ValueError("Invalid automatic function calling config payload")
# ---------------------------------------------------------------------
# Response parsing
# ---------------------------------------------------------------------
def _deserialize_response(self, response: Any) -> Message:
candidate = self._select_primary_candidate(response)
if not candidate:
return Message(role=MessageRole.ASSISTANT, content="")
content = getattr(candidate, "content", None)
if not content:
return Message(role=MessageRole.ASSISTANT, content=response.text if hasattr(response, "text") else "")
blocks, tool_calls = self._parse_candidate_parts(getattr(content, "parts", []) or [])
if not blocks:
fallback = getattr(response, "text", None) or ""
blocks = [MessageBlock(MessageBlockType.TEXT, text=fallback)] if fallback else []
return Message(
role=MessageRole.ASSISTANT,
content=blocks or "",
tool_calls=tool_calls,
)
def _select_primary_candidate(self, response: Any) -> Any:
candidates = getattr(response, "candidates", None) or []
if not candidates:
return None
return candidates[0]
def _parse_candidate_parts(
self,
parts: Sequence[Any],
) -> Tuple[List[MessageBlock], List[ToolCallPayload]]:
blocks: List[MessageBlock] = []
tool_calls: List[ToolCallPayload] = []
for part in parts:
if hasattr(part, "text") and part.text is not None:
blocks.append(MessageBlock(MessageBlockType.TEXT, text=part.text))
continue
function_call = getattr(part, "function_call", None)
if function_call:
thought_signature = getattr(part, "thought_signature", None)
tool_calls.append(
self._build_tool_call_payload(function_call, thought_signature=thought_signature)
)
continue
inline_data = getattr(part, "inline_data", None)
if inline_data:
blocks.append(self._build_inline_block(inline_data))
continue
file_data = getattr(part, "file_data", None)
if file_data:
blocks.append(self._build_file_block(file_data))
continue
function_response = getattr(part, "function_response", None)
if function_response:
blocks.append(
MessageBlock(
type=MessageBlockType.DATA,
text=json.dumps(function_response.response or {}, ensure_ascii=False),
data={
"function_name": getattr(function_response, "name", ""),
"response": function_response.response or {},
},
)
)
continue
return blocks, tool_calls
def _build_tool_call_payload(self, fn_call: Any, *, thought_signature: Any = None) -> ToolCallPayload:
call_id = getattr(fn_call, "name", "") or uuid.uuid4().hex
arguments = getattr(fn_call, "args", {}) or {}
try:
arg_str = json.dumps(arguments, ensure_ascii=False)
except (TypeError, ValueError):
arg_str = str(arguments)
metadata: Dict[str, Any] = {}
encoded_signature = self._encode_thought_signature(thought_signature)
if encoded_signature:
metadata["gemini_thought_signature_b64"] = encoded_signature
return ToolCallPayload(
id=call_id,
function_name=getattr(fn_call, "name", "") or call_id,
arguments=arg_str,
type="function",
metadata=metadata,
)
def _build_inline_block(self, blob: Any) -> MessageBlock:
mime_type = getattr(blob, "mime_type", "") or "application/octet-stream"
data_bytes = getattr(blob, "data", None) or b""
data_uri = self._encode_data_uri(mime_type, data_bytes)
block_type = self._block_type_from_mime(mime_type)
return MessageBlock(
type=block_type,
attachment=AttachmentRef(
attachment_id=uuid.uuid4().hex,
mime_type=mime_type,
data_uri=data_uri,
metadata={"source": "gemini_inline"},
),
)
def _build_file_block(self, file_data: Any) -> MessageBlock:
mime_type = getattr(file_data, "mime_type", None)
file_uri = getattr(file_data, "file_uri", None) or getattr(file_data, "file", None)
block_type = self._block_type_from_mime(mime_type or "")
return MessageBlock(
type=block_type,
attachment=AttachmentRef(
attachment_id=uuid.uuid4().hex,
mime_type=mime_type,
remote_file_id=file_uri,
metadata={"gemini_file_uri": file_uri, "source": "gemini_file"},
),
)
def _block_type_from_mime(self, mime_type: str) -> MessageBlockType:
if mime_type.startswith("image/"):
return MessageBlockType.IMAGE
if mime_type.startswith("audio/"):
return MessageBlockType.AUDIO
if mime_type.startswith("video/"):
return MessageBlockType.VIDEO
return MessageBlockType.FILE
def _encode_data_uri(self, mime_type: str, data: bytes) -> str:
encoded = base64.b64encode(data).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
# ---------------------------------------------------------------------
# Token tracking
# ---------------------------------------------------------------------
def _track_token_usage(self, response: Any) -> None:
token_tracker = getattr(self.config, "token_tracker", None)
if not token_tracker:
return
usage = self.extract_token_usage(response)
if usage.input_tokens == 0 and usage.output_tokens == 0 and not usage.metadata:
return
node_id = getattr(self.config, "node_id", "ALL")
usage.node_id = node_id
usage.model_name = self.model_name
usage.workflow_id = token_tracker.workflow_id
usage.provider = "gemini"
token_tracker.record_usage(node_id, self.model_name, usage, provider="gemini")