Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# ZenCtrl Inpainting Playground (Baseten backend) | |
#import spaces | |
import os, json, base64, requests | |
from io import BytesIO | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
import replicate | |
# ββββββββββ Secrets & endpoints ββββββββββ | |
BASETEN_MODEL_URL = os.getenv("BASETEN_MODEL_URL") | |
BASETEN_API_KEY = os.getenv("BASETEN_API_KEY") | |
REPLICATE_TOKEN = os.getenv("REPLICATE_API_TOKEN") | |
# ββββββββββ Globals ββββββββββ | |
ADAPTER_SIZE = 1024 | |
css = "#col-container {margin:0 auto; max-width:960px;}" | |
# Background generation via Replicate | |
def _gen_bg(prompt: str): | |
url = replicate.run( | |
"google/imagen-4-fast", | |
input={"prompt": prompt or "cinematic background", "aspect_ratio": "1:1"}, | |
) | |
url = url[0] if isinstance(url, list) else url | |
return Image.open(BytesIO(requests.get(url, timeout=120).content)).convert("RGB") | |
# Main processing function | |
def process_image_and_text(subject_image, adapter_dict, prompt, _unused1, _unused2, size=ADAPTER_SIZE, rank=10.0): | |
seed, guidance_scale, steps = 42, 2.5, 28 | |
adapter_image = adapter_dict["image"] if isinstance(adapter_dict, dict) else adapter_dict | |
if isinstance(adapter_dict, dict): # Image/sketch input | |
adapter_image = adapter_dict["image"] | |
adapter_mask = adapter_dict["mask"] | |
if adapter_mask is not None: | |
# convert mask -> solid green rectangle on copy of adapter_image | |
m = adapter_mask.convert("L").point(lambda p: 255 if p else 0) | |
bbox = m.getbbox() | |
if bbox: | |
rect = Image.new("L", m.size, 0) | |
ImageDraw.Draw(rect).rectangle(bbox, fill=255) | |
m = rect | |
green = Image.new("RGB", adapter_image.size, "#00FF00") | |
adapter_image = Image.composite(green, adapter_image, m) | |
else: | |
adapter_image = adapter_dict | |
# def prep(img: Image.Image): | |
# w, h = img.size | |
# m = min(w, h) | |
# return img.crop(((w - m) // 2, (h - m) // 2, (w + m) // 2, (h + m) // 2)).resize((size, size), Image.LANCZOS) | |
def prep(img: Image.Image): | |
return img.resize((size, size), Image.LANCZOS) | |
subj_proc = prep(subject_image) | |
adap_proc = prep(adapter_image) | |
def b64(img): | |
buf = BytesIO() | |
img.save(buf, format="PNG") | |
return base64.b64encode(buf.getvalue()).decode() | |
payload = { | |
"prompt": prompt, | |
"subject_image": b64(subj_proc), | |
"adapter_image": b64(adap_proc), | |
"height": size, | |
"width": size, | |
"steps": steps, | |
"seed": seed, | |
"guidance_scale": guidance_scale, | |
"rank": rank, | |
} | |
headers = {"Content-Type": "application/json"} | |
if BASETEN_API_KEY: | |
headers["Authorization"] = f"Api-Key {BASETEN_API_KEY}" | |
resp = requests.post(BASETEN_MODEL_URL, headers=headers, json=payload, timeout=180) | |
resp.raise_for_status() | |
data = resp.json() | |
# Extract base64 image from 'blended' key | |
if "blended" in data: | |
try: | |
blended_bytes = base64.b64decode(data["raw_result"]) | |
raw_img = Image.open(BytesIO(blended_bytes)).convert("RGB") | |
return raw_img, raw_img | |
except Exception: | |
raise gr.Error("Failed to decode 'blended' image from Baseten response.") | |
else: | |
raise gr.Error("Baseten response missing 'blended' image.") | |
# ββββββββββ Header HTML ββββββββββ | |
header_html = """ | |
<h1>ZenCtrl Inpainting Beta</h1> | |
<div align=\"center\" style=\"line-height: 1;\"> | |
<a href=\"https://discord.com/invite/b9RuYQ3F8k\" target=\"_blank\" style=\"margin: 10px;\" name=\"discord_link\"><img src=\"https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord\" alt=\"Discord\" style=\"display: inline-block; vertical-align: middle;\"></a> | |
<a href=\"https://fotographer.ai/zen-control\" target=\"_blank\" style=\"margin: 10px;\" name=\"lp_link\"><img src=\"https://img.shields.io/badge/Website-Landing_Page-blue\" alt=\"LP\" style=\"display: inline-block; vertical-align: middle;\"></a> | |
<a href=\"https://x.com/FotographerAI\" target=\"_blank\" style=\"margin: 10px;\" name=\"twitter_link\"><img src=\"https://img.shields.io/twitter/follow/FotographerAI?style=social\" alt=\"X\" style=\"display: inline-block; vertical-align: middle;\"></a> | |
</div> | |
""" | |
# ββββββββββ Gradio UI ββββββββββ | |
with gr.Blocks(css=css, title="ZenCtrl Inpainting") as demo: | |
raw_state = gr.State() | |
gr.HTML(header_html) | |
gr.Markdown( | |
"**Generate context-aware images of your subject with ZenCtrlβs inpainting playground.** Upload a subject + optional mask, write a prompt, and hit **Generate**. \n" | |
"Open *Advanced Settings* for an AI-generated background. \n\n" | |
"**Note:** The model was trained mainly on interior scenes and other *rigid* objects. Results on people or highly deformable items may contain visual distortions. \n" | |
"In case of High traffic , your requests might be queued and processed one by one by our backend server" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2, elem_id="col-container"): | |
subj_img = gr.Image(type="pil", label="Subject image") | |
ref_img = gr.Image(type="pil", label="Background / Mask image", tool="sketch", brush_color="#00FF00") | |
ref_img_ex = gr.Image(type="pil", visible=False) | |
# Removed Florence-SAM | |
promptbox = gr.Textbox(label="Generation prompt", value="furniture", lines=2) | |
run_btn = gr.Button("Generate", variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
bgprompt = gr.Textbox(label="Background Prompt", value="Scandinavian living room β¦") | |
bg_btn = gr.Button("Generate BG") | |
with gr.Column(scale=2): | |
output_img = gr.Image(label="Output Image") | |
bg_img = gr.Image(label="Background", visible=True) | |
# ---------- Example wrapper --------------------------------- | |
# def _load_and_show(subj_path, bg_path, prompt_text): | |
# out_path = subj_path.replace(".png", "_out.png") | |
# return ( | |
# Image.open(subj_path), # β gr.Image widget ok | |
# {"image": Image.open(bg_path), "mask": None}, # β **dict for sketch!** | |
# prompt_text, # β gr.Textbox | |
# Image.open(out_path) # β gr.Image output | |
# ) | |
def _load_and_show(subj_path, bg_path, prompt_text): | |
""" | |
Takes the three values coming from an Examples row | |
and returns FOUR objects β one for every output widget: | |
1. subject PIL image -> subj_img | |
2. dict for the sketch component -> ref_img | |
3. prompt string -> promptbox | |
4. pre-rendered result PIL -> output_img | |
""" | |
out_path = subj_path.replace(".png", "_out.png") # your saved result | |
return ( | |
Image.open(subj_path), # 1οΈβ£ subject | |
{"image": Image.open(bg_path), "mask": None}, # 2οΈβ£ sketch dict | |
prompt_text, # 3οΈβ£ prompt | |
Image.open(out_path) # 4οΈβ£ output image | |
) | |
def ex(subj, bg, prompt): | |
return [ | |
Image.open(subj), | |
{"image": Image.open(bg), "mask": None}, | |
prompt | |
] | |
# ---------- Examples ---------------------------------------- | |
gr.Examples( | |
examples=[ | |
["examples/sofa1_1.png", "examples/sofa1_bg.png", "add the sofa", "examples/sofa1_out.png"], | |
["examples/sofa2.png", "examples/sofa2_bg.png", "add this sofa", "examples/sofa2_out.png"], | |
["examples/chair1.png", "examples/chair1_bg.png", "add the chair", "examples/chair1_out.png"], | |
["examples/console_table.png", "examples/console_table_bg.png", "Scandinavian console table against a gallery-style wall filled with abstract framed art,", "examples/console_table_out.png"], | |
["examples/office_chair.png", "examples/office_chair_bg.png", "office chair", "examples/office_chair_out.png"], | |
["examples/office_chair1.png", "examples/office_chair1_bg.png", "Executive mesh chair in a modern home office, with matte black wall panels, built-in shelves, ", "examples/office_chair1_out.png"], | |
["examples/bed.png", "examples/bed_in.png", "Low platform bed in a Japandi-style bedroom, surrounded by floating nightstands", "examples/bed_out.png"], | |
["examples/car.png", "examples/car_bg.png", "car on the road", "examples/car_out.png"], | |
], | |
inputs = [subj_img, ref_img, promptbox, output_img], | |
outputs = [subj_img, ref_img, promptbox, output_img], | |
fn = _load_and_show, | |
#preprocess = False, # data already widget-ready | |
cache_examples=False | |
) | |
# ---------- Buttons & interactions -------------------------- | |
# run_btn.click( | |
# process_image_and_text, | |
# inputs=[subj_img, ref_img, promptbox, gr.State(False), gr.State("")], | |
# outputs=[gallery, raw_state] | |
# ) | |
run_btn.click( | |
process_image_and_text, | |
inputs=[subj_img, ref_img, promptbox, gr.State(False), gr.State("")], | |
outputs=[output_img, raw_state] | |
) | |
bg_btn.click(_gen_bg, inputs=[bgprompt], outputs=[bg_img]) | |
# ---------------- Launch --------------------------------------- | |
if __name__ == "__main__": | |
demo.launch() | |