Spaces:
Runtime error
Runtime error
File size: 11,987 Bytes
d53478b fa6b263 d53478b ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e fa6b263 ed79f3e ff3266f d53478b ff3266f fa6b263 d53478b ed79f3e fa6b263 ed79f3e d53478b ed79f3e d53478b ff3266f ed79f3e fa6b263 ed79f3e d53478b fa6b263 d53478b fa6b263 d53478b ed79f3e d53478b fa6b263 d53478b ed79f3e fa6b263 ed79f3e ff3266f ed79f3e d53478b ff3266f ed79f3e ff3266f ed79f3e fa6b263 ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f d53478b ed79f3e fa6b263 ed79f3e ff3266f ed79f3e fa6b263 ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e ff3266f ed79f3e d53478b ed79f3e ff3266f ed79f3e d53478b ed79f3e ff3266f fa6b263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ovis-U1-3B ๅคๆจกๆ DEMO๏ผCPU / GPU ่ช้ๅบ็ๆฌ๏ผ
ไพ่ต๏ผPython 3.10+ใtorch 2.*ใtransformers 4.41.*ใgradio 4.*
"""
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โ ๅจไปปไฝ transformers / flash_attn ๅฏผๅ
ฅไนๅๅค็็ฏๅข
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
import os, sys, types, subprocess, random, numpy as np, torch
import importlib.util # โ
ๆฐๅข๏ผ็จไบ็ๆ ModuleSpec
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# -------- CPU ็ฏๅข๏ผๅฑ่ฝ flash-attn --------
if DEVICE == "cpu":
# ๅธ่ฝฝๆฝๅจ็ flash-attn
subprocess.run("pip uninstall -y flash-attn",
shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# ๆ้ ็ฉบๅฃณๆจกๅ
fake_flash_attn = types.ModuleType("flash_attn")
fake_layers = types.ModuleType("flash_attn.layers")
fake_rotary = types.ModuleType("flash_attn.layers.rotary")
def _cpu_apply_rotary_emb(x, cos, sin):
"""็บฏ CPU ็ๆ่ฝฌไฝ็ฝฎ็ผ็ ๏ผ็ฎๆๅฎ็ฐ๏ผ"""
x1, x2 = x[..., ::2], x[..., 1::2]
rot_x1 = x1 * cos - x2 * sin
rot_x2 = x1 * sin + x2 * cos
out = torch.empty_like(x)
out[..., ::2] = rot_x1
out[..., 1::2] = rot_x2
return out
fake_rotary.apply_rotary_emb = _cpu_apply_rotary_emb
fake_layers.rotary = fake_rotary
fake_flash_attn.layers = fake_layers
# โ
ๆฐๅข๏ผไธบ็ฉบๅฃณๆจกๅ่กฅๅ
ๅๆณ็ __spec__
fake_flash_attn.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None)
sys.modules.update({
"flash_attn": fake_flash_attn,
"flash_attn.layers": fake_layers,
"flash_attn.layers.rotary": fake_rotary,
})
else:
# GPU ็ฏๅข๏ผๅฐ่ฏๅฎ่ฃ
flash-attn
try:
subprocess.run(
"pip install flash-attn==2.6.3 --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True, check=True)
except subprocess.CalledProcessError:
print("[WARN] flash-attn ๅฎ่ฃ
ๅคฑ่ดฅ๏ผGPU ๅ ้ๅ่ฝๅ้ใ")
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โก ๅธธ่งไพ่ต
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM
from test_img_edit import pipe_img_edit
from test_img_to_txt import pipe_txt_gen
from test_txt_to_img import pipe_t2i
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โข ๅทฅๅ
ทๅฝๆฐ & ๅธธ้
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
MAX_SEED = 10_000
def set_global_seed(seed: int = 42):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def randomize_seed_fn(seed: int, randomize: bool) -> int:
return random.randint(0, MAX_SEED) if randomize else seed
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โฃ ๅ ่ฝฝๆจกๅ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "AIDC-AI/Ovis-U1-3B"
print(f"[INFO] Loading {MODEL_ID} on {DEVICE} โฆ")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
device_map="auto",
token=HF_TOKEN,
trust_remote_code=True
).eval()
print("[INFO] Model ready!")
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โค ๆจ็ๅฐ่ฃ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def process_txt_to_img(prompt, height, width, steps, seed, cfg,
progress=gr.Progress(track_tqdm=True)):
set_global_seed(seed)
return pipe_t2i(model, prompt, height, width, steps, cfg=cfg, seed=seed)
def process_img_to_txt(prompt, img, progress=gr.Progress(track_tqdm=True)):
return pipe_txt_gen(model, img, prompt)
def process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg,
progress=gr.Progress(track_tqdm=True)):
set_global_seed(seed)
return pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=seed)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โฅ Gradio UI๏ผไธๅ็ไธ่ด๏ผๆญคๅค็็ฅไฟฎๆนๆ ่ฎฐ๏ผ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
with gr.Blocks(title="Ovis-U1-3B (CPU/GPU adaptive)") as demo:
gr.Markdown("# Ovis-U1-3B\nๅคๆจกๆๆๆฌ-ๅพๅ DEMO๏ผCPU/GPU ่ช้ๅบ็๏ผ")
with gr.Row():
with gr.Column():
with gr.Tabs():
# Tab 1: Image + Text โ Image
with gr.TabItem("Image + Text โ Image"):
edit_image_input = gr.Image(label="Input Image", type="pil")
with gr.Row():
edit_prompt_input = gr.Textbox(show_label=False, placeholder="Describe the editing instructionโฆ")
run_edit_image_btn = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
edit_img_guidance = gr.Slider(label="Image Guidance", minimum=1, maximum=10, value=1.5, step=0.1)
edit_txt_guidance = gr.Slider(label="Text Guidance", minimum=1, maximum=30, value=6.0, step=0.5)
edit_steps = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1)
edit_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
edit_random = gr.Checkbox(label="Randomize seed", value=False)
# Tab 2: Text โ Image
with gr.TabItem("Text โ Image"):
prompt_gen = gr.Textbox(show_label=False, placeholder="Describe the image you wantโฆ")
run_gen_btn = gr.Button("Run", scale=0)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
height_slider = gr.Slider(label="height", minimum=256, maximum=1536, value=1024, step=32)
width_slider = gr.Slider(label="width", minimum=256, maximum=1536, value=1024, step=32)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=30, value=5, step=0.5)
steps_slider = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1)
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
random_check = gr.Checkbox(label="Randomize seed", value=False)
# Tab 3: Image โ Text
with gr.TabItem("Image โ Text"):
understand_img = gr.Image(label="Input Image", type="pil")
understand_prompt = gr.Textbox(show_label=False, placeholder="Describe the question about imageโฆ")
run_understand = gr.Button("Run", scale=0)
clear_btn = gr.Button("Clear All")
with gr.Column():
gallery = gr.Gallery(label="Generated Images", columns=2, visible=True)
txt_out = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False)
# ไบไปถ็ปๅฎ๏ผไธไธไธ็็ธๅ๏ผ็็ฅ้ๅคๆณจ้๏ผ
def run_tab1(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)):
if img is None:
return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True)
imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress)
return gr.update(value=imgs, visible=True), gr.update(value="", visible=False)
def run_tab2(prompt, h, w, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)):
imgs = process_txt_to_img(prompt, h, w, steps, seed, guidance, progress)
return gr.update(value=imgs, visible=True), gr.update(value="", visible=False)
def run_tab3(img, prompt, progress=gr.Progress(track_tqdm=True)):
if img is None:
return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True)
text = process_img_to_txt(prompt, img, progress)
return gr.update(value=[], visible=False), gr.update(value=text, visible=True)
# Tab1 ็ปๅฎ
run_edit_image_btn.click(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then(
run_tab1,
[edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance],
[gallery, txt_out])
edit_prompt_input.submit(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then(
run_tab1,
[edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance],
[gallery, txt_out])
# Tab2 ็ปๅฎ
run_gen_btn.click(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then(
run_tab2,
[prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider],
[gallery, txt_out])
prompt_gen.submit(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then(
run_tab2,
[prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider],
[gallery, txt_out])
# Tab3 ็ปๅฎ
run_understand.click(run_tab3, [understand_img, understand_prompt], [gallery, txt_out])
understand_prompt.submit(run_tab3, [understand_img, understand_prompt], [gallery, txt_out])
# ๆธ
็ฉบ
def clear_all():
return (
gr.update(value=None), gr.update(value=""), gr.update(value=1.5), gr.update(value=6.0),
gr.update(value=50), gr.update(value=42), gr.update(value=False),
gr.update(value=""), gr.update(value=1024), gr.update(value=1024),
gr.update(value=5), gr.update(value=50), gr.update(value=42), gr.update(value=False),
gr.update(value=None), gr.update(value=""),
gr.update(value=[], visible=True), gr.update(value="", visible=False)
)
clear_btn.click(clear_all, [], [
edit_image_input, edit_prompt_input, edit_img_guidance, edit_txt_guidance,
edit_steps, edit_seed, edit_random, prompt_gen, height_slider, width_slider,
guidance_slider, steps_slider, seed_slider, random_check, understand_img,
understand_prompt, gallery, txt_out
])
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โฆ ๅฏๅจ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
if __name__ == "__main__":
demo.launch()
|