""" NanoBanana模型客户端 """ import requests import json import re import time import logging import os from datetime import datetime # 获取项目根目录 project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) log_dir = os.path.join(project_root, 'logs') # 确保日志目录存在 os.makedirs(log_dir, exist_ok=True) # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(os.path.join(log_dir, 'app.log')), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class NanoBananaClient: def __init__(self): self.api_url = "https://api.aimindsky.com/v1/chat/completions" self.api_key = "sk-39a5vNBNbkkMA6YHP03h663tBno6OqfJKngWmQy0oT7JCP1O" self.model_name = "gemini-3-pro-image-preview" self.timeout = 500 # 固定超时时间为500秒 def get_timestamp(self): """获取当前时间戳""" return datetime.now().isoformat() def _make_request(self, messages, temperature=0.7, max_retries=3): """发送请求到API,带重试机制""" import threading thread_id = threading.current_thread().ident headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } payload = { "model": self.model_name, "messages": messages, "temperature": temperature, "stream": True } # 记录请求参数 logger.info(f"[线程{thread_id}] API请求开始 - 模型: {self.model_name}") last_error = None for attempt in range(max_retries): try: # 记录请求详情(仅用于调试图片编辑) if any(isinstance(msg.get('content'), list) for msg in messages): logger.info(f"[线程{thread_id}] 图片编辑请求 - 消息类型: 多模态") for msg in messages: if isinstance(msg.get('content'), list): for content_item in msg['content']: if content_item.get('type') == 'image_url': image_url = content_item['image_url']['url'] logger.info(f"[线程{thread_id}] 图片数据长度: {len(image_url)}") response = requests.post( self.api_url, headers=headers, json=payload, stream=True, timeout=self.timeout ) logger.info(f"[线程{thread_id}] API响应状态码: {response.status_code}") if response.status_code == 200: logger.info(f"[线程{thread_id}] API响应成功,开始接收数据...") else: logger.error(f"[线程{thread_id}] API响应失败: {response.status_code}") logger.error(f"[线程{thread_id}] 响应头: {dict(response.headers)}") if response.status_code != 200: error_text = response.text logger.error(f"API请求失败,状态码: {response.status_code}") logger.error(f"错误响应内容: {error_text}") try: # 尝试解析错误响应 error_data = json.loads(error_text) logger.error(f"解析后的错误数据: {json.dumps(error_data, ensure_ascii=False, indent=2)}") if 'error' in error_data: error_msg = error_data['error'].get('message', '未知错误') error_code = error_data['error'].get('code', '') logger.error(f"错误代码: {error_code}, 错误消息: {error_msg}") # 某些错误不需要重试 if error_code in ['invalid_request_error', 'authentication_error']: raise Exception(f"API服务错误: {error_msg}") # 其他错误可以重试 last_error = Exception(f"API服务错误: {error_msg}") else: last_error = Exception(f"API请求失败: {response.status_code} - {error_text}") except json.JSONDecodeError: logger.error("无法解析错误响应为JSON格式") last_error = Exception(f"API请求失败: {response.status_code} - {error_text}") # 如果是最后一次尝试,抛出错误 if attempt == max_retries - 1: logger.error("所有重试都失败了") raise last_error # 等待后重试 wait_time = 2 ** attempt logger.info(f"等待 {wait_time} 秒后重试...") time.sleep(wait_time) # 指数退避 continue return self._parse_stream_response(response, thread_id) except requests.exceptions.Timeout: last_error = Exception(f"请求超时({self.timeout}秒),请稍后重试") logger.error(f"请求超时({self.timeout}秒),尝试 {attempt + 1}/{max_retries}") if attempt == max_retries - 1: logger.error(f"所有重试都超时,总共尝试了 {max_retries} 次") raise last_error wait_time = 2 ** attempt logger.info(f"等待 {wait_time} 秒后重试...") time.sleep(wait_time) continue except requests.exceptions.RequestException as e: last_error = Exception(f"网络请求错误: {str(e)}") logger.error(f"网络请求异常: {str(e)}, 尝试 {attempt + 1}/{max_retries}") if attempt == max_retries - 1: raise last_error time.sleep(2 ** attempt) continue # 如果所有重试都失败了 raise last_error or Exception("请求失败") def _parse_stream_response(self, response, thread_id=None): """解析流式响应""" if thread_id is None: import threading thread_id = threading.current_thread().ident content_parts = [] try: line_count = 0 for line in response.iter_lines(): if line: line_count += 1 line_text = line.decode('utf-8') # 记录前几行数据用于调试 if line_count <= 5: logger.info(f"[线程{thread_id}] 接收到第{line_count}行数据: {line_text[:200]}...") # 解析SSE格式 if line_text.startswith('data: '): data_content = line_text[6:] # 移除 'data: ' 前缀 if data_content.strip() == '[DONE]': logger.info(f"[线程{thread_id}] 接收到结束标记 [DONE]") break if data_content.strip(): # 确保不是空内容 try: data = json.loads(data_content) # 记录响应对象类型 if 'object' in data: logger.info(f"[线程{thread_id}] 响应对象类型: {data['object']}") # 检查是否有错误 if 'error' in data: error_msg = data['error'].get('message', '未知错误') logger.error(f"[线程{thread_id}] API响应错误: {error_msg}") raise Exception(f"API响应错误: {error_msg}") # 提取内容 if 'choices' in data and len(data['choices']) > 0: delta = data['choices'][0].get('delta', {}) if 'content' in delta: content = delta['content'] content_parts.append(content) if content: logger.info(f"[线程{thread_id}] 接收到内容片段,长度: {len(content)}") except json.JSONDecodeError as e: # 记录解析失败的数据片段 logger.error(f"[线程{thread_id}] JSON解析失败: {str(e)}") logger.error(f"[线程{thread_id}] 失败的数据: {data_content[:500]}...") continue logger.info(f"[线程{thread_id}] 总共接收了 {line_count} 行数据") # 合并所有内容 full_content = ''.join(content_parts) # 检查是否有有效内容 if not full_content.strip(): logger.error(f"[线程{thread_id}] API未返回有效内容") raise Exception("API未返回有效内容,请检查提示词或稍后重试") # 提取图片数据 image_data = self._extract_image_data(full_content, thread_id) if not image_data: logger.error(f"[线程{thread_id}] 未能从响应中提取到图片数据") raise Exception("未能从响应中提取到图片数据,请重试") logger.info(f"[线程{thread_id}] 图片生成成功,数据长度: {len(image_data)}") return image_data except Exception as e: logger.error(f"[线程{thread_id}] 响应解析失败: {str(e)}") raise Exception(f"响应解析失败: {str(e)}") def _extract_image_data(self, content, thread_id=None): """从响应内容中提取图片数据""" # 查找图片标记格式: ![image](data:image/jpeg;base64,...) image_pattern = r'!\[image\]\(data:image/[^;]+;base64,([^)]+)\)' matches = re.findall(image_pattern, content) if matches: # 返回第一个匹配的base64数据 base64_data = matches[0] return f"data:image/jpeg;base64,{base64_data}" else: # 如果没有找到标准格式,尝试直接查找base64数据 base64_pattern = r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)' direct_matches = re.findall(base64_pattern, content) if direct_matches: return f"data:image/jpeg;base64,{direct_matches[0]}" # 尝试查找其他可能的图片格式 other_patterns = [ r'base64,([A-Za-z0-9+/=]{100,})', # 查找长base64字符串 r'data:image/[^,]+,([A-Za-z0-9+/=]+)', # 更宽泛的data URL格式 ] for pattern in other_patterns: other_matches = re.findall(pattern, content) if other_matches: return f"data:image/jpeg;base64,{other_matches[0]}" return None def generate_text_to_image(self, prompt, temperature=0.7): """文生图功能""" import threading thread_id = threading.current_thread().ident logger.info(f"[线程{thread_id}] 开始文生图请求: {prompt[:50]}...") # 直接使用用户提示词,不再读取系统提示词文件 messages = [ { "role": "user", "content": prompt } ] return self._make_request(messages, temperature) def generate_with_image(self, prompt, image_base64, temperature=0.7): """图片编辑功能(单张图片,保留向后兼容)""" return self.generate_with_images(prompt, [image_base64], temperature) def generate_with_images(self, prompt, image_base64_list, temperature=0.7): """图片编辑功能(支持多张图片) Args: prompt: 用户提示词 image_base64_list: 图片base64数据列表,最多3张 temperature: 创意度参数 """ import threading thread_id = threading.current_thread().ident # 确保是列表 if not isinstance(image_base64_list, list): image_base64_list = [image_base64_list] # 限制最多3张图片 if len(image_base64_list) > 3: image_base64_list = image_base64_list[:3] logger.warning(f"[线程{thread_id}] 图片数量超过3张,已截取前3张") logger.info(f"[线程{thread_id}] 开始图片编辑请求: {prompt[:50]}...,图片数量: {len(image_base64_list)}") # 读取图片编辑系统提示词 try: prompt_file_path = os.path.join(project_root, 'model_prompt', 'image_edit.txt') with open(prompt_file_path, 'r', encoding='utf-8') as f: system_prompt = f.read().strip() # 替换用户提示词占位符 final_prompt = system_prompt.replace('{user_prompt}', prompt) except Exception as e: logger.warning(f"[线程{thread_id}] 读取系统提示词失败: {str(e)},使用原始提示词") final_prompt = prompt # 构建消息内容,包含文本和所有图片 content = [ { "type": "text", "text": final_prompt } ] # 动态添加所有图片 for i, image_base64 in enumerate(image_base64_list): content.append({ "type": "image_url", "image_url": { "url": image_base64 } }) logger.info(f"[线程{thread_id}] 添加第 {i + 1} 张参考图片,数据长度: {len(image_base64)}") messages = [ { "role": "user", "content": content } ] return self._make_request(messages, temperature)