Spaces:
Runtime error
Runtime error
Commit
·
86120ee
1
Parent(s):
ac03725
update
Browse files
app.py
CHANGED
@@ -10,48 +10,51 @@ from controlnet_union import ControlNetModel_Union
|
|
10 |
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
|
11 |
import devicetorch
|
12 |
DEVICE = devicetorch.get(torch)
|
13 |
-
MODELS = {
|
14 |
-
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
15 |
-
}
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
#@spaces.GPU(duration=16)
|
|
|
10 |
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
|
11 |
import devicetorch
|
12 |
DEVICE = devicetorch.get(torch)
|
|
|
|
|
|
|
13 |
|
14 |
+
def init():
|
15 |
+
global pipe
|
16 |
+
MODELS = {
|
17 |
+
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
18 |
+
}
|
19 |
+
|
20 |
+
config_file = hf_hub_download(
|
21 |
+
"xinsir/controlnet-union-sdxl-1.0",
|
22 |
+
filename="config_promax.json",
|
23 |
+
)
|
24 |
+
|
25 |
+
config = ControlNetModel_Union.load_config(config_file)
|
26 |
+
controlnet_model = ControlNetModel_Union.from_config(config)
|
27 |
+
model_file = hf_hub_download(
|
28 |
+
"xinsir/controlnet-union-sdxl-1.0",
|
29 |
+
filename="diffusion_pytorch_model_promax.safetensors",
|
30 |
+
)
|
31 |
+
state_dict = load_state_dict(model_file)
|
32 |
+
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
|
33 |
+
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
|
34 |
+
)
|
35 |
+
model.to(device=DEVICE, dtype=torch.float16)
|
36 |
+
|
37 |
+
vae = AutoencoderKL.from_pretrained(
|
38 |
+
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
39 |
+
).to(DEVICE)
|
40 |
+
|
41 |
+
pipe = StableDiffusionXLFillPipeline.from_pretrained(
|
42 |
+
"SG161222/RealVisXL_V5.0_Lightning",
|
43 |
+
torch_dtype=torch.float16,
|
44 |
+
vae=vae,
|
45 |
+
controlnet=model,
|
46 |
+
variant="fp16",
|
47 |
+
).to(DEVICE)
|
48 |
+
|
49 |
+
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
50 |
+
|
51 |
+
prompt = "high quality"
|
52 |
+
(
|
53 |
+
prompt_embeds,
|
54 |
+
negative_prompt_embeds,
|
55 |
+
pooled_prompt_embeds,
|
56 |
+
negative_pooled_prompt_embeds,
|
57 |
+
) = pipe.encode_prompt(prompt, DEVICE, True)
|
58 |
|
59 |
|
60 |
#@spaces.GPU(duration=16)
|