[fix] support chat interface

This commit is contained in:
NA-Wen 2026-01-08 21:26:00 +08:00
parent 1079544652
commit 82dcf4a587
2 changed files with 214 additions and 20 deletions

View File

@ -4,7 +4,7 @@ import base64
import binascii
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from urllib.parse import unquote_to_bytes
import openai
@ -46,7 +46,6 @@ class OpenAIProvider(ModelProvider):
)
else:
return OpenAI(
base_url="https://api.openai.com/v1",
api_key=self.api_key,
)
@ -60,26 +59,43 @@ class OpenAIProvider(ModelProvider):
) -> ModelResponse:
"""
Call the OpenAI model with the given messages and parameters.
Args:
client: OpenAI client instance
conversation: List of messages in the conversation
tool_specs: Optional tool specifications
**kwargs: Additional parameters for the model call
Returns:
ModelResponse containing content and potentially tool calls
"""
# 1. Determine if we should use Chat Completions directly
is_chat = self._is_chat_completions_mode(client)
if is_chat:
request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
response = client.chat.completions.create(**request_payload)
self._track_token_usage(response)
self._append_chat_response_output(timeline, response)
message = self._deserialize_chat_response(response)
return ModelResponse(message=message, raw_response=response)
# 2. Try Responses API with fallback
request_payload = self._build_request_payload(timeline, tool_specs, kwargs)
# print(request_payload)
try:
response = client.responses.create(**request_payload)
self._track_token_usage(response)
self._append_response_output(timeline, response)
message = self._deserialize_response(response)
return ModelResponse(message=message, raw_response=response)
except Exception as e:
new_request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
response = client.chat.completions.create(**new_request_payload)
self._track_token_usage(response)
self._append_chat_response_output(timeline, response)
message = self._deserialize_chat_response(response)
return ModelResponse(message=message, raw_response=response)
response = client.responses.create(**request_payload)
# print(response)
self._track_token_usage(response)
self._append_response_output(timeline, response)
message = self._deserialize_response(response)
return ModelResponse(message=message, raw_response=response)
def _is_chat_completions_mode(self, client: Any) -> bool:
"""Determine if we should use standard chat completions instead of responses API."""
protocol = self.params.get("protocol")
if protocol == "chat":
return True
if protocol == "responses":
return False
# Default to Responses API only if it exists on the client
return not hasattr(client, "responses")
def extract_token_usage(self, response: Any) -> TokenUsage:
"""
@ -208,6 +224,185 @@ class OpenAIProvider(ModelProvider):
payload.update(params)
return payload
def _build_chat_payload(
self,
timeline: List[Any],
tool_specs: Optional[List[ToolSpec]],
raw_params: Dict[str, Any],
) -> Dict[str, Any]:
"""Construct standard Chat Completions API payload."""
params = dict(raw_params)
max_output_tokens = params.pop("max_output_tokens", None)
max_tokens = params.pop("max_tokens", None)
if max_tokens is None and max_output_tokens is not None:
max_tokens = max_output_tokens
messages: List[Any] = []
for item in timeline:
serialized = self._serialize_timeline_item_for_chat(item)
if serialized is not None:
messages.append(serialized)
if not messages:
messages = [{"role": "user", "content": ""}]
payload: Dict[str, Any] = {
"model": self.model_name,
"messages": messages,
"temperature": params.pop("temperature", 0.7),
}
if max_tokens is not None:
payload["max_tokens"] = max_tokens
elif self.params.get("max_tokens"):
payload["max_tokens"] = self.params["max_tokens"]
user_tools = params.pop("tools", None)
merged_tools: List[Any] = []
if isinstance(user_tools, list):
merged_tools.extend(user_tools)
if tool_specs:
for spec in tool_specs:
merged_tools.append({
"type": "function",
"function": {
"name": spec.name,
"description": spec.description,
"parameters": spec.parameters or {"type": "object", "properties": {}},
}
})
if merged_tools:
payload["tools"] = merged_tools
tool_choice = params.pop("tool_choice", None)
if tool_choice is not None:
payload["tool_choice"] = tool_choice
elif tool_specs:
payload.setdefault("tool_choice", "auto")
payload.update(params)
return payload
def _serialize_timeline_item_for_chat(self, item: Any) -> Optional[Any]:
if isinstance(item, Message):
return self._serialize_message_for_chat(item)
if isinstance(item, FunctionCallOutputEvent):
return self._serialize_function_call_output_event_for_chat(item)
if isinstance(item, dict):
# basic conversion if it looks like a Responses output
role = item.get("role")
content = item.get("content")
tool_calls = item.get("tool_calls")
if role and (content or tool_calls):
return {
"role": role,
"content": self._transform_blocks_for_chat(content) if isinstance(content, list) else content,
"tool_calls": tool_calls
}
return None
def _serialize_message_for_chat(self, message: Message) -> Dict[str, Any]:
"""Convert internal Message to standard Chat Completions schema."""
role_value = message.role.value
blocks = message.blocks()
if not blocks:
content = message.text_content()
else:
content = self._transform_blocks_for_chat(self._serialize_blocks(blocks, message.role))
payload: Dict[str, Any] = {
"role": role_value,
"content": content,
}
if message.name:
payload["name"] = message.name
if message.tool_call_id:
payload["tool_call_id"] = message.tool_call_id
if message.tool_calls:
payload["tool_calls"] = [tc.to_openai_dict() for tc in message.tool_calls]
return payload
def _serialize_function_call_output_event_for_chat(self, event: FunctionCallOutputEvent) -> Dict[str, Any]:
"""Convert tool result to standard Chat Completions schema."""
text = event.output_text or ""
if event.output_blocks:
# simple concatenation for tool output in chat mode
text = "\n".join(b.describe() for b in event.output_blocks)
return {
"role": "tool",
"tool_call_id": event.call_id or "tool_call",
"content": text,
}
def _transform_blocks_for_chat(self, blocks: List[Dict[str, Any]]) -> Union[str, List[Dict[str, Any]]]:
"""Convert Responses block types to Chat block types (e.g., input_text -> text)."""
transformed: List[Dict[str, Any]] = []
for block in blocks:
b_type = block.get("type", "")
if b_type in ("input_text", "output_text"):
transformed.append({"type": "text", "text": block.get("text", "")})
elif b_type in ("input_image", "output_image"):
transformed.append({"type": "image_url", "image_url": {"url": block.get("image_url", "")}})
else:
# Keep as is or drop if complex
transformed.append(block)
# If only one text block, return as string for better compatibility
if len(transformed) == 1 and transformed[0]["type"] == "text":
return transformed[0]["text"]
return transformed
def _deserialize_chat_response(self, response: Any) -> Message:
"""Convert Chat Completions output to internal Message."""
choices = self._get_attr(response, "choices") or []
if not choices:
return Message(role=MessageRole.ASSISTANT, content="")
choice = choices[0]
msg = self._get_attr(choice, "message")
tool_calls: List[ToolCallPayload] = []
tc_data = self._get_attr(msg, "tool_calls")
if tc_data:
for tc in tc_data:
f_data = self._get_attr(tc, "function") or {}
tool_calls.append(ToolCallPayload(
id=self._get_attr(tc, "id"),
function_name=self._get_attr(f_data, "name"),
arguments=self._get_attr(f_data, "arguments"),
type="function"
))
return Message(
role=MessageRole.ASSISTANT,
content=self._get_attr(msg, "content") or "",
tool_calls=tool_calls
)
def _append_chat_response_output(self, timeline: List[Any], response: Any) -> None:
"""Add chat response to timeline, preserving tool_calls (Chat API compatible)."""
msg = response.choices[0].message
assistant_msg = {
"role": "assistant",
"content": msg.content or ""
}
if getattr(msg, "tool_calls", None):
assistant_msg["tool_calls"] = []
for tc in msg.tool_calls:
assistant_msg["tool_calls"].append({
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
})
timeline.append(assistant_msg)
def _serialize_timeline_item(self, item: Any) -> Optional[Any]:
if isinstance(item, Message):
return self._serialize_message_for_responses(item)

View File

@ -283,7 +283,6 @@ class AgentNodeExecutor(NodeExecutor):
self._record_model_call(node, last_input, None, CallStage.BEFORE)
response = self._execute_with_retry(node, retry_policy, _call_provider)
self.log_manager.debug(response.str_raw_response())
# print(timeline)
self._record_model_call(node, last_input, response, CallStage.AFTER)
return response