File size: 11,987 Bytes
d53478b
 
 
fa6b263
 
d53478b
 
ed79f3e
 
 
 
 
fa6b263
 
ed79f3e
fa6b263
ed79f3e
fa6b263
ed79f3e
 
 
fa6b263
ed79f3e
fa6b263
 
 
 
 
ed79f3e
fa6b263
 
 
 
 
 
 
 
 
 
 
 
ed79f3e
 
 
fa6b263
 
 
 
 
 
ed79f3e
fa6b263
 
 
 
ed79f3e
fa6b263
ed79f3e
fa6b263
ed79f3e
fa6b263
ed79f3e
ff3266f
 
d53478b
ff3266f
 
fa6b263
 
 
d53478b
ed79f3e
fa6b263
ed79f3e
d53478b
 
ed79f3e
 
d53478b
 
ff3266f
 
 
 
ed79f3e
fa6b263
ed79f3e
 
 
d53478b
fa6b263
d53478b
fa6b263
d53478b
ed79f3e
 
d53478b
 
fa6b263
d53478b
 
ed79f3e
fa6b263
ed79f3e
 
 
 
 
ff3266f
ed79f3e
d53478b
ff3266f
ed79f3e
 
 
 
ff3266f
ed79f3e
 
 
fa6b263
 
ff3266f
 
 
 
ed79f3e
ff3266f
 
 
ed79f3e
ff3266f
 
 
ed79f3e
 
 
 
 
 
ff3266f
ed79f3e
 
ff3266f
 
ed79f3e
 
 
 
 
 
 
ff3266f
ed79f3e
 
 
 
ff3266f
d53478b
ed79f3e
 
fa6b263
ed79f3e
 
ff3266f
ed79f3e
fa6b263
ed79f3e
ff3266f
ed79f3e
 
 
ff3266f
ed79f3e
ff3266f
ed79f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3266f
ed79f3e
 
ff3266f
ed79f3e
 
 
 
d53478b
ed79f3e
ff3266f
ed79f3e
 
 
 
 
d53478b
 
ed79f3e
 
 
ff3266f
fa6b263
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ovis-U1-3B ๅคšๆจกๆ€ DEMO๏ผˆCPU / GPU ่‡ช้€‚ๅบ”็‰ˆๆœฌ๏ผ‰
ไพ่ต–๏ผšPython 3.10+ใ€torch 2.*ใ€transformers 4.41.*ใ€gradio 4.*
"""

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘  ๅœจไปปไฝ• transformers / flash_attn ๅฏผๅ…ฅไน‹ๅ‰ๅค„็†็Žฏๅขƒ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
import os, sys, types, subprocess, random, numpy as np, torch
import importlib.util  # โ˜… ๆ–ฐๅขž๏ผš็”จไบŽ็”Ÿๆˆ ModuleSpec

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if DEVICE == "cuda" else torch.float32

# -------- CPU ็Žฏๅขƒ๏ผšๅฑ่”ฝ flash-attn --------
if DEVICE == "cpu":
    # ๅธ่ฝฝๆฝœๅœจ็š„ flash-attn
    subprocess.run("pip uninstall -y flash-attn",
                   shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

    # ๆž„้€ ็ฉบๅฃณๆจกๅ—
    fake_flash_attn   = types.ModuleType("flash_attn")
    fake_layers       = types.ModuleType("flash_attn.layers")
    fake_rotary       = types.ModuleType("flash_attn.layers.rotary")

    def _cpu_apply_rotary_emb(x, cos, sin):
        """็บฏ CPU ็š„ๆ—‹่ฝฌไฝ็ฝฎ็ผ–็ ๏ผˆ็ฎ€ๆ˜“ๅฎž็Žฐ๏ผ‰"""
        x1, x2 = x[..., ::2], x[..., 1::2]
        rot_x1 = x1 * cos - x2 * sin
        rot_x2 = x1 * sin + x2 * cos
        out = torch.empty_like(x)
        out[..., ::2] = rot_x1
        out[..., 1::2] = rot_x2
        return out

    fake_rotary.apply_rotary_emb = _cpu_apply_rotary_emb
    fake_layers.rotary           = fake_rotary
    fake_flash_attn.layers       = fake_layers

    # โ˜… ๆ–ฐๅขž๏ผšไธบ็ฉบๅฃณๆจกๅ—่กฅๅ……ๅˆๆณ•็š„ __spec__
    fake_flash_attn.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None)

    sys.modules.update({
        "flash_attn":               fake_flash_attn,
        "flash_attn.layers":        fake_layers,
        "flash_attn.layers.rotary": fake_rotary,
    })
else:
    # GPU ็Žฏๅขƒ๏ผšๅฐ่ฏ•ๅฎ‰่ฃ… flash-attn
    try:
        subprocess.run(
            "pip install flash-attn==2.6.3 --no-build-isolation",
            env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
            shell=True, check=True)
    except subprocess.CalledProcessError:
        print("[WARN] flash-attn ๅฎ‰่ฃ…ๅคฑ่ดฅ๏ผŒGPU ๅŠ ้€ŸๅŠŸ่ƒฝๅ—้™ใ€‚")

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ก ๅธธ่ง„ไพ่ต–
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM

from test_img_edit   import pipe_img_edit
from test_img_to_txt import pipe_txt_gen
from test_txt_to_img import pipe_t2i

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ข ๅทฅๅ…ทๅ‡ฝๆ•ฐ & ๅธธ้‡
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
MAX_SEED = 10_000

def set_global_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def randomize_seed_fn(seed: int, randomize: bool) -> int:
    return random.randint(0, MAX_SEED) if randomize else seed

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ฃ ๅŠ ่ฝฝๆจกๅž‹
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "AIDC-AI/Ovis-U1-3B"

print(f"[INFO] Loading {MODEL_ID} on {DEVICE} โ€ฆ")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    low_cpu_mem_usage=True,
    device_map="auto",
    token=HF_TOKEN,
    trust_remote_code=True
).eval()
print("[INFO] Model ready!")

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ค ๆŽจ็†ๅฐ่ฃ…
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def process_txt_to_img(prompt, height, width, steps, seed, cfg,
                       progress=gr.Progress(track_tqdm=True)):
    set_global_seed(seed)
    return pipe_t2i(model, prompt, height, width, steps, cfg=cfg, seed=seed)

def process_img_to_txt(prompt, img, progress=gr.Progress(track_tqdm=True)):
    return pipe_txt_gen(model, img, prompt)

def process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg,
                           progress=gr.Progress(track_tqdm=True)):
    set_global_seed(seed)
    return pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=seed)

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ฅ Gradio UI๏ผˆไธŽๅ‰็‰ˆไธ€่‡ด๏ผŒๆญคๅค„็œ็•ฅไฟฎๆ”นๆ ‡่ฎฐ๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks(title="Ovis-U1-3B (CPU/GPU adaptive)") as demo:
    gr.Markdown("# Ovis-U1-3B\nๅคšๆจกๆ€ๆ–‡ๆœฌ-ๅ›พๅƒ DEMO๏ผˆCPU/GPU ่‡ช้€‚ๅบ”็‰ˆ๏ผ‰")

    with gr.Row():
        with gr.Column():
            with gr.Tabs():
                # Tab 1: Image + Text โ†’ Image
                with gr.TabItem("Image + Text โ†’ Image"):
                    edit_image_input = gr.Image(label="Input Image", type="pil")
                    with gr.Row():
                        edit_prompt_input = gr.Textbox(show_label=False, placeholder="Describe the editing instructionโ€ฆ")
                        run_edit_image_btn = gr.Button("Run", scale=0)
                    with gr.Accordion("Advanced Settings", open=False):
                        with gr.Row():
                            edit_img_guidance = gr.Slider(label="Image Guidance", minimum=1, maximum=10, value=1.5, step=0.1)
                            edit_txt_guidance = gr.Slider(label="Text Guidance", minimum=1, maximum=30, value=6.0, step=0.5)
                        edit_steps = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1)
                        edit_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
                        edit_random = gr.Checkbox(label="Randomize seed", value=False)
                # Tab 2: Text โ†’ Image
                with gr.TabItem("Text โ†’ Image"):
                    prompt_gen = gr.Textbox(show_label=False, placeholder="Describe the image you wantโ€ฆ")
                    run_gen_btn = gr.Button("Run", scale=0)
                    with gr.Accordion("Advanced Settings", open=False):
                        with gr.Row():
                            height_slider = gr.Slider(label="height", minimum=256, maximum=1536, value=1024, step=32)
                            width_slider  = gr.Slider(label="width",  minimum=256, maximum=1536, value=1024, step=32)
                        guidance_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=30, value=5, step=0.5)
                        steps_slider    = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1)
                        seed_slider     = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1)
                        random_check    = gr.Checkbox(label="Randomize seed", value=False)
                # Tab 3: Image โ†’ Text
                with gr.TabItem("Image โ†’ Text"):
                    understand_img = gr.Image(label="Input Image", type="pil")
                    understand_prompt = gr.Textbox(show_label=False, placeholder="Describe the question about imageโ€ฆ")
                    run_understand = gr.Button("Run", scale=0)
            clear_btn = gr.Button("Clear All")

        with gr.Column():
            gallery = gr.Gallery(label="Generated Images", columns=2, visible=True)
            txt_out = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False)

    # ไบ‹ไปถ็ป‘ๅฎš๏ผˆไธŽไธŠไธ€็‰ˆ็›ธๅŒ๏ผŒ็œ็•ฅ้‡ๅคๆณจ้‡Š๏ผ‰
    def run_tab1(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)):
        if img is None:
            return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True)
        imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress)
        return gr.update(value=imgs, visible=True), gr.update(value="", visible=False)

    def run_tab2(prompt, h, w, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)):
        imgs = process_txt_to_img(prompt, h, w, steps, seed, guidance, progress)
        return gr.update(value=imgs, visible=True), gr.update(value="", visible=False)

    def run_tab3(img, prompt, progress=gr.Progress(track_tqdm=True)):
        if img is None:
            return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True)
        text = process_img_to_txt(prompt, img, progress)
        return gr.update(value=[], visible=False), gr.update(value=text, visible=True)

    # Tab1 ็ป‘ๅฎš
    run_edit_image_btn.click(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then(
        run_tab1,
        [edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance],
        [gallery, txt_out])

    edit_prompt_input.submit(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then(
        run_tab1,
        [edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance],
        [gallery, txt_out])

    # Tab2 ็ป‘ๅฎš
    run_gen_btn.click(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then(
        run_tab2,
        [prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider],
        [gallery, txt_out])

    prompt_gen.submit(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then(
        run_tab2,
        [prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider],
        [gallery, txt_out])

    # Tab3 ็ป‘ๅฎš
    run_understand.click(run_tab3, [understand_img, understand_prompt], [gallery, txt_out])
    understand_prompt.submit(run_tab3, [understand_img, understand_prompt], [gallery, txt_out])

    # ๆธ…็ฉบ
    def clear_all():
        return (
            gr.update(value=None), gr.update(value=""), gr.update(value=1.5), gr.update(value=6.0),
            gr.update(value=50), gr.update(value=42), gr.update(value=False),
            gr.update(value=""), gr.update(value=1024), gr.update(value=1024),
            gr.update(value=5), gr.update(value=50), gr.update(value=42), gr.update(value=False),
            gr.update(value=None), gr.update(value=""),
            gr.update(value=[], visible=True), gr.update(value="", visible=False)
        )
    clear_btn.click(clear_all, [], [
        edit_image_input, edit_prompt_input, edit_img_guidance, edit_txt_guidance,
        edit_steps, edit_seed, edit_random, prompt_gen, height_slider, width_slider,
        guidance_slider, steps_slider, seed_slider, random_check, understand_img,
        understand_prompt, gallery, txt_out
    ])

# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ‘ฆ ๅฏๅŠจ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if __name__ == "__main__":
    demo.launch()