diff --git a/runtime/node/agent/providers/openai_provider.py b/runtime/node/agent/providers/openai_provider.py index 87cb7823..d42dd171 100755 --- a/runtime/node/agent/providers/openai_provider.py +++ b/runtime/node/agent/providers/openai_provider.py @@ -1,6 +1,7 @@ """OpenAI provider implementation.""" import base64 +import hashlib import binascii import os @@ -64,7 +65,7 @@ class OpenAIProvider(ModelProvider): is_chat = self._is_chat_completions_mode(client) if is_chat: - request_payload = self._build_chat_payload(timeline, tool_specs, kwargs) + request_payload = self._build_chat_payload(conversation, tool_specs, kwargs) response = client.chat.completions.create(**request_payload) self._track_token_usage(response) self._append_chat_response_output(timeline, response) @@ -80,7 +81,7 @@ class OpenAIProvider(ModelProvider): message = self._deserialize_response(response) return ModelResponse(message=message, raw_response=response) except Exception as e: - new_request_payload = self._build_chat_payload(timeline, tool_specs, kwargs) + new_request_payload = self._build_chat_payload(conversation, tool_specs, kwargs) response = client.chat.completions.create(**new_request_payload) self._track_token_usage(response) self._append_chat_response_output(timeline, response) @@ -226,7 +227,7 @@ class OpenAIProvider(ModelProvider): def _build_chat_payload( self, - timeline: List[Any], + conversation: List[Message], tool_specs: Optional[List[ToolSpec]], raw_params: Dict[str, Any], ) -> Dict[str, Any]: @@ -238,8 +239,8 @@ class OpenAIProvider(ModelProvider): max_tokens = max_output_tokens messages: List[Any] = [] - for item in timeline: - serialized = self._serialize_timeline_item_for_chat(item) + for item in conversation: + serialized = self._serialize_message_for_chat(item) if serialized is not None: messages.append(serialized) @@ -366,12 +367,19 @@ class OpenAIProvider(ModelProvider): tool_calls: List[ToolCallPayload] = [] tc_data = self._get_attr(msg, "tool_calls") if tc_data: - for tc in tc_data: + for idx, tc in enumerate(tc_data): f_data = self._get_attr(tc, "function") or {} + function_name = self._get_attr(f_data, "name") or "" + arguments = self._get_attr(f_data, "arguments") or "" + if not isinstance(arguments, str): + arguments = str(arguments) + call_id = self._get_attr(tc, "id") + if not call_id: + call_id = self._build_tool_call_id(function_name, arguments, fallback_prefix=f"tool_call_{idx}") tool_calls.append(ToolCallPayload( - id=self._get_attr(tc, "id"), - function_name=self._get_attr(f_data, "name"), - arguments=self._get_attr(f_data, "arguments"), + id=call_id, + function_name=function_name, + arguments=arguments, type="function" )) @@ -391,13 +399,18 @@ class OpenAIProvider(ModelProvider): if getattr(msg, "tool_calls", None): assistant_msg["tool_calls"] = [] - for tc in msg.tool_calls: + for idx, tc in enumerate(msg.tool_calls): + function_name = tc.function.name + arguments = tc.function.arguments or "" + if not isinstance(arguments, str): + arguments = str(arguments) + call_id = tc.id or self._build_tool_call_id(function_name, arguments, fallback_prefix=f"tool_call_{idx}") assistant_msg["tool_calls"].append({ - "id": tc.id, + "id": call_id, "type": "function", "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, + "name": function_name, + "arguments": arguments, }, }) @@ -691,7 +704,6 @@ class OpenAIProvider(ModelProvider): ) def _parse_tool_call(self, payload: Any) -> Optional[ToolCallPayload]: - call_id = self._get_attr(payload, "call_id") or self._get_attr(payload, "id") or "" function_payload = self._get_attr(payload, "function") or {} function_name = self._get_attr(function_payload, "name") or self._get_attr(payload, "name") or "" arguments = self._get_attr(function_payload, "arguments") or self._get_attr(payload, "arguments") or "" @@ -706,6 +718,9 @@ class OpenAIProvider(ModelProvider): arguments_str = str(arguments) else: arguments_str = str(arguments) + call_id = self._get_attr(payload, "call_id") or self._get_attr(payload, "id") or "" + if not call_id: + call_id = self._build_tool_call_id(function_name, arguments_str) return ToolCallPayload( id=call_id, function_name=function_name, @@ -713,6 +728,12 @@ class OpenAIProvider(ModelProvider): type="function", ) + def _build_tool_call_id(self, function_name: str, arguments: str, *, fallback_prefix: str = "tool_call") -> str: + base = function_name or fallback_prefix + payload = f"{base}:{arguments or ''}".encode("utf-8") + digest = hashlib.md5(payload).hexdigest()[:8] + return f"{base}_{digest}" + def _get_attr(self, payload: Any, key: str) -> Any: if hasattr(payload, key): return getattr(payload, key)