mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
834 lines
32 KiB
Python
Executable File
834 lines
32 KiB
Python
Executable File
"""Gemini provider implementation."""
|
|
|
|
import base64
|
|
import binascii
|
|
import json
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
|
|
from google import genai
|
|
from google.genai import types as genai_types
|
|
from google.genai.types import GenerateContentResponse
|
|
|
|
from entity.messages import (
|
|
AttachmentRef,
|
|
FunctionCallOutputEvent,
|
|
Message,
|
|
MessageBlock,
|
|
MessageBlockType,
|
|
MessageRole,
|
|
ToolCallPayload,
|
|
)
|
|
from entity.tool_spec import ToolSpec
|
|
from runtime.node.agent import ModelProvider
|
|
from runtime.node.agent import ModelResponse
|
|
from utils.token_tracker import TokenUsage
|
|
|
|
|
|
class GeminiProvider(ModelProvider):
|
|
"""Gemini provider implementation."""
|
|
|
|
CSV_INLINE_CHAR_LIMIT = 200_000
|
|
CSV_INLINE_SIZE_THRESHOLD_BYTES = 3 * 1024 * 1024 # 3 MB
|
|
|
|
def create_client(self):
|
|
"""
|
|
Create and return the Gemini client.
|
|
"""
|
|
client_kwargs: Dict[str, Any] = {}
|
|
if self.api_key:
|
|
client_kwargs["api_key"] = self.api_key
|
|
|
|
base_url = (self.base_url or "").strip()
|
|
http_options = self._build_http_options(base_url)
|
|
if http_options:
|
|
client_kwargs["http_options"] = http_options
|
|
|
|
return genai.Client(**client_kwargs)
|
|
|
|
def call_model(
|
|
self,
|
|
client,
|
|
conversation: List[Message],
|
|
timeline: List[Any],
|
|
tool_specs: Optional[List[ToolSpec]] = None,
|
|
**kwargs,
|
|
) -> ModelResponse:
|
|
"""
|
|
Call the Gemini model using the unified conversation timeline.
|
|
"""
|
|
contents, system_instruction = self._build_contents(timeline)
|
|
config = self._build_generation_config(system_instruction, tool_specs, kwargs)
|
|
# print(contents)
|
|
# print(config)
|
|
|
|
response: GenerateContentResponse = client.models.generate_content(
|
|
model=self.model_name,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
# print(response)
|
|
|
|
self._track_token_usage(response)
|
|
self._append_response_contents(timeline, response)
|
|
message = self._deserialize_response(response)
|
|
return ModelResponse(message=message, raw_response=response)
|
|
|
|
def extract_token_usage(self, response: Any) -> TokenUsage:
|
|
"""Extract token usage from Gemini usage metadata."""
|
|
usage_metadata = getattr(response, "usage_metadata", None)
|
|
if not usage_metadata:
|
|
return TokenUsage()
|
|
|
|
prompt_tokens = getattr(usage_metadata, "prompt_token_count", None) or 0
|
|
candidate_tokens = getattr(usage_metadata, "candidates_token_count", None) or 0
|
|
total_tokens = getattr(usage_metadata, "total_token_count", None)
|
|
cached_tokens = getattr(usage_metadata, "cached_content_token_count", None)
|
|
|
|
metadata = {
|
|
"prompt_token_count": prompt_tokens,
|
|
"candidates_token_count": candidate_tokens,
|
|
}
|
|
if total_tokens is not None:
|
|
metadata["total_token_count"] = total_tokens
|
|
if cached_tokens is not None:
|
|
metadata["cached_content_token_count"] = cached_tokens
|
|
|
|
return TokenUsage(
|
|
input_tokens=prompt_tokens,
|
|
output_tokens=candidate_tokens,
|
|
total_tokens=total_tokens or (prompt_tokens + candidate_tokens),
|
|
metadata=metadata,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Serialization helpers
|
|
# ---------------------------------------------------------------------
|
|
|
|
def _build_contents(
|
|
self,
|
|
timeline: List[Any],
|
|
) -> Tuple[List[genai_types.Content], Optional[str]]:
|
|
contents: List[genai_types.Content] = []
|
|
system_prompts: List[str] = []
|
|
|
|
for item in timeline:
|
|
if isinstance(item, Message):
|
|
if item.role is MessageRole.SYSTEM:
|
|
text = item.text_content().strip()
|
|
if text:
|
|
system_prompts.append(text)
|
|
continue
|
|
contents.append(self._message_to_content(item))
|
|
continue
|
|
|
|
if isinstance(item, FunctionCallOutputEvent):
|
|
contents.append(self._function_output_event_to_content(item))
|
|
continue
|
|
|
|
if isinstance(item, genai_types.Content):
|
|
contents.append(item)
|
|
|
|
if not contents:
|
|
contents.append(
|
|
genai_types.Content(
|
|
role="user",
|
|
parts=[genai_types.Part(text="")],
|
|
)
|
|
)
|
|
|
|
system_instruction = "\n\n".join(system_prompts) if system_prompts else None
|
|
return contents, system_instruction
|
|
|
|
def _append_response_contents(self, timeline: List[Any], response: Any) -> None:
|
|
candidates = getattr(response, "candidates", None)
|
|
if not candidates:
|
|
return
|
|
for candidate in candidates:
|
|
content = getattr(candidate, "content", None)
|
|
if content:
|
|
timeline.append(content)
|
|
|
|
def _message_to_content(self, message: Message) -> genai_types.Content:
|
|
role = self._map_role(message.role)
|
|
if message.role is MessageRole.TOOL:
|
|
part = self._build_tool_response_part(message)
|
|
return genai_types.Content(role="user", parts=[part])
|
|
|
|
parts: List[genai_types.Part] = []
|
|
for block in message.blocks():
|
|
parts.extend(self._block_to_parts(block))
|
|
if not parts:
|
|
text = message.text_content()
|
|
parts.append(genai_types.Part(text=text))
|
|
return genai_types.Content(role=role, parts=parts)
|
|
|
|
def _function_output_event_to_content(
|
|
self,
|
|
event: FunctionCallOutputEvent,
|
|
) -> genai_types.Content:
|
|
function_name = event.function_name or event.call_id or "tool"
|
|
payload: Dict[str, Any] = {}
|
|
function_result_parts: List[genai_types.FunctionResponsePart] = []
|
|
result_texts: List[str] = []
|
|
|
|
if event.output_blocks:
|
|
for block in event.output_blocks:
|
|
# Describe the block for the text result
|
|
desc = self._describe_block(block)
|
|
if desc:
|
|
result_texts.append(desc)
|
|
|
|
if self._block_has_attachment(block):
|
|
# Check if we should inline this attachment as text
|
|
if self._should_inline_attachment_as_text(block):
|
|
text_content = self._read_attachment_text(block.attachment)
|
|
if text_content:
|
|
result_texts.append(f"\n[Attachment Content: {block.attachment.name}]\n{text_content}")
|
|
continue
|
|
|
|
# Otherwise treat as binary part
|
|
general_parts = self._block_to_parts(block)
|
|
function_result_parts.extend(self._general_parts_to_function_response_parts(general_parts))
|
|
else:
|
|
if event.output_text:
|
|
result_texts.append(event.output_text)
|
|
|
|
payload["result"] = "\n".join(result_texts)
|
|
|
|
function_part = genai_types.Part.from_function_response(
|
|
name=function_name,
|
|
response=payload or {"result": ""},
|
|
parts=function_result_parts or None
|
|
)
|
|
|
|
parts: List[genai_types.Part] = [function_part]
|
|
return genai_types.Content(role="user", parts=parts)
|
|
|
|
def _should_inline_attachment_as_text(self, block: MessageBlock) -> bool:
|
|
if not block.attachment:
|
|
return False
|
|
mime = (block.attachment.mime_type or "").lower()
|
|
return (
|
|
mime.startswith("text/") or
|
|
mime == "application/json" or
|
|
mime.endswith("+json") or
|
|
mime.endswith("+xml")
|
|
)
|
|
|
|
def _read_attachment_text(self, attachment: AttachmentRef) -> Optional[str]:
|
|
data_bytes = self._read_attachment_bytes(attachment)
|
|
return self._bytes_to_text(data_bytes)
|
|
|
|
def _general_parts_to_function_response_parts(self, parts: List[genai_types.Part]) -> List[genai_types.FunctionResponsePart]:
|
|
function_response_parts: List[genai_types.FunctionResponsePart] = []
|
|
for part in parts:
|
|
if part.inline_data:
|
|
# Convert inline_data (bytes) to base64 data URI and use from_uri
|
|
function_response_parts.append(
|
|
genai_types.FunctionResponsePart.from_bytes(data=part.inline_data.data, mime_type=part.inline_data.mime_type or "application/octet-stream")
|
|
)
|
|
if part.file_data:
|
|
function_response_parts.append(
|
|
genai_types.FunctionResponsePart.from_uri(file_uri=part.file_data.file_uri, mime_type=part.file_data.mime_type or "application/octet-stream")
|
|
)
|
|
return function_response_parts
|
|
|
|
def _build_tool_response_part(self, message: Message) -> genai_types.Part:
|
|
tool_name = message.metadata.get("tool_name") if isinstance(message.metadata, dict) else None
|
|
tool_name = tool_name or message.tool_call_id or "tool"
|
|
payload, block_parts = self._serialize_tool_message_payload(message)
|
|
return genai_types.Part(
|
|
function_response=genai_types.FunctionResponse(
|
|
name=tool_name,
|
|
response=payload,
|
|
parts=block_parts or None,
|
|
)
|
|
)
|
|
|
|
def _block_has_attachment(self, block: Any) -> bool:
|
|
return isinstance(block, MessageBlock) and block.attachment is not None
|
|
|
|
def _serialize_tool_message_payload(self, message: Message) -> Tuple[Dict[str, Any], List[genai_types.FunctionResponsePart]]:
|
|
content = message.content
|
|
blocks: List[MessageBlock] = []
|
|
if isinstance(content, str):
|
|
stripped = content.strip()
|
|
if stripped:
|
|
try:
|
|
payload = json.loads(stripped)
|
|
except json.JSONDecodeError:
|
|
payload = {"result": stripped}
|
|
else:
|
|
payload = {"result": ""}
|
|
return payload, []
|
|
|
|
if isinstance(content, list):
|
|
blocks_payload = []
|
|
for block in content:
|
|
if isinstance(block, MessageBlock):
|
|
blocks_payload.append(block.to_dict())
|
|
blocks.append(block)
|
|
elif isinstance(block, dict):
|
|
blocks_payload.append(block)
|
|
try:
|
|
blocks.append(MessageBlock.from_dict(block))
|
|
except Exception:
|
|
continue
|
|
parts = self._blocks_to_function_parts(blocks)
|
|
return {"blocks": blocks_payload, "result": message.text_content()}, parts
|
|
|
|
parts = self._blocks_to_function_parts(blocks)
|
|
return {"result": message.text_content()}, parts
|
|
|
|
def _describe_block(self, block: Any) -> str:
|
|
if isinstance(block, MessageBlock):
|
|
return block.describe()
|
|
if isinstance(block, dict):
|
|
text = block.get("text")
|
|
if text:
|
|
return str(text)
|
|
return str(block)
|
|
|
|
def _block_to_parts(self, block: MessageBlock) -> List[genai_types.Part]:
|
|
if block.type is MessageBlockType.TEXT:
|
|
return [genai_types.Part(text=block.text or "")]
|
|
|
|
if block.type is MessageBlockType.FILE:
|
|
csv_text = self._maybe_inline_large_csv(block)
|
|
if csv_text is not None:
|
|
return [genai_types.Part(text=csv_text)]
|
|
|
|
if block.type in (
|
|
MessageBlockType.IMAGE,
|
|
MessageBlockType.AUDIO,
|
|
MessageBlockType.VIDEO,
|
|
MessageBlockType.FILE,
|
|
):
|
|
media_part = self._attachment_block_to_part(block)
|
|
return [media_part] if media_part else []
|
|
|
|
if block.type is MessageBlockType.DATA:
|
|
data_payload = block.data or {}
|
|
text = block.text or json.dumps(data_payload, ensure_ascii=False)
|
|
return [genai_types.Part(text=text)]
|
|
|
|
return []
|
|
|
|
def _maybe_inline_large_csv(self, block: MessageBlock) -> Optional[str]:
|
|
"""Convert large CSV attachments to inline text to avoid Gemini upload size limits."""
|
|
|
|
attachment = block.attachment
|
|
if not attachment:
|
|
return None
|
|
|
|
mime = (attachment.mime_type or "").lower()
|
|
name = (attachment.name or "").lower()
|
|
if "text/csv" not in mime and not name.endswith(".csv"):
|
|
return None
|
|
if attachment.remote_file_id:
|
|
return None
|
|
|
|
threshold = getattr(
|
|
self,
|
|
"csv_inline_size_threshold_bytes",
|
|
self.CSV_INLINE_SIZE_THRESHOLD_BYTES,
|
|
)
|
|
|
|
size_bytes = attachment.size
|
|
data_bytes: Optional[bytes] = None
|
|
if size_bytes is None:
|
|
data_bytes = self._read_attachment_bytes(attachment)
|
|
if data_bytes is None:
|
|
return None
|
|
size_bytes = len(data_bytes)
|
|
|
|
if size_bytes is None or size_bytes <= threshold:
|
|
return None
|
|
|
|
if data_bytes is None:
|
|
data_bytes = self._read_attachment_bytes(attachment)
|
|
if data_bytes is None:
|
|
return None
|
|
|
|
text = self._bytes_to_text(data_bytes)
|
|
if text is None:
|
|
return None
|
|
|
|
char_limit = getattr(self, "csv_inline_char_limit", self.CSV_INLINE_CHAR_LIMIT)
|
|
truncated = False
|
|
if len(text) > char_limit:
|
|
text = text[:char_limit]
|
|
truncated = True
|
|
|
|
display_name = attachment.name or attachment.attachment_id or "attachment.csv"
|
|
suffix = f"\n\n[truncated after {char_limit} characters]" if truncated else ""
|
|
return f"CSV file '{display_name}' (converted from >3MB upload):\n{text}{suffix}"
|
|
|
|
def _bytes_to_text(self, data_bytes: Optional[bytes]) -> Optional[str]:
|
|
if data_bytes is None:
|
|
return None
|
|
try:
|
|
return data_bytes.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
return data_bytes.decode("utf-8", errors="replace")
|
|
|
|
def _attachment_block_to_part(self, block: MessageBlock) -> Optional[genai_types.Part]:
|
|
attachment = block.attachment
|
|
if not attachment:
|
|
return None
|
|
|
|
metadata = attachment.metadata or {}
|
|
gemini_file_uri = metadata.get("gemini_file_uri") or attachment.remote_file_id
|
|
mime_type = attachment.mime_type or self._guess_mime_from_block(block)
|
|
|
|
if gemini_file_uri:
|
|
return genai_types.Part(
|
|
file_data=genai_types.FileData(
|
|
file_uri=gemini_file_uri,
|
|
mime_type=mime_type,
|
|
# display_name=attachment.name
|
|
)
|
|
)
|
|
|
|
blob_data = self._read_attachment_bytes(attachment)
|
|
if blob_data is None:
|
|
return None
|
|
|
|
return genai_types.Part(
|
|
inline_data=genai_types.Blob(
|
|
mime_type=mime_type or "application/octet-stream",
|
|
data=blob_data,
|
|
# display_name=attachment.name,
|
|
)
|
|
)
|
|
|
|
def _blocks_to_function_parts(
|
|
self,
|
|
blocks: Optional[Sequence[Any]],
|
|
) -> List[genai_types.FunctionResponsePart]:
|
|
if not blocks:
|
|
return []
|
|
parts: List[genai_types.FunctionResponsePart] = []
|
|
for block in blocks:
|
|
if not isinstance(block, MessageBlock):
|
|
if isinstance(block, dict):
|
|
try:
|
|
block = MessageBlock.from_dict(block)
|
|
except Exception:
|
|
continue
|
|
else:
|
|
continue
|
|
attachment = block.attachment
|
|
if not attachment:
|
|
continue
|
|
mime_type = attachment.mime_type or self._guess_mime_from_block(block)
|
|
file_uri = (attachment.metadata or {}).get("gemini_file_uri") or attachment.remote_file_id
|
|
if file_uri:
|
|
parts.append(
|
|
genai_types.FunctionResponsePart(
|
|
file_data=genai_types.FunctionResponseFileData(
|
|
file_uri=file_uri,
|
|
mime_type=mime_type,
|
|
display_name=attachment.name,
|
|
)
|
|
)
|
|
)
|
|
continue
|
|
data_bytes = self._read_attachment_bytes(attachment)
|
|
if not data_bytes:
|
|
continue
|
|
parts.append(
|
|
genai_types.FunctionResponsePart(
|
|
inline_data=genai_types.FunctionResponseBlob(
|
|
mime_type=mime_type or "application/octet-stream",
|
|
data=data_bytes,
|
|
display_name=attachment.name,
|
|
)
|
|
)
|
|
)
|
|
return parts
|
|
|
|
def _coerce_message_blocks(self, payload: Any) -> List[MessageBlock]:
|
|
if not isinstance(payload, Sequence) or isinstance(payload, (str, bytes, bytearray)):
|
|
return []
|
|
blocks: List[MessageBlock] = []
|
|
for item in payload:
|
|
if isinstance(item, MessageBlock):
|
|
blocks.append(item)
|
|
elif isinstance(item, dict):
|
|
try:
|
|
blocks.append(MessageBlock.from_dict(item))
|
|
except Exception:
|
|
continue
|
|
return blocks
|
|
|
|
def _encode_thought_signature(self, value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, bytes):
|
|
return base64.b64encode(value).decode("ascii")
|
|
try:
|
|
return str(value)
|
|
except Exception:
|
|
return None
|
|
|
|
def _read_attachment_bytes(self, attachment: AttachmentRef) -> Optional[bytes]:
|
|
if attachment.data_uri:
|
|
decoded = self._decode_data_uri(attachment.data_uri)
|
|
if decoded is not None:
|
|
return decoded
|
|
if attachment.local_path and os.path.exists(attachment.local_path):
|
|
try:
|
|
with open(attachment.local_path, "rb") as handle:
|
|
return handle.read()
|
|
except OSError:
|
|
return None
|
|
return None
|
|
|
|
def _decode_data_uri(self, data_uri: str) -> Optional[bytes]:
|
|
if not data_uri.startswith("data:"):
|
|
return None
|
|
header, _, data = data_uri.partition(",")
|
|
if not _:
|
|
return None
|
|
if ";base64" in header:
|
|
try:
|
|
return base64.b64decode(data)
|
|
except (ValueError, binascii.Error):
|
|
return None
|
|
return data.encode("utf-8")
|
|
|
|
def _guess_mime_from_block(self, block: MessageBlock) -> str:
|
|
if block.attachment and block.attachment.mime_type:
|
|
return block.attachment.mime_type
|
|
if block.type is MessageBlockType.IMAGE:
|
|
return "image/png"
|
|
if block.type is MessageBlockType.AUDIO:
|
|
return "audio/mpeg"
|
|
if block.type is MessageBlockType.VIDEO:
|
|
return "video/mp4"
|
|
return "application/octet-stream"
|
|
|
|
def _map_role(self, role: MessageRole) -> str:
|
|
if role is MessageRole.USER:
|
|
return "user"
|
|
if role is MessageRole.ASSISTANT:
|
|
return "model"
|
|
if role is MessageRole.TOOL:
|
|
return "tool"
|
|
return "user"
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Config builders
|
|
# ---------------------------------------------------------------------
|
|
|
|
def _build_generation_config(
|
|
self,
|
|
system_instruction: Optional[str],
|
|
tool_specs: Optional[List[ToolSpec]],
|
|
call_params: Dict[str, Any],
|
|
) -> genai_types.GenerateContentConfig:
|
|
params = dict(self.params or {})
|
|
params.update(call_params)
|
|
|
|
config_kwargs: Dict[str, Any] = {}
|
|
if system_instruction:
|
|
config_kwargs["system_instruction"] = system_instruction
|
|
|
|
for key in (
|
|
"temperature",
|
|
"top_p",
|
|
"top_k",
|
|
"candidate_count",
|
|
"max_output_tokens",
|
|
"response_modalities",
|
|
"stop_sequences",
|
|
"seed",
|
|
"presence_penalty",
|
|
"frequency_penalty",
|
|
):
|
|
if key in params:
|
|
config_kwargs[key] = params.pop(key)
|
|
|
|
safety_settings = params.pop("safety_settings", None)
|
|
if safety_settings:
|
|
config_kwargs["safety_settings"] = safety_settings
|
|
|
|
image_config = params.pop("image_config", None)
|
|
aspect_ratio = params.pop("aspect_ratio", None)
|
|
if aspect_ratio:
|
|
if image_config is None:
|
|
image_config = {"aspect_ratio": aspect_ratio}
|
|
elif isinstance(image_config, dict):
|
|
image_config = dict(image_config)
|
|
image_config.setdefault("aspect_ratio", aspect_ratio)
|
|
elif isinstance(image_config, genai_types.ImageConfig):
|
|
try:
|
|
image_config.aspect_ratio = aspect_ratio
|
|
except Exception:
|
|
image_config = {"aspect_ratio": aspect_ratio}
|
|
else:
|
|
image_config = {"aspect_ratio": aspect_ratio}
|
|
if image_config:
|
|
config_kwargs["image_config"] = self._coerce_image_config(image_config)
|
|
|
|
audio_config = params.pop("audio_config", None)
|
|
if audio_config:
|
|
config_kwargs["audio_config"] = audio_config
|
|
|
|
video_config = params.pop("video_config", None)
|
|
if video_config:
|
|
config_kwargs["video_config"] = video_config
|
|
|
|
tools = self._build_tools(tool_specs or [])
|
|
if tools:
|
|
config_kwargs["tools"] = tools
|
|
|
|
tool_config_payload = params.pop("tool_config", None)
|
|
function_calling_payload = params.pop("function_calling_config", None)
|
|
if function_calling_payload:
|
|
tool_config_payload = tool_config_payload or {}
|
|
tool_config_payload["function_calling_config"] = function_calling_payload
|
|
|
|
if tool_config_payload:
|
|
config_kwargs["tool_config"] = self._coerce_tool_config(tool_config_payload)
|
|
|
|
automatic_fn_calling = params.pop("automatic_function_calling", None)
|
|
if automatic_fn_calling:
|
|
config_kwargs["automatic_function_calling"] = self._coerce_automatic_function_calling(
|
|
automatic_fn_calling
|
|
)
|
|
|
|
return genai_types.GenerateContentConfig(**config_kwargs)
|
|
|
|
def _build_http_options(self, base_url: str) -> Optional[genai_types.HttpOptions]:
|
|
if not base_url:
|
|
return None
|
|
try:
|
|
return genai_types.HttpOptions(base_url=base_url, timeout=4 * 60 * 1000) # 4 min
|
|
except Exception:
|
|
return None
|
|
|
|
def _coerce_image_config(self, image_config: Any) -> Any:
|
|
if isinstance(image_config, genai_types.ImageConfig):
|
|
return image_config
|
|
if isinstance(image_config, dict):
|
|
try:
|
|
return genai_types.ImageConfig(**image_config)
|
|
except Exception:
|
|
return image_config
|
|
return image_config
|
|
|
|
def _build_tools(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]:
|
|
if not tool_specs:
|
|
return []
|
|
|
|
declarations = []
|
|
for spec in tool_specs:
|
|
fn_payload = spec.to_gemini_function()
|
|
parameters = fn_payload.get("parameters") or {"type": "object", "properties": {}}
|
|
if 'title' in parameters:
|
|
parameters.pop('title')
|
|
# Replace 'title' with 'description' in properties
|
|
for prop_name, prop_value in parameters.get('properties', {}).items():
|
|
if isinstance(prop_value, dict) and 'title' in prop_value:
|
|
prop_value['description'] = prop_value.pop('title')
|
|
declarations.append(
|
|
genai_types.FunctionDeclaration(
|
|
name=fn_payload.get("name", ""),
|
|
description=fn_payload.get("description") or "",
|
|
parameters=parameters,
|
|
)
|
|
)
|
|
return [genai_types.Tool(function_declarations=declarations)]
|
|
|
|
def _coerce_tool_config(self, payload: Any) -> genai_types.ToolConfig:
|
|
if isinstance(payload, genai_types.ToolConfig):
|
|
return payload
|
|
kwargs: Dict[str, Any] = {}
|
|
if isinstance(payload, dict):
|
|
fn_payload = payload.get("function_calling_config")
|
|
if fn_payload:
|
|
kwargs["function_calling_config"] = self._coerce_function_calling_config(fn_payload)
|
|
return genai_types.ToolConfig(**kwargs)
|
|
|
|
def _coerce_function_calling_config(self, payload: Any) -> genai_types.FunctionCallingConfig:
|
|
if isinstance(payload, genai_types.FunctionCallingConfig):
|
|
return payload
|
|
if isinstance(payload, str):
|
|
return genai_types.FunctionCallingConfig(mode=payload)
|
|
if isinstance(payload, dict):
|
|
return genai_types.FunctionCallingConfig(**payload)
|
|
raise ValueError("Invalid function calling configuration payload")
|
|
|
|
def _coerce_automatic_function_calling(self, payload: Any) -> Any:
|
|
config_cls = getattr(genai_types, "AutomaticFunctionCallingConfig", None)
|
|
if config_cls is None:
|
|
raise ValueError("Automatic function calling config not supported in current SDK version")
|
|
if isinstance(payload, config_cls):
|
|
return payload
|
|
if isinstance(payload, dict):
|
|
return config_cls(**payload)
|
|
raise ValueError("Invalid automatic function calling config payload")
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Response parsing
|
|
# ---------------------------------------------------------------------
|
|
|
|
def _deserialize_response(self, response: Any) -> Message:
|
|
candidate = self._select_primary_candidate(response)
|
|
if not candidate:
|
|
return Message(role=MessageRole.ASSISTANT, content="")
|
|
|
|
content = getattr(candidate, "content", None)
|
|
if not content:
|
|
return Message(role=MessageRole.ASSISTANT, content=response.text if hasattr(response, "text") else "")
|
|
|
|
blocks, tool_calls = self._parse_candidate_parts(getattr(content, "parts", []) or [])
|
|
if not blocks:
|
|
fallback = getattr(response, "text", None) or ""
|
|
blocks = [MessageBlock(MessageBlockType.TEXT, text=fallback)] if fallback else []
|
|
|
|
return Message(
|
|
role=MessageRole.ASSISTANT,
|
|
content=blocks or "",
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
def _select_primary_candidate(self, response: Any) -> Any:
|
|
candidates = getattr(response, "candidates", None) or []
|
|
if not candidates:
|
|
return None
|
|
return candidates[0]
|
|
|
|
def _parse_candidate_parts(
|
|
self,
|
|
parts: Sequence[Any],
|
|
) -> Tuple[List[MessageBlock], List[ToolCallPayload]]:
|
|
blocks: List[MessageBlock] = []
|
|
tool_calls: List[ToolCallPayload] = []
|
|
|
|
for part in parts:
|
|
if hasattr(part, "text") and part.text is not None:
|
|
blocks.append(MessageBlock(MessageBlockType.TEXT, text=part.text))
|
|
continue
|
|
|
|
function_call = getattr(part, "function_call", None)
|
|
if function_call:
|
|
thought_signature = getattr(part, "thought_signature", None)
|
|
tool_calls.append(
|
|
self._build_tool_call_payload(function_call, thought_signature=thought_signature)
|
|
)
|
|
continue
|
|
|
|
inline_data = getattr(part, "inline_data", None)
|
|
if inline_data:
|
|
blocks.append(self._build_inline_block(inline_data))
|
|
continue
|
|
|
|
file_data = getattr(part, "file_data", None)
|
|
if file_data:
|
|
blocks.append(self._build_file_block(file_data))
|
|
continue
|
|
|
|
function_response = getattr(part, "function_response", None)
|
|
if function_response:
|
|
blocks.append(
|
|
MessageBlock(
|
|
type=MessageBlockType.DATA,
|
|
text=json.dumps(function_response.response or {}, ensure_ascii=False),
|
|
data={
|
|
"function_name": getattr(function_response, "name", ""),
|
|
"response": function_response.response or {},
|
|
},
|
|
)
|
|
)
|
|
continue
|
|
|
|
return blocks, tool_calls
|
|
|
|
def _build_tool_call_payload(self, fn_call: Any, *, thought_signature: Any = None) -> ToolCallPayload:
|
|
call_id = getattr(fn_call, "name", "") or uuid.uuid4().hex
|
|
arguments = getattr(fn_call, "args", {}) or {}
|
|
try:
|
|
arg_str = json.dumps(arguments, ensure_ascii=False)
|
|
except (TypeError, ValueError):
|
|
arg_str = str(arguments)
|
|
metadata: Dict[str, Any] = {}
|
|
encoded_signature = self._encode_thought_signature(thought_signature)
|
|
if encoded_signature:
|
|
metadata["gemini_thought_signature_b64"] = encoded_signature
|
|
return ToolCallPayload(
|
|
id=call_id,
|
|
function_name=getattr(fn_call, "name", "") or call_id,
|
|
arguments=arg_str,
|
|
type="function",
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _build_inline_block(self, blob: Any) -> MessageBlock:
|
|
mime_type = getattr(blob, "mime_type", "") or "application/octet-stream"
|
|
data_bytes = getattr(blob, "data", None) or b""
|
|
data_uri = self._encode_data_uri(mime_type, data_bytes)
|
|
block_type = self._block_type_from_mime(mime_type)
|
|
return MessageBlock(
|
|
type=block_type,
|
|
attachment=AttachmentRef(
|
|
attachment_id=uuid.uuid4().hex,
|
|
mime_type=mime_type,
|
|
data_uri=data_uri,
|
|
metadata={"source": "gemini_inline"},
|
|
),
|
|
)
|
|
|
|
def _build_file_block(self, file_data: Any) -> MessageBlock:
|
|
mime_type = getattr(file_data, "mime_type", None)
|
|
file_uri = getattr(file_data, "file_uri", None) or getattr(file_data, "file", None)
|
|
block_type = self._block_type_from_mime(mime_type or "")
|
|
return MessageBlock(
|
|
type=block_type,
|
|
attachment=AttachmentRef(
|
|
attachment_id=uuid.uuid4().hex,
|
|
mime_type=mime_type,
|
|
remote_file_id=file_uri,
|
|
metadata={"gemini_file_uri": file_uri, "source": "gemini_file"},
|
|
),
|
|
)
|
|
|
|
def _block_type_from_mime(self, mime_type: str) -> MessageBlockType:
|
|
if mime_type.startswith("image/"):
|
|
return MessageBlockType.IMAGE
|
|
if mime_type.startswith("audio/"):
|
|
return MessageBlockType.AUDIO
|
|
if mime_type.startswith("video/"):
|
|
return MessageBlockType.VIDEO
|
|
return MessageBlockType.FILE
|
|
|
|
def _encode_data_uri(self, mime_type: str, data: bytes) -> str:
|
|
encoded = base64.b64encode(data).decode("utf-8")
|
|
return f"data:{mime_type};base64,{encoded}"
|
|
|
|
# ---------------------------------------------------------------------
|
|
# Token tracking
|
|
# ---------------------------------------------------------------------
|
|
|
|
def _track_token_usage(self, response: Any) -> None:
|
|
token_tracker = getattr(self.config, "token_tracker", None)
|
|
if not token_tracker:
|
|
return
|
|
|
|
usage = self.extract_token_usage(response)
|
|
if usage.input_tokens == 0 and usage.output_tokens == 0 and not usage.metadata:
|
|
return
|
|
|
|
node_id = getattr(self.config, "node_id", "ALL")
|
|
usage.node_id = node_id
|
|
usage.model_name = self.model_name
|
|
usage.workflow_id = token_tracker.workflow_id
|
|
usage.provider = "gemini"
|
|
|
|
token_tracker.record_usage(node_id, self.model_name, usage, provider="gemini")
|