mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-05-13 20:13:43 +00:00
96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
from typing import Dict
|
|
import logging
|
|
from tenacity import retry
|
|
from tenacity.stop import stop_after_attempt
|
|
from tenacity.wait import wait_exponential
|
|
|
|
logger = logging.getLogger("model")
|
|
|
|
class APIConfig:
|
|
SLOW_FLAG = False
|
|
TRUNCATE_FACTOR = 0
|
|
|
|
def model_log_and_print(content):
|
|
if content is not None:
|
|
logger.info(content)
|
|
print(content)
|
|
|
|
def truncate_messages(messages):
|
|
max_length = 0
|
|
max_index = 0
|
|
for i, msg in enumerate(messages):
|
|
if len(msg.get('content', '')) > max_length:
|
|
max_length = len(msg['content'])
|
|
max_index = i
|
|
|
|
content = messages[max_index]['content']
|
|
factor = 1/(2**APIConfig.TRUNCATE_FACTOR)
|
|
messages[max_index]['content'] = content[:int(len(content)*factor)]
|
|
return messages
|
|
|
|
|
|
def calc_max_token(messages, max_tokens):
|
|
string = "\n".join([str(message["content"]) for message in messages])
|
|
num_prompt_tokens = int(len(string)//1.8) # approximation of tokens number
|
|
gap_between_send_receive = 15 * len(messages)
|
|
num_prompt_tokens += gap_between_send_receive
|
|
|
|
num_max_completion_tokens = max_tokens - num_prompt_tokens
|
|
logger.info(f"num_prompt_tokens: {num_prompt_tokens}, num_max_completion_tokens: {num_max_completion_tokens}")
|
|
if num_max_completion_tokens < 0:
|
|
logger.warning(f"num_max_completion_tokens is negative: {num_max_completion_tokens}")
|
|
return 0
|
|
return num_max_completion_tokens
|
|
|
|
|
|
@retry(wait=wait_exponential(min=5, max=10), stop=stop_after_attempt(10))
|
|
def chat_completion_request(messages, model, new_client, model_config_dict: Dict = None):
|
|
if model_config_dict is None:
|
|
model_config_dict = {
|
|
"temperature": 0.1,
|
|
"top_p": 1.0,
|
|
"n": 1,
|
|
"stream": False,
|
|
"frequency_penalty": 0.0,
|
|
"presence_penalty": 0.0,
|
|
"logit_bias": {},
|
|
}
|
|
|
|
json_data = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": 4096,
|
|
"temperature": model_config_dict["temperature"],
|
|
"top_p": model_config_dict["top_p"],
|
|
"n": model_config_dict["n"],
|
|
"stream": model_config_dict["stream"],
|
|
"frequency_penalty": model_config_dict["frequency_penalty"],
|
|
"presence_penalty": model_config_dict["presence_penalty"],
|
|
"logit_bias": model_config_dict["logit_bias"],
|
|
}
|
|
|
|
try:
|
|
model_log_and_print("[Model Query] {}".format(messages))
|
|
if APIConfig.SLOW_FLAG:
|
|
messages = truncate_messages(messages=messages)
|
|
|
|
response = new_client.chat.completions.create(**json_data)
|
|
|
|
completion_tokens = response.usage.completion_tokens
|
|
prompt_tokens = response.usage.prompt_tokens
|
|
total_tokens = response.usage.total_tokens
|
|
if total_tokens == 0:
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
if total_tokens == 0:
|
|
total_tokens = len(response.choices[0].message.content)//1.8
|
|
model_log_and_print(f"[Model Query] Token Usage: \nCompletion Tokens: {completion_tokens} \nPrompt Tokens: {prompt_tokens} \nTotal Tokens: {total_tokens}")
|
|
APIConfig.SLOW_FLAG = False
|
|
APIConfig.TRUNCATE_FACTOR = 0
|
|
return response, total_tokens
|
|
|
|
except Exception as e:
|
|
print("Unable to generate ChatCompletion response. " + f"OpenAI calling Exception: {e}")
|
|
APIConfig.SLOW_FLAG = True
|
|
APIConfig.TRUNCATE_FACTOR += 1
|
|
model_log_and_print(f"[Model Query: ChatCompletion] query failed: {str(e)}")
|
|
raise Exception() |