添加获取token数和计费
This commit is contained in:
20
controllers/calculate_cost.py
Normal file
20
controllers/calculate_cost.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
def calc_cost(tokens: tuple) -> Decimal :
|
||||||
|
"""计算费用"""
|
||||||
|
input_tokens, output_tokens = tokens
|
||||||
|
if input_tokens <= 32000:
|
||||||
|
input_price = Decimal("0.001")
|
||||||
|
output_price = Decimal("0.01")
|
||||||
|
elif input_tokens <= 128000:
|
||||||
|
input_price = Decimal("0.0015")
|
||||||
|
output_price = Decimal("0.015")
|
||||||
|
else:
|
||||||
|
input_price = Decimal("0.003")
|
||||||
|
output_price = Decimal("0.03")
|
||||||
|
|
||||||
|
# 计算费用
|
||||||
|
input_cost = (Decimal(input_tokens) / Decimal(1000)) * input_price
|
||||||
|
output_cost = (Decimal(output_tokens) / Decimal(1000)) * output_price
|
||||||
|
|
||||||
|
return input_cost + output_cost
|
||||||
9
main.py
9
main.py
@@ -4,6 +4,7 @@ from models.generator import Generator
|
|||||||
from models.emoji_db import check_emoji, update_description
|
from models.emoji_db import check_emoji, update_description
|
||||||
from controllers.localfile_handles import get_emoji_count, yield_emoji_path
|
from controllers.localfile_handles import get_emoji_count, yield_emoji_path
|
||||||
from controllers.compress_image import compress_image
|
from controllers.compress_image import compress_image
|
||||||
|
from controllers.calculate_cost import calc_cost
|
||||||
from models.logger import setup_logger
|
from models.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -19,6 +20,7 @@ emoji = yield_emoji_path(emoji_folder)
|
|||||||
gen = Generator(api_key)
|
gen = Generator(api_key)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
count_cost = 0
|
||||||
for i in range(get_emoji_count(emoji_folder)):
|
for i in range(get_emoji_count(emoji_folder)):
|
||||||
emoji_file = next(emoji)
|
emoji_file = next(emoji)
|
||||||
emoji_uuid = ''
|
emoji_uuid = ''
|
||||||
@@ -30,13 +32,16 @@ def main():
|
|||||||
|
|
||||||
if check_emoji(emoji_uuid):
|
if check_emoji(emoji_uuid):
|
||||||
try:
|
try:
|
||||||
description = gen.process_local_image(str(image_path))
|
description, tokens = gen.get_data(str(image_path))
|
||||||
|
cost = calc_cost(tokens)
|
||||||
|
input_tokens, output_tokens = tokens
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning("AI模型未响应,请检查是否欠费!")
|
logger.warning("AI模型未响应,请检查是否欠费!")
|
||||||
break
|
break
|
||||||
update_description(emoji_uuid, description)
|
update_description(emoji_uuid, description)
|
||||||
if description:
|
if description:
|
||||||
logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i + 1}, 其 description 为: {description}")
|
logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i + 1}, 其 description 为: {description}")
|
||||||
|
logger.info(f"本次生成输入token:{input_tokens}, 输出token:{output_tokens}, 花费:{cost} 元")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"图片 {emoji_file} 的描述词生成失败!序号:{i + 1}")
|
logger.warning(f"图片 {emoji_file} 的描述词生成失败!序号:{i + 1}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from decimal import Decimal
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
from models.logger import setup_logger
|
from models.logger import setup_logger
|
||||||
|
|
||||||
@@ -10,8 +11,8 @@ class Generator:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.text = text
|
self.text = text
|
||||||
|
|
||||||
def process_local_image(self, image_path: str) -> Optional[str]:
|
def process_local_image(self, image_path: str):
|
||||||
"""处理本地单张图片,生成描述词"""
|
"""处理本地单张图片"""
|
||||||
try:
|
try:
|
||||||
image_url = f"file://{image_path}"
|
image_url = f"file://{image_path}"
|
||||||
messages = [
|
messages = [
|
||||||
@@ -29,22 +30,14 @@ class Generator:
|
|||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
|
|
||||||
if response and hasattr(response, 'output'):
|
return response
|
||||||
if hasattr(response.output.choices[0].message.content[0], 'text'):
|
|
||||||
return response.output.choices[0].message.content[0]["text"]
|
|
||||||
elif isinstance(response.output.choices[0].message.content[0], dict):
|
|
||||||
return response.output.choices[0].message.content[0].get("text", "")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except AttributeError as e:
|
|
||||||
raise AttributeError(e)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"API调用失败: {e}")
|
logger.error(f"API调用失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def process_link_image(self, image_url: str) -> Optional[str]:
|
def process_link_image(self, image_url: str):
|
||||||
"""处理链接图片,生成描述词"""
|
"""处理链接图片"""
|
||||||
try:
|
try:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -61,15 +54,22 @@ class Generator:
|
|||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
|
|
||||||
if response and hasattr(response, 'output'):
|
return response
|
||||||
if hasattr(response.output.choices[0].message.content[0], 'text'):
|
|
||||||
return response.output.choices[0].message.content[0]["text"]
|
|
||||||
elif isinstance(response.output.choices[0].message.content[0], dict):
|
|
||||||
return response.output.choices[0].message.content[0].get("text", "")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except AttributeError as e:
|
|
||||||
raise AttributeError(e)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"API调用失败: {e}")
|
logger.error(f"API调用失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_data(self, image_string: str) -> Optional[tuple[str, tuple]]:
|
||||||
|
"""处理图片"""
|
||||||
|
if image_string.startswith(('http://', 'https://')):
|
||||||
|
response = self.process_link_image(image_string)
|
||||||
|
else:
|
||||||
|
response = self.process_local_image(image_string)
|
||||||
|
|
||||||
|
if isinstance(response.output.choices[0].message.content[0], dict):
|
||||||
|
text = response.output.choices[0].message.content[0].get("text", "")
|
||||||
|
tokens = (response.usage.input_tokens, response.usage.output_tokens)
|
||||||
|
return text, tokens
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
10
test.py
10
test.py
@@ -3,6 +3,7 @@ import dashscope
|
|||||||
from models.generator import Generator
|
from models.generator import Generator
|
||||||
from controllers.localfile_handles import get_emoji_count, yield_emoji_path
|
from controllers.localfile_handles import get_emoji_count, yield_emoji_path
|
||||||
from controllers.compress_image import compress_image
|
from controllers.compress_image import compress_image
|
||||||
|
from controllers.calculate_cost import calc_cost
|
||||||
from models.logger import setup_logger
|
from models.logger import setup_logger
|
||||||
from models.emoji_db import get_emoji_url
|
from models.emoji_db import get_emoji_url
|
||||||
|
|
||||||
@@ -29,15 +30,18 @@ def generate_description_to_txt():
|
|||||||
image = compress_image(emoji_file)
|
image = compress_image(emoji_file)
|
||||||
image_path = Path(root_folder) / image
|
image_path = Path(root_folder) / image
|
||||||
|
|
||||||
description = gen.process_local_image(str(image_path))
|
description, tokens = gen.process_local_image(str(image_path))
|
||||||
|
cost = calc_cost(tokens)
|
||||||
|
input_tokens, output_tokens = tokens
|
||||||
logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i}, 其 description 为: {description}")
|
logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i}, 其 description 为: {description}")
|
||||||
|
logger.info(f"本次生成输入token:{input_tokens}, 输出token:{output_tokens}, 花费:{cost} 元")
|
||||||
fp.write(f"{emoji_uuid}: {description}\n")
|
fp.write(f"{emoji_uuid}: {description}\n")
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
generate_description_to_txt()
|
generate_description_to_txt()
|
||||||
|
|
||||||
content = get_emoji_url("c3151774-2c4c-48f6-a3ac-b8802cf95498")
|
content = get_emoji_url("c3151774-2c4c-48f6-a3ac-b8802cf95498")
|
||||||
print(content)
|
# print(content)
|
||||||
|
|
||||||
url = 'https://pic1.arkoo.com/56D0B40F99F841DF8A2425762AE2565D/picture/o_1i4qop009177v1tgf14db15he1iaj1is.jpg'
|
url = 'https://pic1.arkoo.com/56D0B40F99F841DF8A2425762AE2565D/picture/o_1i4qop009177v1tgf14db15he1iaj1is.jpg'
|
||||||
print(gen.process_link_image(url))
|
print(gen.get_data(url))
|
||||||
Reference in New Issue
Block a user