Update README.md
Browse files
README.md
CHANGED
@@ -47,19 +47,44 @@ This model contains a **fine-tuned U-Net** from the `CompVis/stable-diffusion-v1
|
|
47 |
This U-Net can be loaded into a standard Stable Diffusion pipeline to enhance image generation on descriptive prompts:
|
48 |
|
49 |
```python
|
50 |
-
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
|
|
51 |
import torch
|
|
|
52 |
|
53 |
-
# Load
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
# Load
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
pipe = StableDiffusionPipeline.from_pretrained(
|
58 |
"CompVis/stable-diffusion-v1-4",
|
59 |
-
unet=
|
|
|
|
|
60 |
torch_dtype=torch.float16
|
61 |
-
).to(
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
# Generate image
|
64 |
-
image = pipe(
|
|
|
|
|
|
|
|
|
|
|
65 |
image.show()
|
|
|
|
47 |
This U-Net can be loaded into a standard Stable Diffusion pipeline to enhance image generation on descriptive prompts:
|
48 |
|
49 |
```python
|
50 |
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
|
51 |
+
from transformers import CLIPTextModel
|
52 |
import torch
|
53 |
+
import matplotlib.pyplot as plt
|
54 |
|
55 |
+
# Load base components
|
56 |
+
print("Loading VAE and text encoder from base SD...")
|
57 |
+
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
|
58 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)
|
59 |
|
60 |
+
# Load fine-tuned UNet from Hugging Face
|
61 |
+
print("Loading fine-tuned UNet from Hugging Face (srishticrai/unet-flickr8k)...")
|
62 |
+
fine_tuned_unet = UNet2DConditionModel.from_pretrained(
|
63 |
+
"srishticrai/unet-flickr8k",
|
64 |
+
torch_dtype=torch.float16
|
65 |
+
).to(device)
|
66 |
+
|
67 |
+
# Rebuild the pipeline
|
68 |
pipe = StableDiffusionPipeline.from_pretrained(
|
69 |
"CompVis/stable-diffusion-v1-4",
|
70 |
+
unet=fine_tuned_unet,
|
71 |
+
vae=vae,
|
72 |
+
text_encoder=text_encoder,
|
73 |
torch_dtype=torch.float16
|
74 |
+
).to(device)
|
75 |
+
|
76 |
+
pipe.set_progress_bar_config(disable=False)
|
77 |
+
pipe.enable_attention_slicing()
|
78 |
+
|
79 |
+
# Ask for prompt
|
80 |
+
prompt = input("Enter a prompt to generate an image: ")
|
81 |
|
82 |
# Generate image
|
83 |
+
image = pipe(
|
84 |
+
prompt,
|
85 |
+
guidance_scale=10.0,
|
86 |
+
num_inference_steps=50
|
87 |
+
)
|
88 |
+
|
89 |
image.show()
|
90 |
+
```
|