stephenebert commited on
Commit
e5a5c47
·
verified ·
1 Parent(s): b171040

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import functools
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
+
6
+ MODEL_OPTS = {
7
+ "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
8
+ "SDXL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
9
+ "SD-Turbo (ultra-fast)": "stabilityai/sd-turbo"
10
+ }
11
+
12
+ DEVICE = (
13
+ "mps" if torch.backends.mps.is_available() else
14
+ "cuda" if torch.cuda.is_available() else
15
+ "cpu"
16
+ )
17
+ DTYPE = torch.float16 if DEVICE != "cpu" else torch.float32
18
+
19
+ @functools.lru_cache(maxsize=len(MODEL_OPTS))
20
+ def get_pipeline(model_id: str):
21
+ pipe = StableDiffusionPipeline.from_pretrained(
22
+ model_id,
23
+ torch_dtype=DTYPE,
24
+ safety_checker=None
25
+ ).to(DEVICE)
26
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
27
+ return pipe
28
+
29
+ def generate(prompt, steps, guidance, seed, model_name):
30
+ model_id = MODEL_OPTS[model_name]
31
+ if "Turbo" in model_name:
32
+ steps = min(int(steps), 4)
33
+ pipe = get_pipeline(model_id)
34
+ generator = None if seed == 0 else torch.manual_seed(int(seed))
35
+ imgs = pipe(
36
+ prompt,
37
+ num_inference_steps=int(steps),
38
+ guidance_scale=float(guidance),
39
+ generator=generator
40
+ ).images
41
+ return imgs
42
+
43
+ with gr.Blocks() as demo:
44
+ gr.Markdown("## Model-Switcher Stable Diffusion Demo")
45
+ prompt = gr.Textbox("Retro robot in neon city", label="Prompt")
46
+ checkpoint = gr.Dropdown(list(MODEL_OPTS.keys()), value="SD v1.5 (base)", label="Checkpoint")
47
+ steps = gr.Slider(1, 50, value=30, label="Inference Steps")
48
+ guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
49
+ seed = gr.Number(0, label="Seed (0=random)")
50
+ btn = gr.Button("Generate")
51
+ gallery = gr.Gallery(label="Gallery", columns=2, height="auto")
52
+
53
+ btn.click(
54
+ fn=generate,
55
+ inputs=[prompt, steps, guidance, seed, checkpoint],
56
+ outputs=gallery
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.2
2
+ diffusers>=0.28
3
+ transformers>=4.42
4
+ accelerate>=0.29
5
+ safetensors
6
+ gradio>=4.32