Spaces:
Runtime error
Runtime error
File size: 4,010 Bytes
edb0494 ac03725 6405936 edb0494 6405936 ac03725 6405936 78a1133 e9280cf e4bd666 86120ee 6405936 7ca6fbe 86120ee 6405936 ac03725 e9280cf 7ca6fbe 6405936 e9280cf 6405936 4f3a8eb 6405936 9edde8e 97567b1 9edde8e e9280cf 6405936 86580f0 97567b1 6405936 97567b1 6405936 e9280cf 6405936 4f3a8eb 6405936 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
#import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
import devicetorch
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
DEVICE = devicetorch.get(torch)
pipe = None
def init():
global pipe
if pipe is None:
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
state_dict = load_state_dict(model_file)
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
model.to(device=DEVICE, dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to(DEVICE)
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model,
variant="fp16",
).to(DEVICE)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
#@spaces.GPU(duration=16)
def fill_image(prompt, image, model_selection):
init()
print(f"image {image}")
source = image["background"]
mask = image["layers"][0]
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt, DEVICE, True)
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=cnet_image,
):
yield image, cnet_image
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
def clear_result():
return gr.update(value=None)
def resize(image):
print(f"resize image={image}")
image.thumbnail(size, Image.LANCZOS)
print(f"resized image={image}")
return image
#css = """
#.gradio-container {
# width: 1024px !important;
#}
#"""
#with gr.Blocks(css=css, fill_width=True) as demo:
with gr.Blocks(fill_width=True) as demo:
with gr.Row():
prompt = gr.Textbox(value="high quality", label="Prompt")
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="Input Image",
# crop_size=(1024, 1024),
canvas_size=(1024, 1024),
layers=False,
sources=["upload"],
)
result = ImageSlider(
interactive=False,
label="Generated Image",
)
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value="RealVisXL V5.0 Lightning",
label="Model",
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
)
input_image.upload(fn=resize, inputs=input_image, outputs=input_image)
demo.launch(share=False)
|