Merge pull request #502 from zxrys/main

fix: OpenAI chat.completions.create API function calling error
This commit is contained in:
Yufan Dang 2026-01-09 23:15:55 +08:00 committed by GitHub
commit 48c4263cdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
"""OpenAI provider implementation."""
import base64
import hashlib
import binascii
import os
@ -64,7 +65,7 @@ class OpenAIProvider(ModelProvider):
is_chat = self._is_chat_completions_mode(client)
if is_chat:
request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
request_payload = self._build_chat_payload(conversation, tool_specs, kwargs)
response = client.chat.completions.create(**request_payload)
self._track_token_usage(response)
self._append_chat_response_output(timeline, response)
@ -80,7 +81,7 @@ class OpenAIProvider(ModelProvider):
message = self._deserialize_response(response)
return ModelResponse(message=message, raw_response=response)
except Exception as e:
new_request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
new_request_payload = self._build_chat_payload(conversation, tool_specs, kwargs)
response = client.chat.completions.create(**new_request_payload)
self._track_token_usage(response)
self._append_chat_response_output(timeline, response)
@ -226,7 +227,7 @@ class OpenAIProvider(ModelProvider):
def _build_chat_payload(
self,
timeline: List[Any],
conversation: List[Message],
tool_specs: Optional[List[ToolSpec]],
raw_params: Dict[str, Any],
) -> Dict[str, Any]:
@ -238,8 +239,8 @@ class OpenAIProvider(ModelProvider):
max_tokens = max_output_tokens
messages: List[Any] = []
for item in timeline:
serialized = self._serialize_timeline_item_for_chat(item)
for item in conversation:
serialized = self._serialize_message_for_chat(item)
if serialized is not None:
messages.append(serialized)
@ -366,12 +367,19 @@ class OpenAIProvider(ModelProvider):
tool_calls: List[ToolCallPayload] = []
tc_data = self._get_attr(msg, "tool_calls")
if tc_data:
for tc in tc_data:
for idx, tc in enumerate(tc_data):
f_data = self._get_attr(tc, "function") or {}
function_name = self._get_attr(f_data, "name") or ""
arguments = self._get_attr(f_data, "arguments") or ""
if not isinstance(arguments, str):
arguments = str(arguments)
call_id = self._get_attr(tc, "id")
if not call_id:
call_id = self._build_tool_call_id(function_name, arguments, fallback_prefix=f"tool_call_{idx}")
tool_calls.append(ToolCallPayload(
id=self._get_attr(tc, "id"),
function_name=self._get_attr(f_data, "name"),
arguments=self._get_attr(f_data, "arguments"),
id=call_id,
function_name=function_name,
arguments=arguments,
type="function"
))
@ -391,13 +399,18 @@ class OpenAIProvider(ModelProvider):
if getattr(msg, "tool_calls", None):
assistant_msg["tool_calls"] = []
for tc in msg.tool_calls:
for idx, tc in enumerate(msg.tool_calls):
function_name = tc.function.name
arguments = tc.function.arguments or ""
if not isinstance(arguments, str):
arguments = str(arguments)
call_id = tc.id or self._build_tool_call_id(function_name, arguments, fallback_prefix=f"tool_call_{idx}")
assistant_msg["tool_calls"].append({
"id": tc.id,
"id": call_id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
"name": function_name,
"arguments": arguments,
},
})
@ -691,7 +704,6 @@ class OpenAIProvider(ModelProvider):
)
def _parse_tool_call(self, payload: Any) -> Optional[ToolCallPayload]:
call_id = self._get_attr(payload, "call_id") or self._get_attr(payload, "id") or ""
function_payload = self._get_attr(payload, "function") or {}
function_name = self._get_attr(function_payload, "name") or self._get_attr(payload, "name") or ""
arguments = self._get_attr(function_payload, "arguments") or self._get_attr(payload, "arguments") or ""
@ -706,6 +718,9 @@ class OpenAIProvider(ModelProvider):
arguments_str = str(arguments)
else:
arguments_str = str(arguments)
call_id = self._get_attr(payload, "call_id") or self._get_attr(payload, "id") or ""
if not call_id:
call_id = self._build_tool_call_id(function_name, arguments_str)
return ToolCallPayload(
id=call_id,
function_name=function_name,
@ -713,6 +728,12 @@ class OpenAIProvider(ModelProvider):
type="function",
)
def _build_tool_call_id(self, function_name: str, arguments: str, *, fallback_prefix: str = "tool_call") -> str:
base = function_name or fallback_prefix
payload = f"{base}:{arguments or ''}".encode("utf-8")
digest = hashlib.md5(payload).hexdigest()[:8]
return f"{base}_{digest}"
def _get_attr(self, payload: Any, key: str) -> Any:
if hasattr(payload, key):
return getattr(payload, key)