SD1.5-LLM8850 / api_10steps.py
LittleMouse
Add File
244baf9
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
from contextlib import asynccontextmanager
import numpy as np
from PIL import Image
import io
import uuid
from typing import List, Union
import axengine
import torch
from transformers import CLIPTokenizer, PreTrainedTokenizer
import time
import argparse
import os
import traceback
from diffusers import DPMSolverMultistepScheduler
# 配置日志格式
DEBUG_MODE = True
LOG_TIMESTAMP = True
def debug_log(msg):
if DEBUG_MODE:
timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else ""
print(f"{timestamp}[DEBUG] {msg}")
# 服务配置
MODEL_PATHS = {
"tokenizer": "./models/tokenizer",
"text_encoder": "./models/text_encoder/sd15_text_encoder_sim.axmodel",
"unet": "./models/unet.axmodel",
"vae": "./models/vae_decoder.axmodel",
"time_embeddings": "./models/time_input_dpmpp_20steps.npy" # 仍使用20步数据,但只取其中10步
}
class DiffusionModels:
def __init__(self):
self.models_loaded = False
self.tokenizer = None
self.text_encoder = None
self.unet = None
self.vae = None
self.time_embeddings = None
def load_models(self):
"""预加载所有模型到内存"""
try:
# 初始化tokenizer和模型
self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS["tokenizer"])
self.text_encoder = axengine.InferenceSession(MODEL_PATHS["text_encoder"])
self.unet = axengine.InferenceSession(MODEL_PATHS["unet"])
self.vae = axengine.InferenceSession(MODEL_PATHS["vae"])
# 加载时间嵌入并间隔采样为10步
full_time_embeddings = np.load(MODEL_PATHS["time_embeddings"])
# 从20步中间隔取10步 (取索引 0, 2, 4, 6, 8, 10, 12, 14, 16, 18)
self.time_embeddings = full_time_embeddings[::2] # 间隔取值
debug_log(f"时间嵌入已从20步采样为10步,形状: {self.time_embeddings.shape}")
self.models_loaded = True
print("所有模型已成功加载到内存")
except Exception as e:
print(f"模型加载失败: {str(e)}")
raise
diffusion_models = DiffusionModels()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 服务启动时加载模型
diffusion_models.load_models()
yield
# 服务关闭时清理资源
# (根据axengine的要求添加必要的清理逻辑)
app = FastAPI(lifespan=lifespan)
class GenerationRequest(BaseModel):
positive_prompt: str
negative_prompt: str = ""
# 移除这些参数,因为已经固定
# num_inference_steps: int = 10 # 固定为10步
# guidance_scale: float = 5.4 # 固定为5.4
seed: int = None
@app.post("/generate")
async def generate_image(request: GenerationRequest):
try:
# 输入验证
if len(request.positive_prompt) > 1000:
raise ValueError("提示词过长")
# 执行推理流程 - 固定参数
image = generate_diffusion_image(
positive_prompt=request.positive_prompt,
negative_prompt=request.negative_prompt,
num_steps=10, # 固定10步
guidance_scale=5.4, # 固定CFG=5.4
seed=request.seed
)
# 转换图像为字节流
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
except Exception as e:
error_id = str(uuid.uuid4())
print(f"Error [{error_id}]: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"生成失败,错误ID:{error_id}"
)
def get_embeds(prompt, negative_prompt):
"""获取正负提示词的嵌入(带形状验证)"""
try:
debug_log(f"开始处理提示词: {prompt}")
start_time = time.time()
def process_prompt(prompt_text):
inputs = diffusion_models.tokenizer(
prompt_text,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
)
debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}")
outputs = diffusion_models.text_encoder.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0]
debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}")
return outputs
neg_start = time.time()
neg_embeds = process_prompt(negative_prompt)
pos_embeds = process_prompt(prompt)
debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s")
# 验证形状
if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768):
raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}")
return neg_embeds, pos_embeds
except Exception as e:
print(f"获取嵌入失败: {str(e)}")
traceback.print_exc()
exit(1)
def generate_diffusion_image(
positive_prompt: str,
negative_prompt: str,
num_steps: int = 10, # 固定默认值为10
guidance_scale: float = 5.4, # 固定默认值为5.4
seed: int = None
) -> Image.Image:
"""
生成扩散图像的优化版本(固定10步推理,CFG=5.4)
参数:
positive_prompt (str): 正向提示词
negative_prompt (str): 负向提示词
num_steps (int): 推理步数 (固定为10)
guidance_scale (float): 分类器自由引导系数 (固定为5.4)
seed (int): 随机种子 (可选)
返回:
PIL.Image.Image: 生成的图像
异常:
ValueError: 输入参数无效时抛出
RuntimeError: 推理过程中出现错误时抛出
"""
try:
# 参数验证和固定
if not positive_prompt:
raise ValueError("正向提示词不能为空")
# 强制使用优化后的固定参数
num_steps = 10
guidance_scale = 5.4
debug_log(f"开始生成流程 (固定参数: 10步, CFG=5.4)...")
start_time = time.time()
# =====================================================================
# 1. 初始化配置
# =====================================================================
seed = seed if seed is not None else int(time.time() * 1000) % 0xFFFFFFFF
torch.manual_seed(seed)
np.random.seed(seed)
debug_log(f"初始随机种子: {seed}")
# =====================================================================
# 2. 文本编码 (保持原有输入形状 [1, 77, 768])
# =====================================================================
embed_start = time.time()
neg_emb, pos_emb = get_embeds(
positive_prompt,
negative_prompt,
)
debug_log(f"文本编码完成 | 耗时: {time.time()-embed_start:.2f}s")
# =====================================================================
# 3. 初始化潜在变量 (固定形状 [1, 4, 60, 40])
# =====================================================================
scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True
)
scheduler.set_timesteps(num_steps) # 设置为10步
latents_shape = (1, 4, 60, 40)
latent = torch.randn(latents_shape, generator=torch.Generator().manual_seed(seed))
latent = latent * scheduler.init_noise_sigma
latent = latent.numpy().astype(np.float32)
debug_log(f"潜在变量初始化 | 形状: {latent.shape} sigma:{scheduler.init_noise_sigma:.3f}")
# =====================================================================
# 4. 准备时间嵌入 (使用预处理的10步数据)
# =====================================================================
if len(diffusion_models.time_embeddings) != num_steps:
raise ValueError(f"时间嵌入步数不匹配: 需要{num_steps}步 当前{len(diffusion_models.time_embeddings)}步")
time_steps = diffusion_models.time_embeddings
debug_log(f"使用预处理的10步时间嵌入,形状: {time_steps.shape}")
# =====================================================================
# 5. 采样主循环 (10步优化版)
# =====================================================================
debug_log("开始10步采样循环...")
for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)):
step_start = time.time()
# 准备时间嵌入 (形状 [1, 1])
time_emb = np.expand_dims(time_steps[step_idx], axis=0)
# -----------------------------------------
# UNET双推理流程 (CFG=5.4优化)
# -----------------------------------------
# 负面提示推理
noise_pred_neg = diffusion_models.unet.run(None, {
"sample": latent,
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
"encoder_hidden_states": neg_emb
})[0]
# 正面提示推理
noise_pred_pos = diffusion_models.unet.run(None, {
"sample": latent,
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb,
"encoder_hidden_states": pos_emb
})[0]
# CFG融合 (固定使用5.4的引导强度)
noise_pred = noise_pred_neg + 5.4 * (noise_pred_pos - noise_pred_neg)
# 转换为Tensor
latent_tensor = torch.from_numpy(latent)
noise_pred_tensor = torch.from_numpy(noise_pred)
# 调度器更新
scheduler_start = time.time()
latent_tensor = scheduler.step(
model_output=noise_pred_tensor,
timestep=timestep,
sample=latent_tensor
).prev_sample
debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s")
# 转换回numpy
latent = latent_tensor.numpy().astype(np.float32)
debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]")
debug_log(f"步骤 {step_idx+1}/{num_steps} | 耗时: {time.time()-step_start:.2f}s")
# =====================================================================
# 6. VAE解码 (强制输出形状为768x512)
# =====================================================================
debug_log("开始VAE解码...")
vae_start = time.time()
latent = latent / 0.18215
image = diffusion_models.vae.run(None, {"latent": latent})[0]
# 转换为PIL图像 (优化内存拷贝)
image = np.transpose(image.squeeze(), (1, 2, 0))
image = np.clip((image / 2 + 0.5) * 255, 0, 255).astype(np.uint8)
pil_image = Image.fromarray(image[..., :3]) # 移除alpha通道
pil_image.save("./api.png")
debug_log(f"VAE解码完成 | 耗时: {time.time()-vae_start:.2f}s")
debug_log(f"总耗时: {time.time()-start_time:.2f}s (10步优化版)")
return pil_image
except Exception as e:
error_msg = f"生成失败: {str(e)}"
debug_log(error_msg)
traceback.print_exc()
raise RuntimeError(error_msg)