mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
Merge pull request #502 from zxrys/main
fix: OpenAI chat.completions.create API function calling error
This commit is contained in:
commit
48c4263cdb
@ -1,6 +1,7 @@
|
||||
"""OpenAI provider implementation."""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
import binascii
|
||||
import os
|
||||
@ -64,7 +65,7 @@ class OpenAIProvider(ModelProvider):
|
||||
is_chat = self._is_chat_completions_mode(client)
|
||||
|
||||
if is_chat:
|
||||
request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
|
||||
request_payload = self._build_chat_payload(conversation, tool_specs, kwargs)
|
||||
response = client.chat.completions.create(**request_payload)
|
||||
self._track_token_usage(response)
|
||||
self._append_chat_response_output(timeline, response)
|
||||
@ -80,7 +81,7 @@ class OpenAIProvider(ModelProvider):
|
||||
message = self._deserialize_response(response)
|
||||
return ModelResponse(message=message, raw_response=response)
|
||||
except Exception as e:
|
||||
new_request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
|
||||
new_request_payload = self._build_chat_payload(conversation, tool_specs, kwargs)
|
||||
response = client.chat.completions.create(**new_request_payload)
|
||||
self._track_token_usage(response)
|
||||
self._append_chat_response_output(timeline, response)
|
||||
@ -226,7 +227,7 @@ class OpenAIProvider(ModelProvider):
|
||||
|
||||
def _build_chat_payload(
|
||||
self,
|
||||
timeline: List[Any],
|
||||
conversation: List[Message],
|
||||
tool_specs: Optional[List[ToolSpec]],
|
||||
raw_params: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
@ -238,8 +239,8 @@ class OpenAIProvider(ModelProvider):
|
||||
max_tokens = max_output_tokens
|
||||
|
||||
messages: List[Any] = []
|
||||
for item in timeline:
|
||||
serialized = self._serialize_timeline_item_for_chat(item)
|
||||
for item in conversation:
|
||||
serialized = self._serialize_message_for_chat(item)
|
||||
if serialized is not None:
|
||||
messages.append(serialized)
|
||||
|
||||
@ -366,12 +367,19 @@ class OpenAIProvider(ModelProvider):
|
||||
tool_calls: List[ToolCallPayload] = []
|
||||
tc_data = self._get_attr(msg, "tool_calls")
|
||||
if tc_data:
|
||||
for tc in tc_data:
|
||||
for idx, tc in enumerate(tc_data):
|
||||
f_data = self._get_attr(tc, "function") or {}
|
||||
function_name = self._get_attr(f_data, "name") or ""
|
||||
arguments = self._get_attr(f_data, "arguments") or ""
|
||||
if not isinstance(arguments, str):
|
||||
arguments = str(arguments)
|
||||
call_id = self._get_attr(tc, "id")
|
||||
if not call_id:
|
||||
call_id = self._build_tool_call_id(function_name, arguments, fallback_prefix=f"tool_call_{idx}")
|
||||
tool_calls.append(ToolCallPayload(
|
||||
id=self._get_attr(tc, "id"),
|
||||
function_name=self._get_attr(f_data, "name"),
|
||||
arguments=self._get_attr(f_data, "arguments"),
|
||||
id=call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
type="function"
|
||||
))
|
||||
|
||||
@ -391,13 +399,18 @@ class OpenAIProvider(ModelProvider):
|
||||
|
||||
if getattr(msg, "tool_calls", None):
|
||||
assistant_msg["tool_calls"] = []
|
||||
for tc in msg.tool_calls:
|
||||
for idx, tc in enumerate(msg.tool_calls):
|
||||
function_name = tc.function.name
|
||||
arguments = tc.function.arguments or ""
|
||||
if not isinstance(arguments, str):
|
||||
arguments = str(arguments)
|
||||
call_id = tc.id or self._build_tool_call_id(function_name, arguments, fallback_prefix=f"tool_call_{idx}")
|
||||
assistant_msg["tool_calls"].append({
|
||||
"id": tc.id,
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
})
|
||||
|
||||
@ -691,7 +704,6 @@ class OpenAIProvider(ModelProvider):
|
||||
)
|
||||
|
||||
def _parse_tool_call(self, payload: Any) -> Optional[ToolCallPayload]:
|
||||
call_id = self._get_attr(payload, "call_id") or self._get_attr(payload, "id") or ""
|
||||
function_payload = self._get_attr(payload, "function") or {}
|
||||
function_name = self._get_attr(function_payload, "name") or self._get_attr(payload, "name") or ""
|
||||
arguments = self._get_attr(function_payload, "arguments") or self._get_attr(payload, "arguments") or ""
|
||||
@ -706,6 +718,9 @@ class OpenAIProvider(ModelProvider):
|
||||
arguments_str = str(arguments)
|
||||
else:
|
||||
arguments_str = str(arguments)
|
||||
call_id = self._get_attr(payload, "call_id") or self._get_attr(payload, "id") or ""
|
||||
if not call_id:
|
||||
call_id = self._build_tool_call_id(function_name, arguments_str)
|
||||
return ToolCallPayload(
|
||||
id=call_id,
|
||||
function_name=function_name,
|
||||
@ -713,6 +728,12 @@ class OpenAIProvider(ModelProvider):
|
||||
type="function",
|
||||
)
|
||||
|
||||
def _build_tool_call_id(self, function_name: str, arguments: str, *, fallback_prefix: str = "tool_call") -> str:
|
||||
base = function_name or fallback_prefix
|
||||
payload = f"{base}:{arguments or ''}".encode("utf-8")
|
||||
digest = hashlib.md5(payload).hexdigest()[:8]
|
||||
return f"{base}_{digest}"
|
||||
|
||||
def _get_attr(self, payload: Any, key: str) -> Any:
|
||||
if hasattr(payload, key):
|
||||
return getattr(payload, key)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user