|
|
from fastapi import FastAPI, HTTPException, Response |
|
|
from pydantic import BaseModel |
|
|
from contextlib import asynccontextmanager |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import uuid |
|
|
from typing import List, Union |
|
|
|
|
|
import axengine |
|
|
import torch |
|
|
|
|
|
from transformers import CLIPTokenizer, PreTrainedTokenizer |
|
|
import time |
|
|
import argparse |
|
|
|
|
|
import os |
|
|
import traceback |
|
|
from diffusers import DPMSolverMultistepScheduler |
|
|
|
|
|
DEBUG_MODE = True |
|
|
LOG_TIMESTAMP = True |
|
|
|
|
|
def debug_log(msg): |
|
|
if DEBUG_MODE: |
|
|
timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else "" |
|
|
print(f"{timestamp}[DEBUG] {msg}") |
|
|
|
|
|
|
|
|
MODEL_PATHS = { |
|
|
"tokenizer": "./models/tokenizer", |
|
|
"text_encoder": "./models/text_encoder/sd15_text_encoder_sim.axmodel", |
|
|
"unet": "./models/unet.axmodel", |
|
|
"vae": "./models/vae_decoder.axmodel", |
|
|
"time_embeddings": "./models/time_input_dpmpp_20steps.npy" |
|
|
} |
|
|
|
|
|
class DiffusionModels: |
|
|
def __init__(self): |
|
|
self.models_loaded = False |
|
|
self.tokenizer = None |
|
|
self.text_encoder = None |
|
|
self.unet = None |
|
|
self.vae = None |
|
|
self.time_embeddings = None |
|
|
|
|
|
def load_models(self): |
|
|
"""预加载所有模型到内存""" |
|
|
try: |
|
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS["tokenizer"]) |
|
|
self.text_encoder = axengine.InferenceSession(MODEL_PATHS["text_encoder"]) |
|
|
self.unet = axengine.InferenceSession(MODEL_PATHS["unet"]) |
|
|
self.vae = axengine.InferenceSession(MODEL_PATHS["vae"]) |
|
|
|
|
|
|
|
|
full_time_embeddings = np.load(MODEL_PATHS["time_embeddings"]) |
|
|
|
|
|
self.time_embeddings = full_time_embeddings[::2] |
|
|
debug_log(f"时间嵌入已从20步采样为10步,形状: {self.time_embeddings.shape}") |
|
|
|
|
|
self.models_loaded = True |
|
|
print("所有模型已成功加载到内存") |
|
|
except Exception as e: |
|
|
print(f"模型加载失败: {str(e)}") |
|
|
raise |
|
|
|
|
|
diffusion_models = DiffusionModels() |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
diffusion_models.load_models() |
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
class GenerationRequest(BaseModel): |
|
|
positive_prompt: str |
|
|
negative_prompt: str = "" |
|
|
|
|
|
|
|
|
|
|
|
seed: int = None |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_image(request: GenerationRequest): |
|
|
try: |
|
|
|
|
|
if len(request.positive_prompt) > 1000: |
|
|
raise ValueError("提示词过长") |
|
|
|
|
|
|
|
|
image = generate_diffusion_image( |
|
|
positive_prompt=request.positive_prompt, |
|
|
negative_prompt=request.negative_prompt, |
|
|
num_steps=10, |
|
|
guidance_scale=5.4, |
|
|
seed=request.seed |
|
|
) |
|
|
|
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
|
image.save(img_byte_arr, format='PNG') |
|
|
|
|
|
return Response(content=img_byte_arr.getvalue(), media_type="image/png") |
|
|
|
|
|
except Exception as e: |
|
|
error_id = str(uuid.uuid4()) |
|
|
print(f"Error [{error_id}]: {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"生成失败,错误ID:{error_id}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_embeds(prompt, negative_prompt): |
|
|
"""获取正负提示词的嵌入(带形状验证)""" |
|
|
try: |
|
|
debug_log(f"开始处理提示词: {prompt}") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
def process_prompt(prompt_text): |
|
|
inputs = diffusion_models.tokenizer( |
|
|
prompt_text, |
|
|
padding="max_length", |
|
|
max_length=77, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}") |
|
|
|
|
|
outputs = diffusion_models.text_encoder.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0] |
|
|
debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}") |
|
|
return outputs |
|
|
|
|
|
neg_start = time.time() |
|
|
neg_embeds = process_prompt(negative_prompt) |
|
|
pos_embeds = process_prompt(prompt) |
|
|
debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s") |
|
|
|
|
|
|
|
|
if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768): |
|
|
raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}") |
|
|
|
|
|
return neg_embeds, pos_embeds |
|
|
except Exception as e: |
|
|
print(f"获取嵌入失败: {str(e)}") |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|
|
|
|
|
|
def generate_diffusion_image( |
|
|
positive_prompt: str, |
|
|
negative_prompt: str, |
|
|
num_steps: int = 10, |
|
|
guidance_scale: float = 5.4, |
|
|
seed: int = None |
|
|
) -> Image.Image: |
|
|
""" |
|
|
生成扩散图像的优化版本(固定10步推理,CFG=5.4) |
|
|
|
|
|
参数: |
|
|
positive_prompt (str): 正向提示词 |
|
|
negative_prompt (str): 负向提示词 |
|
|
num_steps (int): 推理步数 (固定为10) |
|
|
guidance_scale (float): 分类器自由引导系数 (固定为5.4) |
|
|
seed (int): 随机种子 (可选) |
|
|
|
|
|
返回: |
|
|
PIL.Image.Image: 生成的图像 |
|
|
|
|
|
异常: |
|
|
ValueError: 输入参数无效时抛出 |
|
|
RuntimeError: 推理过程中出现错误时抛出 |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not positive_prompt: |
|
|
raise ValueError("正向提示词不能为空") |
|
|
|
|
|
|
|
|
num_steps = 10 |
|
|
guidance_scale = 5.4 |
|
|
|
|
|
debug_log(f"开始生成流程 (固定参数: 10步, CFG=5.4)...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seed = seed if seed is not None else int(time.time() * 1000) % 0xFFFFFFFF |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
debug_log(f"初始随机种子: {seed}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embed_start = time.time() |
|
|
neg_emb, pos_emb = get_embeds( |
|
|
positive_prompt, |
|
|
negative_prompt, |
|
|
) |
|
|
debug_log(f"文本编码完成 | 耗时: {time.time()-embed_start:.2f}s") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scheduler = DPMSolverMultistepScheduler( |
|
|
num_train_timesteps=1000, |
|
|
beta_start=0.00085, |
|
|
beta_end=0.012, |
|
|
beta_schedule="scaled_linear", |
|
|
algorithm_type="dpmsolver++", |
|
|
use_karras_sigmas=True |
|
|
) |
|
|
scheduler.set_timesteps(num_steps) |
|
|
|
|
|
latents_shape = (1, 4, 60, 40) |
|
|
latent = torch.randn(latents_shape, generator=torch.Generator().manual_seed(seed)) |
|
|
latent = latent * scheduler.init_noise_sigma |
|
|
latent = latent.numpy().astype(np.float32) |
|
|
debug_log(f"潜在变量初始化 | 形状: {latent.shape} sigma:{scheduler.init_noise_sigma:.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(diffusion_models.time_embeddings) != num_steps: |
|
|
raise ValueError(f"时间嵌入步数不匹配: 需要{num_steps}步 当前{len(diffusion_models.time_embeddings)}步") |
|
|
time_steps = diffusion_models.time_embeddings |
|
|
debug_log(f"使用预处理的10步时间嵌入,形状: {time_steps.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
debug_log("开始10步采样循环...") |
|
|
for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)): |
|
|
step_start = time.time() |
|
|
|
|
|
|
|
|
time_emb = np.expand_dims(time_steps[step_idx], axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise_pred_neg = diffusion_models.unet.run(None, { |
|
|
"sample": latent, |
|
|
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
|
|
"encoder_hidden_states": neg_emb |
|
|
})[0] |
|
|
|
|
|
|
|
|
noise_pred_pos = diffusion_models.unet.run(None, { |
|
|
"sample": latent, |
|
|
"/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
|
|
"encoder_hidden_states": pos_emb |
|
|
})[0] |
|
|
|
|
|
|
|
|
noise_pred = noise_pred_neg + 5.4 * (noise_pred_pos - noise_pred_neg) |
|
|
|
|
|
|
|
|
latent_tensor = torch.from_numpy(latent) |
|
|
noise_pred_tensor = torch.from_numpy(noise_pred) |
|
|
|
|
|
|
|
|
scheduler_start = time.time() |
|
|
latent_tensor = scheduler.step( |
|
|
model_output=noise_pred_tensor, |
|
|
timestep=timestep, |
|
|
sample=latent_tensor |
|
|
).prev_sample |
|
|
debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
|
|
|
|
|
|
|
|
latent = latent_tensor.numpy().astype(np.float32) |
|
|
debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]") |
|
|
|
|
|
debug_log(f"步骤 {step_idx+1}/{num_steps} | 耗时: {time.time()-step_start:.2f}s") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
debug_log("开始VAE解码...") |
|
|
vae_start = time.time() |
|
|
latent = latent / 0.18215 |
|
|
image = diffusion_models.vae.run(None, {"latent": latent})[0] |
|
|
|
|
|
|
|
|
image = np.transpose(image.squeeze(), (1, 2, 0)) |
|
|
image = np.clip((image / 2 + 0.5) * 255, 0, 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(image[..., :3]) |
|
|
pil_image.save("./api.png") |
|
|
debug_log(f"VAE解码完成 | 耗时: {time.time()-vae_start:.2f}s") |
|
|
debug_log(f"总耗时: {time.time()-start_time:.2f}s (10步优化版)") |
|
|
return pil_image |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"生成失败: {str(e)}" |
|
|
debug_log(error_msg) |
|
|
traceback.print_exc() |
|
|
raise RuntimeError(error_msg) |