srishticrai commited on
Commit
46a5af8
·
verified ·
1 Parent(s): 58c757a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -7
README.md CHANGED
@@ -47,19 +47,44 @@ This model contains a **fine-tuned U-Net** from the `CompVis/stable-diffusion-v1
47
  This U-Net can be loaded into a standard Stable Diffusion pipeline to enhance image generation on descriptive prompts:
48
 
49
  ```python
50
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
 
51
  import torch
 
52
 
53
- # Load fine-tuned U-Net
54
- unet = UNet2DConditionModel.from_pretrained("srishticrai/unet-flickr8k")
 
 
55
 
56
- # Load pipeline with original components + fine-tuned U-Net
 
 
 
 
 
 
 
57
  pipe = StableDiffusionPipeline.from_pretrained(
58
  "CompVis/stable-diffusion-v1-4",
59
- unet=unet,
 
 
60
  torch_dtype=torch.float16
61
- ).to("cuda")
 
 
 
 
 
 
62
 
63
  # Generate image
64
- image = pipe("A child blowing bubbles in a park at sunset").images[0]
 
 
 
 
 
65
  image.show()
 
 
47
  This U-Net can be loaded into a standard Stable Diffusion pipeline to enhance image generation on descriptive prompts:
48
 
49
  ```python
50
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
51
+ from transformers import CLIPTextModel
52
  import torch
53
+ import matplotlib.pyplot as plt
54
 
55
+ # Load base components
56
+ print("Loading VAE and text encoder from base SD...")
57
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
58
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)
59
 
60
+ # Load fine-tuned UNet from Hugging Face
61
+ print("Loading fine-tuned UNet from Hugging Face (srishticrai/unet-flickr8k)...")
62
+ fine_tuned_unet = UNet2DConditionModel.from_pretrained(
63
+ "srishticrai/unet-flickr8k",
64
+ torch_dtype=torch.float16
65
+ ).to(device)
66
+
67
+ # Rebuild the pipeline
68
  pipe = StableDiffusionPipeline.from_pretrained(
69
  "CompVis/stable-diffusion-v1-4",
70
+ unet=fine_tuned_unet,
71
+ vae=vae,
72
+ text_encoder=text_encoder,
73
  torch_dtype=torch.float16
74
+ ).to(device)
75
+
76
+ pipe.set_progress_bar_config(disable=False)
77
+ pipe.enable_attention_slicing()
78
+
79
+ # Ask for prompt
80
+ prompt = input("Enter a prompt to generate an image: ")
81
 
82
  # Generate image
83
+ image = pipe(
84
+ prompt,
85
+ guidance_scale=10.0,
86
+ num_inference_steps=50
87
+ )
88
+
89
  image.show()
90
+ ```