mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-06-01 21:20:36 +00:00
[fix] support chat interface
This commit is contained in:
parent
1079544652
commit
82dcf4a587
@ -4,7 +4,7 @@ import base64
|
|||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from urllib.parse import unquote_to_bytes
|
from urllib.parse import unquote_to_bytes
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@ -46,7 +46,6 @@ class OpenAIProvider(ModelProvider):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
base_url="https://api.openai.com/v1",
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,26 +59,43 @@ class OpenAIProvider(ModelProvider):
|
|||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""
|
"""
|
||||||
Call the OpenAI model with the given messages and parameters.
|
Call the OpenAI model with the given messages and parameters.
|
||||||
|
|
||||||
Args:
|
|
||||||
client: OpenAI client instance
|
|
||||||
conversation: List of messages in the conversation
|
|
||||||
tool_specs: Optional tool specifications
|
|
||||||
**kwargs: Additional parameters for the model call
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelResponse containing content and potentially tool calls
|
|
||||||
"""
|
"""
|
||||||
|
# 1. Determine if we should use Chat Completions directly
|
||||||
|
is_chat = self._is_chat_completions_mode(client)
|
||||||
|
|
||||||
|
if is_chat:
|
||||||
|
request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
|
||||||
|
response = client.chat.completions.create(**request_payload)
|
||||||
|
self._track_token_usage(response)
|
||||||
|
self._append_chat_response_output(timeline, response)
|
||||||
|
message = self._deserialize_chat_response(response)
|
||||||
|
return ModelResponse(message=message, raw_response=response)
|
||||||
|
|
||||||
|
# 2. Try Responses API with fallback
|
||||||
request_payload = self._build_request_payload(timeline, tool_specs, kwargs)
|
request_payload = self._build_request_payload(timeline, tool_specs, kwargs)
|
||||||
# print(request_payload)
|
try:
|
||||||
|
response = client.responses.create(**request_payload)
|
||||||
|
self._track_token_usage(response)
|
||||||
|
self._append_response_output(timeline, response)
|
||||||
|
message = self._deserialize_response(response)
|
||||||
|
return ModelResponse(message=message, raw_response=response)
|
||||||
|
except Exception as e:
|
||||||
|
new_request_payload = self._build_chat_payload(timeline, tool_specs, kwargs)
|
||||||
|
response = client.chat.completions.create(**new_request_payload)
|
||||||
|
self._track_token_usage(response)
|
||||||
|
self._append_chat_response_output(timeline, response)
|
||||||
|
message = self._deserialize_chat_response(response)
|
||||||
|
return ModelResponse(message=message, raw_response=response)
|
||||||
|
|
||||||
response = client.responses.create(**request_payload)
|
def _is_chat_completions_mode(self, client: Any) -> bool:
|
||||||
# print(response)
|
"""Determine if we should use standard chat completions instead of responses API."""
|
||||||
self._track_token_usage(response)
|
protocol = self.params.get("protocol")
|
||||||
self._append_response_output(timeline, response)
|
if protocol == "chat":
|
||||||
|
return True
|
||||||
message = self._deserialize_response(response)
|
if protocol == "responses":
|
||||||
return ModelResponse(message=message, raw_response=response)
|
return False
|
||||||
|
# Default to Responses API only if it exists on the client
|
||||||
|
return not hasattr(client, "responses")
|
||||||
|
|
||||||
def extract_token_usage(self, response: Any) -> TokenUsage:
|
def extract_token_usage(self, response: Any) -> TokenUsage:
|
||||||
"""
|
"""
|
||||||
@ -208,6 +224,185 @@ class OpenAIProvider(ModelProvider):
|
|||||||
payload.update(params)
|
payload.update(params)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
def _build_chat_payload(
|
||||||
|
self,
|
||||||
|
timeline: List[Any],
|
||||||
|
tool_specs: Optional[List[ToolSpec]],
|
||||||
|
raw_params: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Construct standard Chat Completions API payload."""
|
||||||
|
params = dict(raw_params)
|
||||||
|
max_output_tokens = params.pop("max_output_tokens", None)
|
||||||
|
max_tokens = params.pop("max_tokens", None)
|
||||||
|
if max_tokens is None and max_output_tokens is not None:
|
||||||
|
max_tokens = max_output_tokens
|
||||||
|
|
||||||
|
messages: List[Any] = []
|
||||||
|
for item in timeline:
|
||||||
|
serialized = self._serialize_timeline_item_for_chat(item)
|
||||||
|
if serialized is not None:
|
||||||
|
messages.append(serialized)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
messages = [{"role": "user", "content": ""}]
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": params.pop("temperature", 0.7),
|
||||||
|
}
|
||||||
|
if max_tokens is not None:
|
||||||
|
payload["max_tokens"] = max_tokens
|
||||||
|
elif self.params.get("max_tokens"):
|
||||||
|
payload["max_tokens"] = self.params["max_tokens"]
|
||||||
|
|
||||||
|
user_tools = params.pop("tools", None)
|
||||||
|
merged_tools: List[Any] = []
|
||||||
|
if isinstance(user_tools, list):
|
||||||
|
merged_tools.extend(user_tools)
|
||||||
|
|
||||||
|
if tool_specs:
|
||||||
|
for spec in tool_specs:
|
||||||
|
merged_tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": spec.name,
|
||||||
|
"description": spec.description,
|
||||||
|
"parameters": spec.parameters or {"type": "object", "properties": {}},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if merged_tools:
|
||||||
|
payload["tools"] = merged_tools
|
||||||
|
|
||||||
|
tool_choice = params.pop("tool_choice", None)
|
||||||
|
if tool_choice is not None:
|
||||||
|
payload["tool_choice"] = tool_choice
|
||||||
|
elif tool_specs:
|
||||||
|
payload.setdefault("tool_choice", "auto")
|
||||||
|
|
||||||
|
payload.update(params)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _serialize_timeline_item_for_chat(self, item: Any) -> Optional[Any]:
|
||||||
|
if isinstance(item, Message):
|
||||||
|
return self._serialize_message_for_chat(item)
|
||||||
|
if isinstance(item, FunctionCallOutputEvent):
|
||||||
|
return self._serialize_function_call_output_event_for_chat(item)
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# basic conversion if it looks like a Responses output
|
||||||
|
role = item.get("role")
|
||||||
|
content = item.get("content")
|
||||||
|
tool_calls = item.get("tool_calls")
|
||||||
|
if role and (content or tool_calls):
|
||||||
|
return {
|
||||||
|
"role": role,
|
||||||
|
"content": self._transform_blocks_for_chat(content) if isinstance(content, list) else content,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _serialize_message_for_chat(self, message: Message) -> Dict[str, Any]:
|
||||||
|
"""Convert internal Message to standard Chat Completions schema."""
|
||||||
|
role_value = message.role.value
|
||||||
|
blocks = message.blocks()
|
||||||
|
if not blocks:
|
||||||
|
content = message.text_content()
|
||||||
|
else:
|
||||||
|
content = self._transform_blocks_for_chat(self._serialize_blocks(blocks, message.role))
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"role": role_value,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
if message.name:
|
||||||
|
payload["name"] = message.name
|
||||||
|
if message.tool_call_id:
|
||||||
|
payload["tool_call_id"] = message.tool_call_id
|
||||||
|
if message.tool_calls:
|
||||||
|
payload["tool_calls"] = [tc.to_openai_dict() for tc in message.tool_calls]
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _serialize_function_call_output_event_for_chat(self, event: FunctionCallOutputEvent) -> Dict[str, Any]:
|
||||||
|
"""Convert tool result to standard Chat Completions schema."""
|
||||||
|
text = event.output_text or ""
|
||||||
|
if event.output_blocks:
|
||||||
|
# simple concatenation for tool output in chat mode
|
||||||
|
text = "\n".join(b.describe() for b in event.output_blocks)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": event.call_id or "tool_call",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _transform_blocks_for_chat(self, blocks: List[Dict[str, Any]]) -> Union[str, List[Dict[str, Any]]]:
|
||||||
|
"""Convert Responses block types to Chat block types (e.g., input_text -> text)."""
|
||||||
|
transformed: List[Dict[str, Any]] = []
|
||||||
|
for block in blocks:
|
||||||
|
b_type = block.get("type", "")
|
||||||
|
if b_type in ("input_text", "output_text"):
|
||||||
|
transformed.append({"type": "text", "text": block.get("text", "")})
|
||||||
|
elif b_type in ("input_image", "output_image"):
|
||||||
|
transformed.append({"type": "image_url", "image_url": {"url": block.get("image_url", "")}})
|
||||||
|
else:
|
||||||
|
# Keep as is or drop if complex
|
||||||
|
transformed.append(block)
|
||||||
|
|
||||||
|
# If only one text block, return as string for better compatibility
|
||||||
|
if len(transformed) == 1 and transformed[0]["type"] == "text":
|
||||||
|
return transformed[0]["text"]
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
def _deserialize_chat_response(self, response: Any) -> Message:
|
||||||
|
"""Convert Chat Completions output to internal Message."""
|
||||||
|
choices = self._get_attr(response, "choices") or []
|
||||||
|
if not choices:
|
||||||
|
return Message(role=MessageRole.ASSISTANT, content="")
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
msg = self._get_attr(choice, "message")
|
||||||
|
|
||||||
|
tool_calls: List[ToolCallPayload] = []
|
||||||
|
tc_data = self._get_attr(msg, "tool_calls")
|
||||||
|
if tc_data:
|
||||||
|
for tc in tc_data:
|
||||||
|
f_data = self._get_attr(tc, "function") or {}
|
||||||
|
tool_calls.append(ToolCallPayload(
|
||||||
|
id=self._get_attr(tc, "id"),
|
||||||
|
function_name=self._get_attr(f_data, "name"),
|
||||||
|
arguments=self._get_attr(f_data, "arguments"),
|
||||||
|
type="function"
|
||||||
|
))
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
content=self._get_attr(msg, "content") or "",
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
def _append_chat_response_output(self, timeline: List[Any], response: Any) -> None:
|
||||||
|
"""Add chat response to timeline, preserving tool_calls (Chat API compatible)."""
|
||||||
|
msg = response.choices[0].message
|
||||||
|
assistant_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": msg.content or ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if getattr(msg, "tool_calls", None):
|
||||||
|
assistant_msg["tool_calls"] = []
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
assistant_msg["tool_calls"].append({
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
timeline.append(assistant_msg)
|
||||||
|
|
||||||
def _serialize_timeline_item(self, item: Any) -> Optional[Any]:
|
def _serialize_timeline_item(self, item: Any) -> Optional[Any]:
|
||||||
if isinstance(item, Message):
|
if isinstance(item, Message):
|
||||||
return self._serialize_message_for_responses(item)
|
return self._serialize_message_for_responses(item)
|
||||||
|
|||||||
@ -283,7 +283,6 @@ class AgentNodeExecutor(NodeExecutor):
|
|||||||
self._record_model_call(node, last_input, None, CallStage.BEFORE)
|
self._record_model_call(node, last_input, None, CallStage.BEFORE)
|
||||||
response = self._execute_with_retry(node, retry_policy, _call_provider)
|
response = self._execute_with_retry(node, retry_policy, _call_provider)
|
||||||
self.log_manager.debug(response.str_raw_response())
|
self.log_manager.debug(response.str_raw_response())
|
||||||
# print(timeline)
|
|
||||||
self._record_model_call(node, last_input, response, CallStage.AFTER)
|
self._record_model_call(node, last_input, response, CallStage.AFTER)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user