cocktailpeanut commited on
Commit
e9280cf
·
1 Parent(s): 78a1133
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -13,9 +13,9 @@ import devicetorch
13
  MODELS = {
14
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
15
  }
 
16
  def init():
17
  global pipe
18
- DEVICE = devicetorch.get(torch)
19
 
20
  config_file = hf_hub_download(
21
  "xinsir/controlnet-union-sdxl-1.0",
@@ -48,17 +48,10 @@ def init():
48
 
49
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
50
 
51
- prompt = "high quality"
52
- (
53
- prompt_embeds,
54
- negative_prompt_embeds,
55
- pooled_prompt_embeds,
56
- negative_pooled_prompt_embeds,
57
- ) = pipe.encode_prompt(prompt, DEVICE, True)
58
-
59
 
60
  #@spaces.GPU(duration=16)
61
- def fill_image(image, model_selection):
 
62
  source = image["background"]
63
  mask = image["layers"][0]
64
 
@@ -67,6 +60,14 @@ def fill_image(image, model_selection):
67
  cnet_image = source.copy()
68
  cnet_image.paste(0, (0, 0), binary_mask)
69
 
 
 
 
 
 
 
 
 
70
  for image in pipe(
71
  prompt_embeds=prompt_embeds,
72
  negative_prompt_embeds=negative_prompt_embeds,
@@ -94,7 +95,9 @@ css = """
94
 
95
 
96
  with gr.Blocks(css=css) as demo:
97
- run_button = gr.Button("Generate")
 
 
98
 
99
  with gr.Row():
100
  input_image = gr.ImageMask(
@@ -123,7 +126,7 @@ with gr.Blocks(css=css) as demo:
123
  outputs=result,
124
  ).then(
125
  fn=fill_image,
126
- inputs=[input_image, model_selection],
127
  outputs=result,
128
  )
129
 
 
13
  MODELS = {
14
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
15
  }
16
+ DEVICE = devicetorch.get(torch)
17
  def init():
18
  global pipe
 
19
 
20
  config_file = hf_hub_download(
21
  "xinsir/controlnet-union-sdxl-1.0",
 
48
 
49
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
50
 
 
 
 
 
 
 
 
 
51
 
52
  #@spaces.GPU(duration=16)
53
+ def fill_image(prompt, image, model_selection):
54
+ init()
55
  source = image["background"]
56
  mask = image["layers"][0]
57
 
 
60
  cnet_image = source.copy()
61
  cnet_image.paste(0, (0, 0), binary_mask)
62
 
63
+ (
64
+ prompt_embeds,
65
+ negative_prompt_embeds,
66
+ pooled_prompt_embeds,
67
+ negative_pooled_prompt_embeds,
68
+ ) = pipe.encode_prompt(prompt, DEVICE, True)
69
+
70
+
71
  for image in pipe(
72
  prompt_embeds=prompt_embeds,
73
  negative_prompt_embeds=negative_prompt_embeds,
 
95
 
96
 
97
  with gr.Blocks(css=css) as demo:
98
+ with gr.Row():
99
+ prompt = gr.Textbox(value="high quality", label="Prompt")
100
+ run_button = gr.Button("Generate")
101
 
102
  with gr.Row():
103
  input_image = gr.ImageMask(
 
126
  outputs=result,
127
  ).then(
128
  fn=fill_image,
129
+ inputs=[prompt, input_image, model_selection],
130
  outputs=result,
131
  )
132