Ovis-U1-3B-cpu / app.py
innoai's picture
Update app.py
ed79f3e verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ovis-U1-3B ๅคšๆจกๆ€ DEMO๏ผˆCPU / GPU ่‡ช้€‚ๅบ”็‰ˆๆœฌ๏ผ‰
ไพ่ต–๏ผšPython 3.10+ใ€torch 2.*ใ€transformers 4.41.*ใ€gradio 4.*
"""
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘  ๅœจไปปไฝ• transformers / flash_attn ๅฏผๅ…ฅไน‹ๅ‰ๅค„็†็Žฏๅขƒ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
import os, sys, types, subprocess, random, numpy as np, torch
import importlib.util # โ˜… ๆ–ฐๅขž๏ผš็”จไบŽ็”Ÿๆˆ ModuleSpec
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# -------- CPU ็Žฏๅขƒ๏ผšๅฑ่”ฝ flash-attn --------
if DEVICE == "cpu":
# ๅธ่ฝฝๆฝœๅœจ็š„ flash-attn
subprocess.run("pip uninstall -y flash-attn",
shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# ๆž„้€ ็ฉบๅฃณๆจกๅ—
fake_flash_attn = types.ModuleType("flash_attn")
fake_layers = types.ModuleType("flash_attn.layers")
fake_rotary = types.ModuleType("flash_attn.layers.rotary")
def _cpu_apply_rotary_emb(x, cos, sin):
"""็บฏ CPU ็š„ๆ—‹่ฝฌไฝ็ฝฎ็ผ–็ ๏ผˆ็ฎ€ๆ˜“ๅฎž็Žฐ๏ผ‰"""
x1, x2 = x[..., ::2], x[..., 1::2]
rot_x1 = x1 * cos - x2 * sin
rot_x2 = x1 * sin + x2 * cos
out = torch.empty_like(x)
out[..., ::2] = rot_x1
out[..., 1::2] = rot_x2
return out
fake_rotary.apply_rotary_emb = _cpu_apply_rotary_emb
fake_layers.rotary = fake_rotary
fake_flash_attn.layers = fake_layers
# โ˜… ๆ–ฐๅขž๏ผšไธบ็ฉบๅฃณๆจกๅ—่กฅๅ……ๅˆๆณ•็š„ __spec__
fake_flash_attn.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None)
sys.modules.update({
"flash_attn": fake_flash_attn,
"flash_attn.layers": fake_layers,
"flash_attn.layers.rotary": fake_rotary,
})
else:
# GPU ็Žฏๅขƒ๏ผšๅฐ่ฏ•ๅฎ‰่ฃ… flash-attn
try:
subprocess.run(
"pip install flash-attn==2.6.3 --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True, check=True)
except subprocess.CalledProcessError:
print("[WARN] flash-attn ๅฎ‰่ฃ…ๅคฑ่ดฅ๏ผŒGPU ๅŠ ้€ŸๅŠŸ่ƒฝๅ—้™ใ€‚")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ก ๅธธ่ง„ไพ่ต–
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM
from test_img_edit import pipe_img_edit
from test_img_to_txt import pipe_txt_gen
from test_txt_to_img import pipe_t2i
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ข ๅทฅๅ…ทๅ‡ฝๆ•ฐ & ๅธธ้‡
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
MAX_SEED = 10_000
def set_global_seed(seed: int = 42):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def randomize_seed_fn(seed: int, randomize: bool) -> int:
return random.randint(0, MAX_SEED) if randomize else seed
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ฃ ๅŠ ่ฝฝๆจกๅž‹
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "AIDC-AI/Ovis-U1-3B"
print(f"[INFO] Loading {MODEL_ID} on {DEVICE} โ€ฆ")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
device_map="auto",
token=HF_TOKEN,
trust_remote_code=True
).eval()
print("[INFO] Model ready!")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ค ๆŽจ็†ๅฐ่ฃ…
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def process_txt_to_img(prompt, height, width, steps, seed, cfg,
progress=gr.Progress(track_tqdm=True)):
set_global_seed(seed)
return pipe_t2i(model, prompt, height, width, steps, cfg=cfg, seed=seed)
def process_img_to_txt(prompt, img, progress=gr.Progress(track_tqdm=True)):
return pipe_txt_gen(model, img, prompt)
def process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg,
progress=gr.Progress(track_tqdm=True)):
set_global_seed(seed)
return pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=seed)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ฅ Gradio UI๏ผˆไธŽๅ‰็‰ˆไธ€่‡ด๏ผŒๆญคๅค„็œ็•ฅไฟฎๆ”นๆ ‡่ฎฐ๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks(title="Ovis-U1-3B (CPU/GPU adaptive)") as demo:
gr.Markdown("# Ovis-U1-3B\nๅคšๆจกๆ€ๆ–‡ๆœฌ-ๅ›พๅƒ DEMO๏ผˆCPU/GPU ่‡ช้€‚ๅบ”็‰ˆ๏ผ‰")
with gr.Row():
with gr.Column():
with gr.Tabs():
# Tab 1: Image + Text โ†’ Image
with gr.TabItem("Image + Text โ†’ Image"):
edit_image_input = gr.Image(label="Input Image", type="pil")
with gr.Row():
edit_prompt_input = gr.Textbox(show_label=False, placeholder="Describe the editing instructionโ€ฆ")
run_edit_image_btn = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
edit_img_guidance = gr.Slider(label="Image Guidance", minimum=1, maximum=10, value=1.5, step=0.1)
edit_txt_guidance = gr.Slider(label="Text Guidance", minimum=1, maximum=30, value=6.0, step=0.5)
edit_steps = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1)
edit_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
edit_random = gr.Checkbox(label="Randomize seed", value=False)
# Tab 2: Text โ†’ Image
with gr.TabItem("Text โ†’ Image"):
prompt_gen = gr.Textbox(show_label=False, placeholder="Describe the image you wantโ€ฆ")
run_gen_btn = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
height_slider = gr.Slider(label="height", minimum=256, maximum=1536, value=1024, step=32)
width_slider = gr.Slider(label="width", minimum=256, maximum=1536, value=1024, step=32)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=30, value=5, step=0.5)
steps_slider = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1)
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
random_check = gr.Checkbox(label="Randomize seed", value=False)
# Tab 3: Image โ†’ Text
with gr.TabItem("Image โ†’ Text"):
understand_img = gr.Image(label="Input Image", type="pil")
understand_prompt = gr.Textbox(show_label=False, placeholder="Describe the question about imageโ€ฆ")
run_understand = gr.Button("Run", scale=0)
clear_btn = gr.Button("Clear All")
with gr.Column():
gallery = gr.Gallery(label="Generated Images", columns=2, visible=True)
txt_out = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False)
# ไบ‹ไปถ็ป‘ๅฎš๏ผˆไธŽไธŠไธ€็‰ˆ็›ธๅŒ๏ผŒ็œ็•ฅ้‡ๅคๆณจ้‡Š๏ผ‰
def run_tab1(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)):
if img is None:
return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True)
imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress)
return gr.update(value=imgs, visible=True), gr.update(value="", visible=False)
def run_tab2(prompt, h, w, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)):
imgs = process_txt_to_img(prompt, h, w, steps, seed, guidance, progress)
return gr.update(value=imgs, visible=True), gr.update(value="", visible=False)
def run_tab3(img, prompt, progress=gr.Progress(track_tqdm=True)):
if img is None:
return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True)
text = process_img_to_txt(prompt, img, progress)
return gr.update(value=[], visible=False), gr.update(value=text, visible=True)
# Tab1 ็ป‘ๅฎš
run_edit_image_btn.click(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then(
run_tab1,
[edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance],
[gallery, txt_out])
edit_prompt_input.submit(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then(
run_tab1,
[edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance],
[gallery, txt_out])
# Tab2 ็ป‘ๅฎš
run_gen_btn.click(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then(
run_tab2,
[prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider],
[gallery, txt_out])
prompt_gen.submit(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then(
run_tab2,
[prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider],
[gallery, txt_out])
# Tab3 ็ป‘ๅฎš
run_understand.click(run_tab3, [understand_img, understand_prompt], [gallery, txt_out])
understand_prompt.submit(run_tab3, [understand_img, understand_prompt], [gallery, txt_out])
# ๆธ…็ฉบ
def clear_all():
return (
gr.update(value=None), gr.update(value=""), gr.update(value=1.5), gr.update(value=6.0),
gr.update(value=50), gr.update(value=42), gr.update(value=False),
gr.update(value=""), gr.update(value=1024), gr.update(value=1024),
gr.update(value=5), gr.update(value=50), gr.update(value=42), gr.update(value=False),
gr.update(value=None), gr.update(value=""),
gr.update(value=[], visible=True), gr.update(value="", visible=False)
)
clear_btn.click(clear_all, [], [
edit_image_input, edit_prompt_input, edit_img_guidance, edit_txt_guidance,
edit_steps, edit_seed, edit_random, prompt_gen, height_slider, width_slider,
guidance_slider, steps_slider, seed_slider, random_check, understand_img,
understand_prompt, gallery, txt_out
])
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ฆ ๅฏๅŠจ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if __name__ == "__main__":
demo.launch()