diff --git a/controllers/calculate_cost.py b/controllers/calculate_cost.py new file mode 100644 index 0000000..467a366 --- /dev/null +++ b/controllers/calculate_cost.py @@ -0,0 +1,20 @@ +from decimal import Decimal + +def calc_cost(tokens: tuple) -> Decimal : + """计算费用""" + input_tokens, output_tokens = tokens + if input_tokens <= 32000: + input_price = Decimal("0.001") + output_price = Decimal("0.01") + elif input_tokens <= 128000: + input_price = Decimal("0.0015") + output_price = Decimal("0.015") + else: + input_price = Decimal("0.003") + output_price = Decimal("0.03") + + # 计算费用 + input_cost = (Decimal(input_tokens) / Decimal(1000)) * input_price + output_cost = (Decimal(output_tokens) / Decimal(1000)) * output_price + + return input_cost + output_cost \ No newline at end of file diff --git a/main.py b/main.py index c342c90..a7682b5 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ from models.generator import Generator from models.emoji_db import check_emoji, update_description from controllers.localfile_handles import get_emoji_count, yield_emoji_path from controllers.compress_image import compress_image +from controllers.calculate_cost import calc_cost from models.logger import setup_logger logger = setup_logger() @@ -19,6 +20,7 @@ emoji = yield_emoji_path(emoji_folder) gen = Generator(api_key) def main(): + count_cost = 0 for i in range(get_emoji_count(emoji_folder)): emoji_file = next(emoji) emoji_uuid = '' @@ -30,13 +32,16 @@ def main(): if check_emoji(emoji_uuid): try: - description = gen.process_local_image(str(image_path)) + description, tokens = gen.get_data(str(image_path)) + cost = calc_cost(tokens) + input_tokens, output_tokens = tokens except AttributeError: logger.warning("AI模型未响应,请检查是否欠费!") break update_description(emoji_uuid, description) if description: - logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i + 1}, 其 description 为: {description}") + logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i + 1}, 其 description 为: {description}") + logger.info(f"本次生成输入token:{input_tokens}, 输出token:{output_tokens}, 花费:{cost} 元") else: logger.warning(f"图片 {emoji_file} 的描述词生成失败!序号:{i + 1}") diff --git a/models/generator.py b/models/generator.py index 03c2e8f..2d447bc 100644 --- a/models/generator.py +++ b/models/generator.py @@ -1,4 +1,5 @@ from typing import Optional +from decimal import Decimal from dashscope import MultiModalConversation from models.logger import setup_logger @@ -10,8 +11,8 @@ class Generator: self.model = model self.text = text - def process_local_image(self, image_path: str) -> Optional[str]: - """处理本地单张图片,生成描述词""" + def process_local_image(self, image_path: str): + """处理本地单张图片""" try: image_url = f"file://{image_path}" messages = [ @@ -29,22 +30,14 @@ class Generator: messages=messages ) - if response and hasattr(response, 'output'): - if hasattr(response.output.choices[0].message.content[0], 'text'): - return response.output.choices[0].message.content[0]["text"] - elif isinstance(response.output.choices[0].message.content[0], dict): - return response.output.choices[0].message.content[0].get("text", "") + return response - return None - - except AttributeError as e: - raise AttributeError(e) except Exception as e: logger.error(f"API调用失败: {e}") return None - def process_link_image(self, image_url: str) -> Optional[str]: - """处理链接图片,生成描述词""" + def process_link_image(self, image_url: str): + """处理链接图片""" try: messages = [ { @@ -61,15 +54,22 @@ class Generator: messages=messages ) - if response and hasattr(response, 'output'): - if hasattr(response.output.choices[0].message.content[0], 'text'): - return response.output.choices[0].message.content[0]["text"] - elif isinstance(response.output.choices[0].message.content[0], dict): - return response.output.choices[0].message.content[0].get("text", "") - return None + return response - except AttributeError as e: - raise AttributeError(e) except Exception as e: logger.error(f"API调用失败: {e}") return None + + def get_data(self, image_string: str) -> Optional[tuple[str, tuple]]: + """处理图片""" + if image_string.startswith(('http://', 'https://')): + response = self.process_link_image(image_string) + else: + response = self.process_local_image(image_string) + + if isinstance(response.output.choices[0].message.content[0], dict): + text = response.output.choices[0].message.content[0].get("text", "") + tokens = (response.usage.input_tokens, response.usage.output_tokens) + return text, tokens + + return None diff --git a/test.py b/test.py index 11f8eea..9a0987a 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ import dashscope from models.generator import Generator from controllers.localfile_handles import get_emoji_count, yield_emoji_path from controllers.compress_image import compress_image +from controllers.calculate_cost import calc_cost from models.logger import setup_logger from models.emoji_db import get_emoji_url @@ -29,15 +30,18 @@ def generate_description_to_txt(): image = compress_image(emoji_file) image_path = Path(root_folder) / image - description = gen.process_local_image(str(image_path)) + description, tokens = gen.process_local_image(str(image_path)) + cost = calc_cost(tokens) + input_tokens, output_tokens = tokens logger.info(f"图片 {emoji_file} 的描述词生成完毕,序号:{i}, 其 description 为: {description}") + logger.info(f"本次生成输入token:{input_tokens}, 输出token:{output_tokens}, 花费:{cost} 元") fp.write(f"{emoji_uuid}: {description}\n") fp.close() generate_description_to_txt() content = get_emoji_url("c3151774-2c4c-48f6-a3ac-b8802cf95498") -print(content) +# print(content) url = 'https://pic1.arkoo.com/56D0B40F99F841DF8A2425762AE2565D/picture/o_1i4qop009177v1tgf14db15he1iaj1is.jpg' -print(gen.process_link_image(url)) \ No newline at end of file +print(gen.get_data(url)) \ No newline at end of file