ChatDev/puppeteer/model/model_utils.py
2025-09-23 20:31:19 +08:00

96 lines
3.5 KiB
Python

from typing import Dict
import logging
from tenacity import retry
from tenacity.stop import stop_after_attempt
from tenacity.wait import wait_exponential
logger = logging.getLogger("model")
class APIConfig:
SLOW_FLAG = False
TRUNCATE_FACTOR = 0
def model_log_and_print(content):
if content is not None:
logger.info(content)
print(content)
def truncate_messages(messages):
max_length = 0
max_index = 0
for i, msg in enumerate(messages):
if len(msg.get('content', '')) > max_length:
max_length = len(msg['content'])
max_index = i
content = messages[max_index]['content']
factor = 1/(2**APIConfig.TRUNCATE_FACTOR)
messages[max_index]['content'] = content[:int(len(content)*factor)]
return messages
def calc_max_token(messages, max_tokens):
string = "\n".join([str(message["content"]) for message in messages])
num_prompt_tokens = int(len(string)//1.8) # approximation of tokens number
gap_between_send_receive = 15 * len(messages)
num_prompt_tokens += gap_between_send_receive
num_max_completion_tokens = max_tokens - num_prompt_tokens
logger.info(f"num_prompt_tokens: {num_prompt_tokens}, num_max_completion_tokens: {num_max_completion_tokens}")
if num_max_completion_tokens < 0:
logger.warning(f"num_max_completion_tokens is negative: {num_max_completion_tokens}")
return 0
return num_max_completion_tokens
@retry(wait=wait_exponential(min=5, max=10), stop=stop_after_attempt(10))
def chat_completion_request(messages, model, new_client, model_config_dict: Dict = None):
if model_config_dict is None:
model_config_dict = {
"temperature": 0.1,
"top_p": 1.0,
"n": 1,
"stream": False,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"logit_bias": {},
}
json_data = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"temperature": model_config_dict["temperature"],
"top_p": model_config_dict["top_p"],
"n": model_config_dict["n"],
"stream": model_config_dict["stream"],
"frequency_penalty": model_config_dict["frequency_penalty"],
"presence_penalty": model_config_dict["presence_penalty"],
"logit_bias": model_config_dict["logit_bias"],
}
try:
model_log_and_print("[Model Query] {}".format(messages))
if APIConfig.SLOW_FLAG:
messages = truncate_messages(messages=messages)
response = new_client.chat.completions.create(**json_data)
completion_tokens = response.usage.completion_tokens
prompt_tokens = response.usage.prompt_tokens
total_tokens = response.usage.total_tokens
if total_tokens == 0:
total_tokens = prompt_tokens + completion_tokens
if total_tokens == 0:
total_tokens = len(response.choices[0].message.content)//1.8
model_log_and_print(f"[Model Query] Token Usage: \nCompletion Tokens: {completion_tokens} \nPrompt Tokens: {prompt_tokens} \nTotal Tokens: {total_tokens}")
APIConfig.SLOW_FLAG = False
APIConfig.TRUNCATE_FACTOR = 0
return response, total_tokens
except Exception as e:
print("Unable to generate ChatCompletion response. " + f"OpenAI calling Exception: {e}")
APIConfig.SLOW_FLAG = True
APIConfig.TRUNCATE_FACTOR += 1
model_log_and_print(f"[Model Query: ChatCompletion] query failed: {str(e)}")
raise Exception()