Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Ovis-U1-3B ๅคๆจกๆ DEMO๏ผCPU / GPU ่ช้ๅบ็ๆฌ๏ผ | |
ไพ่ต๏ผPython 3.10+ใtorch 2.*ใtransformers 4.41.*ใgradio 4.* | |
""" | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โ ๅจไปปไฝ transformers / flash_attn ๅฏผๅ ฅไนๅๅค็็ฏๅข | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
import os, sys, types, subprocess, random, numpy as np, torch | |
import importlib.util # โ ๆฐๅข๏ผ็จไบ็ๆ ModuleSpec | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
# -------- CPU ็ฏๅข๏ผๅฑ่ฝ flash-attn -------- | |
if DEVICE == "cpu": | |
# ๅธ่ฝฝๆฝๅจ็ flash-attn | |
subprocess.run("pip uninstall -y flash-attn", | |
shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
# ๆ้ ็ฉบๅฃณๆจกๅ | |
fake_flash_attn = types.ModuleType("flash_attn") | |
fake_layers = types.ModuleType("flash_attn.layers") | |
fake_rotary = types.ModuleType("flash_attn.layers.rotary") | |
def _cpu_apply_rotary_emb(x, cos, sin): | |
"""็บฏ CPU ็ๆ่ฝฌไฝ็ฝฎ็ผ็ ๏ผ็ฎๆๅฎ็ฐ๏ผ""" | |
x1, x2 = x[..., ::2], x[..., 1::2] | |
rot_x1 = x1 * cos - x2 * sin | |
rot_x2 = x1 * sin + x2 * cos | |
out = torch.empty_like(x) | |
out[..., ::2] = rot_x1 | |
out[..., 1::2] = rot_x2 | |
return out | |
fake_rotary.apply_rotary_emb = _cpu_apply_rotary_emb | |
fake_layers.rotary = fake_rotary | |
fake_flash_attn.layers = fake_layers | |
# โ ๆฐๅข๏ผไธบ็ฉบๅฃณๆจกๅ่กฅๅ ๅๆณ็ __spec__ | |
fake_flash_attn.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None) | |
sys.modules.update({ | |
"flash_attn": fake_flash_attn, | |
"flash_attn.layers": fake_layers, | |
"flash_attn.layers.rotary": fake_rotary, | |
}) | |
else: | |
# GPU ็ฏๅข๏ผๅฐ่ฏๅฎ่ฃ flash-attn | |
try: | |
subprocess.run( | |
"pip install flash-attn==2.6.3 --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, check=True) | |
except subprocess.CalledProcessError: | |
print("[WARN] flash-attn ๅฎ่ฃ ๅคฑ่ดฅ๏ผGPU ๅ ้ๅ่ฝๅ้ใ") | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โก ๅธธ่งไพ่ต | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM | |
from test_img_edit import pipe_img_edit | |
from test_img_to_txt import pipe_txt_gen | |
from test_txt_to_img import pipe_t2i | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โข ๅทฅๅ ทๅฝๆฐ & ๅธธ้ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
MAX_SEED = 10_000 | |
def set_global_seed(seed: int = 42): | |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
def randomize_seed_fn(seed: int, randomize: bool) -> int: | |
return random.randint(0, MAX_SEED) if randomize else seed | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โฃ ๅ ่ฝฝๆจกๅ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MODEL_ID = "AIDC-AI/Ovis-U1-3B" | |
print(f"[INFO] Loading {MODEL_ID} on {DEVICE} โฆ") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=DTYPE, | |
low_cpu_mem_usage=True, | |
device_map="auto", | |
token=HF_TOKEN, | |
trust_remote_code=True | |
).eval() | |
print("[INFO] Model ready!") | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โค ๆจ็ๅฐ่ฃ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def process_txt_to_img(prompt, height, width, steps, seed, cfg, | |
progress=gr.Progress(track_tqdm=True)): | |
set_global_seed(seed) | |
return pipe_t2i(model, prompt, height, width, steps, cfg=cfg, seed=seed) | |
def process_img_to_txt(prompt, img, progress=gr.Progress(track_tqdm=True)): | |
return pipe_txt_gen(model, img, prompt) | |
def process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, | |
progress=gr.Progress(track_tqdm=True)): | |
set_global_seed(seed) | |
return pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=seed) | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โฅ Gradio UI๏ผไธๅ็ไธ่ด๏ผๆญคๅค็็ฅไฟฎๆนๆ ่ฎฐ๏ผ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
with gr.Blocks(title="Ovis-U1-3B (CPU/GPU adaptive)") as demo: | |
gr.Markdown("# Ovis-U1-3B\nๅคๆจกๆๆๆฌ-ๅพๅ DEMO๏ผCPU/GPU ่ช้ๅบ็๏ผ") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tabs(): | |
# Tab 1: Image + Text โ Image | |
with gr.TabItem("Image + Text โ Image"): | |
edit_image_input = gr.Image(label="Input Image", type="pil") | |
with gr.Row(): | |
edit_prompt_input = gr.Textbox(show_label=False, placeholder="Describe the editing instructionโฆ") | |
run_edit_image_btn = gr.Button("Run", scale=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
edit_img_guidance = gr.Slider(label="Image Guidance", minimum=1, maximum=10, value=1.5, step=0.1) | |
edit_txt_guidance = gr.Slider(label="Text Guidance", minimum=1, maximum=30, value=6.0, step=0.5) | |
edit_steps = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1) | |
edit_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1) | |
edit_random = gr.Checkbox(label="Randomize seed", value=False) | |
# Tab 2: Text โ Image | |
with gr.TabItem("Text โ Image"): | |
prompt_gen = gr.Textbox(show_label=False, placeholder="Describe the image you wantโฆ") | |
run_gen_btn = gr.Button("Run", scale=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
height_slider = gr.Slider(label="height", minimum=256, maximum=1536, value=1024, step=32) | |
width_slider = gr.Slider(label="width", minimum=256, maximum=1536, value=1024, step=32) | |
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=30, value=5, step=0.5) | |
steps_slider = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1) | |
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1) | |
random_check = gr.Checkbox(label="Randomize seed", value=False) | |
# Tab 3: Image โ Text | |
with gr.TabItem("Image โ Text"): | |
understand_img = gr.Image(label="Input Image", type="pil") | |
understand_prompt = gr.Textbox(show_label=False, placeholder="Describe the question about imageโฆ") | |
run_understand = gr.Button("Run", scale=0) | |
clear_btn = gr.Button("Clear All") | |
with gr.Column(): | |
gallery = gr.Gallery(label="Generated Images", columns=2, visible=True) | |
txt_out = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False) | |
# ไบไปถ็ปๅฎ๏ผไธไธไธ็็ธๅ๏ผ็็ฅ้ๅคๆณจ้๏ผ | |
def run_tab1(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)): | |
if img is None: | |
return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True) | |
imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress) | |
return gr.update(value=imgs, visible=True), gr.update(value="", visible=False) | |
def run_tab2(prompt, h, w, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)): | |
imgs = process_txt_to_img(prompt, h, w, steps, seed, guidance, progress) | |
return gr.update(value=imgs, visible=True), gr.update(value="", visible=False) | |
def run_tab3(img, prompt, progress=gr.Progress(track_tqdm=True)): | |
if img is None: | |
return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True) | |
text = process_img_to_txt(prompt, img, progress) | |
return gr.update(value=[], visible=False), gr.update(value=text, visible=True) | |
# Tab1 ็ปๅฎ | |
run_edit_image_btn.click(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then( | |
run_tab1, | |
[edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance], | |
[gallery, txt_out]) | |
edit_prompt_input.submit(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then( | |
run_tab1, | |
[edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance], | |
[gallery, txt_out]) | |
# Tab2 ็ปๅฎ | |
run_gen_btn.click(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then( | |
run_tab2, | |
[prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider], | |
[gallery, txt_out]) | |
prompt_gen.submit(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then( | |
run_tab2, | |
[prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider], | |
[gallery, txt_out]) | |
# Tab3 ็ปๅฎ | |
run_understand.click(run_tab3, [understand_img, understand_prompt], [gallery, txt_out]) | |
understand_prompt.submit(run_tab3, [understand_img, understand_prompt], [gallery, txt_out]) | |
# ๆธ ็ฉบ | |
def clear_all(): | |
return ( | |
gr.update(value=None), gr.update(value=""), gr.update(value=1.5), gr.update(value=6.0), | |
gr.update(value=50), gr.update(value=42), gr.update(value=False), | |
gr.update(value=""), gr.update(value=1024), gr.update(value=1024), | |
gr.update(value=5), gr.update(value=50), gr.update(value=42), gr.update(value=False), | |
gr.update(value=None), gr.update(value=""), | |
gr.update(value=[], visible=True), gr.update(value="", visible=False) | |
) | |
clear_btn.click(clear_all, [], [ | |
edit_image_input, edit_prompt_input, edit_img_guidance, edit_txt_guidance, | |
edit_steps, edit_seed, edit_random, prompt_gen, height_slider, width_slider, | |
guidance_slider, steps_slider, seed_slider, random_check, understand_img, | |
understand_prompt, gallery, txt_out | |
]) | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# โฆ ๅฏๅจ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
if __name__ == "__main__": | |
demo.launch() | |