File size: 4,079 Bytes
edb0494
ac03725
6405936
 
 
 
 
edb0494
6405936
 
ac03725
6405936
78a1133
 
 
e9280cf
e4bd666
86120ee
 
6405936
7ca6fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86120ee
6405936
ac03725
e9280cf
 
7ca6fbe
6405936
 
 
 
 
 
 
 
e9280cf
 
 
 
 
 
 
 
6405936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f3a8eb
 
b8f155b
 
4f3a8eb
b8f155b
4f3a8eb
 
6405936
9edde8e
 
 
 
 
97567b1
 
9edde8e
 
e9280cf
 
 
6405936
 
 
 
 
86580f0
97567b1
6405936
 
 
 
 
 
 
 
 
97567b1
 
 
 
 
 
6405936
 
 
 
 
 
e9280cf
6405936
 
4f3a8eb
6405936
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
#import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download

from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
import devicetorch

MODELS = {
    "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
DEVICE = devicetorch.get(torch)
pipe = None
def init():
    global pipe

    if pipe is None:

        config_file = hf_hub_download(
            "xinsir/controlnet-union-sdxl-1.0",
            filename="config_promax.json",
        )

        config = ControlNetModel_Union.load_config(config_file)
        controlnet_model = ControlNetModel_Union.from_config(config)
        model_file = hf_hub_download(
            "xinsir/controlnet-union-sdxl-1.0",
            filename="diffusion_pytorch_model_promax.safetensors",
        )
        state_dict = load_state_dict(model_file)
        model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
            controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
        )
        model.to(device=DEVICE, dtype=torch.float16)

        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
        ).to(DEVICE)

        pipe = StableDiffusionXLFillPipeline.from_pretrained(
            "SG161222/RealVisXL_V5.0_Lightning",
            torch_dtype=torch.float16,
            vae=vae,
            controlnet=model,
            variant="fp16",
        ).to(DEVICE)

        pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)


#@spaces.GPU(duration=16)
def fill_image(prompt, image, model_selection):
    init()
    print(f"image {image}")
    source = image["background"]
    mask = image["layers"][0]

    alpha_channel = mask.split()[3]
    binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
    cnet_image = source.copy()
    cnet_image.paste(0, (0, 0), binary_mask)

    (
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(prompt, DEVICE, True)


    for image in pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        image=cnet_image,
    ):
        yield image, cnet_image

    image = image.convert("RGBA")
    cnet_image.paste(image, (0, 0), binary_mask)

    yield source, cnet_image


def clear_result():
    return gr.update(value=None)

def resize(image):
    print(f"resize image={image}")
    source = image["background"]
    image["background"].thumbnail((1024, 1024), Image.LANCZOS)
    print(f"resized image={image}")
    return image["background"]



#css = """
#.gradio-container {
#    width: 1024px !important;
#}
#"""


#with gr.Blocks(css=css, fill_width=True) as demo:
with gr.Blocks(fill_width=True) as demo:
    with gr.Row():
        prompt = gr.Textbox(value="high quality", label="Prompt")
        run_button = gr.Button("Generate")

    with gr.Row():
        input_image = gr.ImageMask(
            type="pil",
            label="Input Image",
#            crop_size=(1024, 1024),
            canvas_size=(1024, 1024),
            layers=False,
            sources=["upload"],
        )

        result = ImageSlider(
            interactive=False,
            label="Generated Image",
        )

    model_selection = gr.Dropdown(
        choices=list(MODELS.keys()),
        value="RealVisXL V5.0 Lightning",
        label="Model",
    )

    run_button.click(
        fn=clear_result,
        inputs=None,
        outputs=result,
    ).then(
        fn=fill_image,
        inputs=[prompt, input_image, model_selection],
        outputs=result,
    )
    input_image.upload(fn=resize, inputs=input_image, outputs=input_image)


demo.launch(share=False)