xcheng20 commited on
Commit
0803fb7
·
verified ·
1 Parent(s): 4413470

Upload stable_diffusion_loader.py

Browse files
Files changed (1) hide show
  1. stable_diffusion_loader.py +130 -0
stable_diffusion_loader.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import base64
4
+ from io import BytesIO
5
+ from PIL import Image
6
+
7
+ from transformers import CLIPTokenizer, CLIPTextModel
8
+ from diffusers import (
9
+ StableDiffusionPipeline,
10
+ UNet2DConditionModel,
11
+ AutoencoderKL,
12
+ DDIMScheduler,
13
+ )
14
+
15
+
16
+ def load_custom_pipeline(
17
+ model_path: str = "./fine-tuned-model",
18
+ use_mps_if_available: bool = True
19
+ ):
20
+ """
21
+ Loads your custom fine-tuned Stable Diffusion model from a local folder structure.
22
+ Returns a pipeline object ready for inference.
23
+ """
24
+ # Load tokenizer
25
+ tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
26
+
27
+ # Load text encoder
28
+ text_encoder = CLIPTextModel.from_pretrained(
29
+ os.path.join(model_path, "text_encoder"),
30
+ torch_dtype=torch.float32
31
+ )
32
+
33
+ # Load UNet
34
+ unet = UNet2DConditionModel.from_pretrained(
35
+ os.path.join(model_path, "unet"),
36
+ torch_dtype=torch.float32
37
+ )
38
+
39
+ # Load VAE
40
+ vae = AutoencoderKL.from_pretrained(
41
+ os.path.join(model_path, "vae"),
42
+ torch_dtype=torch.float32
43
+ )
44
+
45
+ # Load scheduler
46
+ scheduler = DDIMScheduler.from_pretrained(
47
+ "CompVis/stable-diffusion-v1-4",
48
+ subfolder="scheduler"
49
+ )
50
+
51
+ # Create the pipeline
52
+ pipe = StableDiffusionPipeline(
53
+ tokenizer=tokenizer,
54
+ text_encoder=text_encoder,
55
+ vae=vae,
56
+ unet=unet,
57
+ scheduler=scheduler,
58
+ safety_checker=None, # Disable safety checker
59
+ feature_extractor=None
60
+ )
61
+
62
+ # Set device
63
+ device = (
64
+ torch.device("mps") if (torch.backends.mps.is_available() and use_mps_if_available)
65
+ else torch.device("cpu")
66
+ )
67
+ pipe.to(device)
68
+
69
+ # Optional: reduce memory usage
70
+ pipe.enable_attention_slicing()
71
+
72
+ return pipe
73
+
74
+
75
+ def load_base_pipeline(
76
+ model_id: str = "CompVis/stable-diffusion-v1-4",
77
+ use_mps_if_available: bool = True
78
+ ):
79
+ """
80
+ Loads the original Stable Diffusion v1.4 model from Hugging Face.
81
+ Returns a pipeline object ready for inference.
82
+ """
83
+ pipe = StableDiffusionPipeline.from_pretrained(
84
+ model_id,
85
+ torch_dtype=torch.float32,
86
+ safety_checker=None,
87
+ feature_extractor=None
88
+ )
89
+ device = (
90
+ torch.device("mps") if (torch.backends.mps.is_available() and use_mps_if_available)
91
+ else torch.device("cpu")
92
+ )
93
+ pipe.to(device)
94
+ pipe.enable_attention_slicing()
95
+ return pipe
96
+
97
+
98
+ def generate_image(
99
+ pipe: StableDiffusionPipeline,
100
+ prompt: str,
101
+ num_inference_steps: int = 50,
102
+ guidance_scale: float = 7.5,
103
+ seed: int = None
104
+ ):
105
+ """
106
+ Generates a single image from the provided pipeline and prompt.
107
+ Optionally accepts a 'seed' for reproducibility.
108
+ """
109
+ if seed is not None:
110
+ generator = torch.Generator(device=pipe.device).manual_seed(seed)
111
+ else:
112
+ generator = None
113
+
114
+ with torch.no_grad():
115
+ result = pipe(
116
+ prompt=prompt,
117
+ num_inference_steps=num_inference_steps,
118
+ guidance_scale=guidance_scale,
119
+ generator=generator
120
+ )
121
+ return result.images[0]
122
+
123
+
124
+ def pil_image_to_base64_str(img: Image.Image) -> str:
125
+ """
126
+ Converts a PIL Image into a Base64-encoded PNG string.
127
+ """
128
+ buffered = BytesIO()
129
+ img.save(buffered, format="PNG")
130
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")