Spaces:
Running
on
Zero
Running
on
Zero
Add TI2TI UI without binaries
Browse files- .gitignore +3 -0
- MMaDA/inference/__pycache__/common.cpython-310.pyc +0 -0
- MMaDA/inference/gradio_multimodal_demo_inst.py +249 -3
- MMaDA/models/__pycache__/__init__.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/common_modules.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/configuration_emova_speech_tokenizer.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/configuration_llada.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/misc.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_emova_speech_tokenizer.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_llada.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_magvitv2.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_mmada.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_omada.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_utils.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/modeling_video_encoder.cpython-310.pyc +0 -0
- MMaDA/models/__pycache__/sampling.cpython-310.pyc +0 -0
- MMaDA/training/__pycache__/__init__.cpython-310.pyc +0 -0
- MMaDA/training/__pycache__/data.cpython-310.pyc +0 -0
- MMaDA/training/__pycache__/prompting_utils.cpython-310.pyc +0 -0
- MMaDA/training/__pycache__/utils.cpython-310.pyc +0 -0
- app.py +77 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
MMaDA/inference/demo/ti2ti/
|
MMaDA/inference/__pycache__/common.cpython-310.pyc
DELETED
|
Binary file (5.69 kB)
|
|
|
MMaDA/inference/gradio_multimodal_demo_inst.py
CHANGED
|
@@ -495,6 +495,32 @@ def _load_i2i_examples():
|
|
| 495 |
return examples
|
| 496 |
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
def _load_media_examples(subdir: str, suffixes):
|
| 499 |
target_dir = DEMO_ROOT / subdir
|
| 500 |
if not target_dir.exists():
|
|
@@ -510,6 +536,7 @@ T2S_EXAMPLES = _load_t2s_examples()
|
|
| 510 |
CHAT_EXAMPLES = _load_chat_examples()
|
| 511 |
T2I_EXAMPLES = _load_t2i_examples()
|
| 512 |
I2I_EXAMPLES = _load_i2i_examples()
|
|
|
|
| 513 |
S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
|
| 514 |
V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
|
| 515 |
S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"})
|
|
@@ -629,6 +656,33 @@ def _render_image_message(status: str, image: Optional[Image.Image]) -> str:
|
|
| 629 |
return _render_response(status, image_html)
|
| 630 |
|
| 631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
def _format_user_message(message: str) -> str:
|
| 633 |
clean = html.escape(message or "")
|
| 634 |
return clean.replace("\n", "<br>")
|
|
@@ -1180,6 +1234,146 @@ class OmadaDemo:
|
|
| 1180 |
image = self._decode_image_tokens(gen_tokens[0])
|
| 1181 |
return image, "Edited image generated."
|
| 1182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
# ------------------------------------------------------------------
|
| 1184 |
# Video-to-Speech
|
| 1185 |
# ------------------------------------------------------------------
|
|
@@ -1866,7 +2060,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 1866 |
group_to_modes = {
|
| 1867 |
"Any β Speech": ["Text β Speech", "Speech β Speech", "Video β Speech", "Image β Speech"],
|
| 1868 |
"Any β Text": ["Speech β Text", "Video β Text", "Chat", "MMU (2 Images β Text)"],
|
| 1869 |
-
"Image Generation": ["Text β Image", "Image Editing"],
|
| 1870 |
}
|
| 1871 |
default_group = "Any β Speech"
|
| 1872 |
default_mode = group_to_modes[default_group][0]
|
|
@@ -1881,6 +2075,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 1881 |
"MMU (2 Images β Text)": "Ask a question about the two uploaded images.",
|
| 1882 |
"Text β Image": "Describe the image you want to generate...",
|
| 1883 |
"Image Editing": "Describe how you want to edit the uploaded image...",
|
|
|
|
| 1884 |
}
|
| 1885 |
with gr.Row(elem_classes=["omada-layout"], equal_height=False):
|
| 1886 |
with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]):
|
|
@@ -2075,6 +2270,27 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2075 |
inputs=[chat_input],
|
| 2076 |
examples_per_page=4,
|
| 2077 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2078 |
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel:
|
| 2079 |
with gr.Group(elem_classes=["omada-card"]):
|
| 2080 |
gr.Markdown("### Chat Controls")
|
|
@@ -2123,7 +2339,8 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2123 |
show_v2t = group == "Any β Text" and mode == "Video β Text"
|
| 2124 |
show_chat = group == "Any β Text" and mode == "Chat"
|
| 2125 |
show_mmu = group == "Any β Text" and mode == "MMU (2 Images β Text)"
|
| 2126 |
-
show_image = group == "Image Generation"
|
|
|
|
| 2127 |
placeholder = placeholder_map.get(mode, chat_input.placeholder)
|
| 2128 |
image_mode_value = "Generation" if mode == "Text β Image" else "Editing"
|
| 2129 |
t2i_visible = show_image and mode == "Text β Image"
|
|
@@ -2140,6 +2357,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2140 |
gr.update(visible=show_chat),
|
| 2141 |
gr.update(visible=show_mmu),
|
| 2142 |
gr.update(visible=show_image),
|
|
|
|
| 2143 |
image_mode_update,
|
| 2144 |
gr.update(visible=t2i_visible),
|
| 2145 |
gr.update(visible=i2i_visible),
|
|
@@ -2169,6 +2387,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2169 |
chat_panel,
|
| 2170 |
mmu_panel,
|
| 2171 |
image_panel,
|
|
|
|
| 2172 |
image_mode_selector,
|
| 2173 |
t2i_settings,
|
| 2174 |
i2i_settings,
|
|
@@ -2189,6 +2408,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2189 |
chat_panel,
|
| 2190 |
mmu_panel,
|
| 2191 |
image_panel,
|
|
|
|
| 2192 |
image_mode_selector,
|
| 2193 |
t2i_settings,
|
| 2194 |
i2i_settings,
|
|
@@ -2250,6 +2470,12 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2250 |
i2i_timesteps,
|
| 2251 |
i2i_temperature,
|
| 2252 |
i2i_guidance,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2253 |
v2s_video_path,
|
| 2254 |
v2s_max_tokens,
|
| 2255 |
v2s_steps,
|
|
@@ -2390,7 +2616,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2390 |
)
|
| 2391 |
response = _render_image_message(status, image_result)
|
| 2392 |
display_user_raw = message or "[Image generation request]"
|
| 2393 |
-
|
| 2394 |
image_result, status = app.run_i2i(
|
| 2395 |
message,
|
| 2396 |
i2i_image,
|
|
@@ -2400,6 +2626,18 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2400 |
)
|
| 2401 |
response = _render_image_message(status, image_result)
|
| 2402 |
display_user_raw = message or "[Image editing request]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2403 |
|
| 2404 |
if not response:
|
| 2405 |
status = f"Mode '{mode}' is not supported."
|
|
@@ -2453,6 +2691,12 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2453 |
i2i_timesteps,
|
| 2454 |
i2i_temperature,
|
| 2455 |
i2i_guidance,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2456 |
v2s_video,
|
| 2457 |
v2s_max_tokens,
|
| 2458 |
v2s_steps,
|
|
@@ -2487,6 +2731,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2487 |
gr.update(value=None),
|
| 2488 |
gr.update(value=None),
|
| 2489 |
gr.update(value=None),
|
|
|
|
| 2490 |
)
|
| 2491 |
|
| 2492 |
clear_button.click(
|
|
@@ -2501,6 +2746,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
|
|
| 2501 |
i2s_image,
|
| 2502 |
v2s_video,
|
| 2503 |
i2i_image,
|
|
|
|
| 2504 |
mmu_image_a,
|
| 2505 |
mmu_image_b,
|
| 2506 |
],
|
|
|
|
| 495 |
return examples
|
| 496 |
|
| 497 |
|
| 498 |
+
def _load_ti2ti_examples():
|
| 499 |
+
"""demo/ti2tiμ sample##_src.png + sample##_instr.txt μμ Examplesλ‘ λ¬Άμ΄μ€."""
|
| 500 |
+
d = DEMO_ROOT / "ti2ti"
|
| 501 |
+
if not d.exists():
|
| 502 |
+
return []
|
| 503 |
+
|
| 504 |
+
src_files = sorted([p for p in d.iterdir() if p.is_file() and p.name.endswith("_src.png")])
|
| 505 |
+
txt_files = {
|
| 506 |
+
p.name.replace("_instr.txt", ""): p
|
| 507 |
+
for p in d.iterdir()
|
| 508 |
+
if p.is_file() and p.name.endswith("_instr.txt")
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
examples = []
|
| 512 |
+
for src in src_files:
|
| 513 |
+
stem = src.name.replace("_src.png", "")
|
| 514 |
+
txt = txt_files.get(stem)
|
| 515 |
+
if not txt:
|
| 516 |
+
continue
|
| 517 |
+
instruction = txt.read_text(encoding="utf-8").strip()
|
| 518 |
+
if not instruction:
|
| 519 |
+
continue
|
| 520 |
+
examples.append([str(src), instruction])
|
| 521 |
+
return examples
|
| 522 |
+
|
| 523 |
+
|
| 524 |
def _load_media_examples(subdir: str, suffixes):
|
| 525 |
target_dir = DEMO_ROOT / subdir
|
| 526 |
if not target_dir.exists():
|
|
|
|
| 536 |
CHAT_EXAMPLES = _load_chat_examples()
|
| 537 |
T2I_EXAMPLES = _load_t2i_examples()
|
| 538 |
I2I_EXAMPLES = _load_i2i_examples()
|
| 539 |
+
TI2TI_EXAMPLES = _load_ti2ti_examples()
|
| 540 |
S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
|
| 541 |
V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
|
| 542 |
S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"})
|
|
|
|
| 656 |
return _render_response(status, image_html)
|
| 657 |
|
| 658 |
|
| 659 |
+
def _render_image_text_message(status: str, image: Optional[Image.Image], text: str) -> str:
|
| 660 |
+
"""Render combined text + image output for TI2TI."""
|
| 661 |
+
blocks = []
|
| 662 |
+
text_clean = (text or "").strip()
|
| 663 |
+
if text_clean:
|
| 664 |
+
safe_text = html.escape(text_clean).replace("\n", "<br>")
|
| 665 |
+
blocks.append(f"<div class='omada-response-block'>{safe_text}</div>")
|
| 666 |
+
|
| 667 |
+
if image is not None:
|
| 668 |
+
buffer = io.BytesIO()
|
| 669 |
+
try:
|
| 670 |
+
image.save(buffer, format="PNG")
|
| 671 |
+
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
|
| 672 |
+
blocks.append(
|
| 673 |
+
"<div class='omada-response-block'>"
|
| 674 |
+
"<img src='data:image/png;base64,"
|
| 675 |
+
f"{encoded}"
|
| 676 |
+
"' alt='Generated image' style='max-width:100%;border-radius:12px;' />"
|
| 677 |
+
"</div>"
|
| 678 |
+
)
|
| 679 |
+
except Exception:
|
| 680 |
+
pass
|
| 681 |
+
|
| 682 |
+
body = "".join(blocks)
|
| 683 |
+
return _render_response(status, body if body else None)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
def _format_user_message(message: str) -> str:
|
| 687 |
clean = html.escape(message or "")
|
| 688 |
return clean.replace("\n", "<br>")
|
|
|
|
| 1234 |
image = self._decode_image_tokens(gen_tokens[0])
|
| 1235 |
return image, "Edited image generated."
|
| 1236 |
|
| 1237 |
+
# ------------------------------------------------------------------
|
| 1238 |
+
# Text+Image β Text+Image (TI2TI)
|
| 1239 |
+
# ------------------------------------------------------------------
|
| 1240 |
+
def run_ti2ti(
|
| 1241 |
+
self,
|
| 1242 |
+
instruction: str,
|
| 1243 |
+
source_image: Optional[Image.Image],
|
| 1244 |
+
text_tokens: int,
|
| 1245 |
+
timesteps_image: int,
|
| 1246 |
+
timesteps_text: int,
|
| 1247 |
+
temperature: float,
|
| 1248 |
+
guidance_scale: float,
|
| 1249 |
+
) -> Tuple[Optional[Image.Image], str, str]:
|
| 1250 |
+
instruction_clean = (instruction or "").strip()
|
| 1251 |
+
if source_image is None:
|
| 1252 |
+
return None, "", "Please upload a source image."
|
| 1253 |
+
if not instruction_clean:
|
| 1254 |
+
return None, "", "Provide an editing instruction for TI2TI."
|
| 1255 |
+
|
| 1256 |
+
try:
|
| 1257 |
+
src_tokens = self._prepare_image_tokens(source_image)
|
| 1258 |
+
except Exception as exc:
|
| 1259 |
+
return None, "", f"Failed to encode source image: {exc}"
|
| 1260 |
+
|
| 1261 |
+
text_tokens = max(4, min(int(text_tokens), self.max_text_len))
|
| 1262 |
+
prompt_ids = self.uni_prompting.text_tokenizer(instruction_clean)['input_ids']
|
| 1263 |
+
if isinstance(prompt_ids, list) and prompt_ids and isinstance(prompt_ids[0], list):
|
| 1264 |
+
prompt_ids = prompt_ids[0]
|
| 1265 |
+
if len(prompt_ids) == 0 or prompt_ids[0] != self.uni_prompting.text_tokenizer.bos_token_id:
|
| 1266 |
+
prompt_ids = [self.uni_prompting.text_tokenizer.bos_token_id] + prompt_ids
|
| 1267 |
+
prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
|
| 1268 |
+
prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
|
| 1269 |
+
|
| 1270 |
+
ti2ti_id = int(self.uni_prompting.sptids_dict['<|ti2ti|>'][0].item())
|
| 1271 |
+
soi_id = int(self.uni_prompting.sptids_dict['<|soi|>'][0].item())
|
| 1272 |
+
eoi_id = int(self.uni_prompting.sptids_dict['<|eoi|>'][0].item())
|
| 1273 |
+
pad_raw = getattr(self.uni_prompting, "pad_id", 0)
|
| 1274 |
+
pad_id = int(pad_raw if pad_raw is not None else 0)
|
| 1275 |
+
|
| 1276 |
+
img_placeholder = torch.full(
|
| 1277 |
+
(self.image_seq_len,),
|
| 1278 |
+
self.mask_token_id,
|
| 1279 |
+
dtype=torch.long,
|
| 1280 |
+
device=self.device,
|
| 1281 |
+
)
|
| 1282 |
+
text_placeholder = torch.full(
|
| 1283 |
+
(text_tokens,),
|
| 1284 |
+
self.mask_token_id,
|
| 1285 |
+
dtype=torch.long,
|
| 1286 |
+
device=self.device,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
src_flat = src_tokens.view(-1)
|
| 1290 |
+
prompt_len = prompt_tensor.numel()
|
| 1291 |
+
img_len = img_placeholder.numel()
|
| 1292 |
+
text_len = text_placeholder.numel()
|
| 1293 |
+
|
| 1294 |
+
prompt_start = 2 + src_flat.numel() + 1
|
| 1295 |
+
prompt_end = prompt_start + prompt_len
|
| 1296 |
+
img_start = prompt_end + 1
|
| 1297 |
+
img_end = img_start + img_len
|
| 1298 |
+
text_start = img_end + 1
|
| 1299 |
+
text_end = text_start + text_len
|
| 1300 |
+
|
| 1301 |
+
seq_parts = [
|
| 1302 |
+
torch.tensor([ti2ti_id, soi_id], device=self.device, dtype=torch.long),
|
| 1303 |
+
src_flat,
|
| 1304 |
+
torch.tensor([eoi_id], device=self.device, dtype=torch.long),
|
| 1305 |
+
prompt_tensor,
|
| 1306 |
+
torch.tensor([soi_id], device=self.device, dtype=torch.long),
|
| 1307 |
+
img_placeholder,
|
| 1308 |
+
torch.tensor([eoi_id], device=self.device, dtype=torch.long),
|
| 1309 |
+
text_placeholder,
|
| 1310 |
+
]
|
| 1311 |
+
seq = torch.cat(seq_parts, dim=0).unsqueeze(0)
|
| 1312 |
+
attn = torch.ones_like(seq, dtype=torch.long, device=self.device)
|
| 1313 |
+
|
| 1314 |
+
uncond_seq = seq.clone()
|
| 1315 |
+
uncond_attn = attn.clone()
|
| 1316 |
+
uncond_seq[:, prompt_start:prompt_end] = pad_id
|
| 1317 |
+
uncond_attn[:, prompt_start:prompt_end] = 0
|
| 1318 |
+
|
| 1319 |
+
with torch.no_grad():
|
| 1320 |
+
filled_tokens, _ = self.model.ti2ti_generate(
|
| 1321 |
+
input_ids=seq.to(self.device),
|
| 1322 |
+
uncond_input_ids=uncond_seq.to(self.device),
|
| 1323 |
+
attention_mask=attn.to(self.device),
|
| 1324 |
+
uncond_attention_mask=uncond_attn.to(self.device),
|
| 1325 |
+
temperature=float(temperature),
|
| 1326 |
+
timesteps=int(timesteps_image),
|
| 1327 |
+
timesteps_text=int(timesteps_text),
|
| 1328 |
+
timesteps_image=int(timesteps_image),
|
| 1329 |
+
guidance_scale=float(guidance_scale),
|
| 1330 |
+
noise_schedule=self.image_noise_schedule,
|
| 1331 |
+
seq_len=self.image_seq_len,
|
| 1332 |
+
mask_token_id=self.mask_token_id,
|
| 1333 |
+
codebook_size=self.codebook_size,
|
| 1334 |
+
uni_prompting=self.uni_prompting,
|
| 1335 |
+
config=self.train_cfg,
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
if filled_tokens is None:
|
| 1339 |
+
return None, "", "TI2TI generation failed."
|
| 1340 |
+
|
| 1341 |
+
filled_tokens = torch.clamp(
|
| 1342 |
+
filled_tokens,
|
| 1343 |
+
min=0,
|
| 1344 |
+
max=self.codebook_size + self.text_vocab_size - 1,
|
| 1345 |
+
)
|
| 1346 |
+
pred_img_tokens = filled_tokens[:, img_start:img_end] - self.text_vocab_size
|
| 1347 |
+
pred_img_tokens = torch.clamp(pred_img_tokens, min=0, max=self.codebook_size - 1)
|
| 1348 |
+
try:
|
| 1349 |
+
image_out = self._decode_image_tokens(pred_img_tokens[0])
|
| 1350 |
+
except Exception as exc:
|
| 1351 |
+
return None, "", f"Failed to decode generated image: {exc}"
|
| 1352 |
+
|
| 1353 |
+
text_slice = slice(text_start, min(text_end, filled_tokens.shape[1]))
|
| 1354 |
+
text_block = filled_tokens[:, text_slice]
|
| 1355 |
+
text_vocab = self.text_vocab_size
|
| 1356 |
+
mask_id = int(self.mask_token_id)
|
| 1357 |
+
eos_id = int(self.uni_prompting.text_tokenizer.eos_token_id)
|
| 1358 |
+
eot_id = int(self.uni_prompting.sptids_dict.get("<|eot_id|>", torch.tensor([eos_id], device=self.device))[0].item())
|
| 1359 |
+
pad_token_id = int(pad_id)
|
| 1360 |
+
|
| 1361 |
+
pred_texts = []
|
| 1362 |
+
for row in text_block:
|
| 1363 |
+
seq_list = []
|
| 1364 |
+
for t in row.tolist():
|
| 1365 |
+
if t in (pad_token_id, mask_id):
|
| 1366 |
+
continue
|
| 1367 |
+
if t == eos_id or t == eot_id:
|
| 1368 |
+
break
|
| 1369 |
+
if 0 <= t < text_vocab:
|
| 1370 |
+
seq_list.append(int(t))
|
| 1371 |
+
pred_texts.append(self.uni_prompting.text_tokenizer.decode(seq_list, skip_special_tokens=True))
|
| 1372 |
+
pred_text = pred_texts[0] if pred_texts else ""
|
| 1373 |
+
|
| 1374 |
+
status = "TI2TI generated image and text."
|
| 1375 |
+
return image_out, pred_text, status
|
| 1376 |
+
|
| 1377 |
# ------------------------------------------------------------------
|
| 1378 |
# Video-to-Speech
|
| 1379 |
# ------------------------------------------------------------------
|
|
|
|
| 2060 |
group_to_modes = {
|
| 2061 |
"Any β Speech": ["Text β Speech", "Speech β Speech", "Video β Speech", "Image β Speech"],
|
| 2062 |
"Any β Text": ["Speech β Text", "Video β Text", "Chat", "MMU (2 Images β Text)"],
|
| 2063 |
+
"Image Generation": ["Text β Image", "Image Editing", "Text+Image β Text+Image (TI2TI)"],
|
| 2064 |
}
|
| 2065 |
default_group = "Any β Speech"
|
| 2066 |
default_mode = group_to_modes[default_group][0]
|
|
|
|
| 2075 |
"MMU (2 Images β Text)": "Ask a question about the two uploaded images.",
|
| 2076 |
"Text β Image": "Describe the image you want to generate...",
|
| 2077 |
"Image Editing": "Describe how you want to edit the uploaded image...",
|
| 2078 |
+
"Text+Image β Text+Image (TI2TI)": "Upload an image and describe how you want it edited and captioned.",
|
| 2079 |
}
|
| 2080 |
with gr.Row(elem_classes=["omada-layout"], equal_height=False):
|
| 2081 |
with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]):
|
|
|
|
| 2270 |
inputs=[chat_input],
|
| 2271 |
examples_per_page=4,
|
| 2272 |
)
|
| 2273 |
+
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as ti2ti_panel:
|
| 2274 |
+
with gr.Group(elem_classes=["omada-card"]):
|
| 2275 |
+
gr.Markdown("### Text+Image β Text+Image (TI2TI)")
|
| 2276 |
+
ti2ti_image = gr.Image(type="pil", label="Source image", sources=["upload"])
|
| 2277 |
+
with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
|
| 2278 |
+
ti2ti_text_tokens = gr.Slider(8, 256, value=64, label="Text placeholder tokens", step=4)
|
| 2279 |
+
with gr.Row():
|
| 2280 |
+
ti2ti_img_timesteps = gr.Slider(4, 128, value=64, label="Image timesteps", step=2)
|
| 2281 |
+
ti2ti_text_timesteps = gr.Slider(4, 128, value=64, label="Text timesteps", step=2)
|
| 2282 |
+
with gr.Row():
|
| 2283 |
+
ti2ti_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
|
| 2284 |
+
ti2ti_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
|
| 2285 |
+
if TI2TI_EXAMPLES:
|
| 2286 |
+
with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
|
| 2287 |
+
gr.Markdown("**Sample edits**")
|
| 2288 |
+
with gr.Column(elem_classes=["omada-examples"]):
|
| 2289 |
+
gr.Examples(
|
| 2290 |
+
examples=TI2TI_EXAMPLES,
|
| 2291 |
+
inputs=[ti2ti_image, chat_input],
|
| 2292 |
+
examples_per_page=4,
|
| 2293 |
+
)
|
| 2294 |
with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel:
|
| 2295 |
with gr.Group(elem_classes=["omada-card"]):
|
| 2296 |
gr.Markdown("### Chat Controls")
|
|
|
|
| 2339 |
show_v2t = group == "Any β Text" and mode == "Video β Text"
|
| 2340 |
show_chat = group == "Any β Text" and mode == "Chat"
|
| 2341 |
show_mmu = group == "Any β Text" and mode == "MMU (2 Images β Text)"
|
| 2342 |
+
show_image = group == "Image Generation" and mode in ("Text β Image", "Image Editing")
|
| 2343 |
+
show_ti2ti = group == "Image Generation" and mode == "Text+Image β Text+Image (TI2TI)"
|
| 2344 |
placeholder = placeholder_map.get(mode, chat_input.placeholder)
|
| 2345 |
image_mode_value = "Generation" if mode == "Text β Image" else "Editing"
|
| 2346 |
t2i_visible = show_image and mode == "Text β Image"
|
|
|
|
| 2357 |
gr.update(visible=show_chat),
|
| 2358 |
gr.update(visible=show_mmu),
|
| 2359 |
gr.update(visible=show_image),
|
| 2360 |
+
gr.update(visible=show_ti2ti),
|
| 2361 |
image_mode_update,
|
| 2362 |
gr.update(visible=t2i_visible),
|
| 2363 |
gr.update(visible=i2i_visible),
|
|
|
|
| 2387 |
chat_panel,
|
| 2388 |
mmu_panel,
|
| 2389 |
image_panel,
|
| 2390 |
+
ti2ti_panel,
|
| 2391 |
image_mode_selector,
|
| 2392 |
t2i_settings,
|
| 2393 |
i2i_settings,
|
|
|
|
| 2408 |
chat_panel,
|
| 2409 |
mmu_panel,
|
| 2410 |
image_panel,
|
| 2411 |
+
ti2ti_panel,
|
| 2412 |
image_mode_selector,
|
| 2413 |
t2i_settings,
|
| 2414 |
i2i_settings,
|
|
|
|
| 2470 |
i2i_timesteps,
|
| 2471 |
i2i_temperature,
|
| 2472 |
i2i_guidance,
|
| 2473 |
+
ti2ti_image,
|
| 2474 |
+
ti2ti_text_tokens,
|
| 2475 |
+
ti2ti_img_timesteps,
|
| 2476 |
+
ti2ti_text_timesteps,
|
| 2477 |
+
ti2ti_temperature,
|
| 2478 |
+
ti2ti_guidance,
|
| 2479 |
v2s_video_path,
|
| 2480 |
v2s_max_tokens,
|
| 2481 |
v2s_steps,
|
|
|
|
| 2616 |
)
|
| 2617 |
response = _render_image_message(status, image_result)
|
| 2618 |
display_user_raw = message or "[Image generation request]"
|
| 2619 |
+
elif mode == "Image Editing":
|
| 2620 |
image_result, status = app.run_i2i(
|
| 2621 |
message,
|
| 2622 |
i2i_image,
|
|
|
|
| 2626 |
)
|
| 2627 |
response = _render_image_message(status, image_result)
|
| 2628 |
display_user_raw = message or "[Image editing request]"
|
| 2629 |
+
else: # TI2TI
|
| 2630 |
+
image_result, text_result, status = app.run_ti2ti(
|
| 2631 |
+
message,
|
| 2632 |
+
ti2ti_image,
|
| 2633 |
+
ti2ti_text_tokens,
|
| 2634 |
+
ti2ti_img_timesteps,
|
| 2635 |
+
ti2ti_text_timesteps,
|
| 2636 |
+
ti2ti_temperature,
|
| 2637 |
+
ti2ti_guidance,
|
| 2638 |
+
)
|
| 2639 |
+
response = _render_image_text_message(status, image_result, text_result)
|
| 2640 |
+
display_user_raw = message or "[TI2TI request]"
|
| 2641 |
|
| 2642 |
if not response:
|
| 2643 |
status = f"Mode '{mode}' is not supported."
|
|
|
|
| 2691 |
i2i_timesteps,
|
| 2692 |
i2i_temperature,
|
| 2693 |
i2i_guidance,
|
| 2694 |
+
ti2ti_image,
|
| 2695 |
+
ti2ti_text_tokens,
|
| 2696 |
+
ti2ti_img_timesteps,
|
| 2697 |
+
ti2ti_text_timesteps,
|
| 2698 |
+
ti2ti_temperature,
|
| 2699 |
+
ti2ti_guidance,
|
| 2700 |
v2s_video,
|
| 2701 |
v2s_max_tokens,
|
| 2702 |
v2s_steps,
|
|
|
|
| 2731 |
gr.update(value=None),
|
| 2732 |
gr.update(value=None),
|
| 2733 |
gr.update(value=None),
|
| 2734 |
+
gr.update(value=None),
|
| 2735 |
)
|
| 2736 |
|
| 2737 |
clear_button.click(
|
|
|
|
| 2746 |
i2s_image,
|
| 2747 |
v2s_video,
|
| 2748 |
i2i_image,
|
| 2749 |
+
ti2ti_image,
|
| 2750 |
mmu_image_a,
|
| 2751 |
mmu_image_b,
|
| 2752 |
],
|
MMaDA/models/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (469 Bytes)
|
|
|
MMaDA/models/__pycache__/common_modules.cpython-310.pyc
DELETED
|
Binary file (10.2 kB)
|
|
|
MMaDA/models/__pycache__/configuration_emova_speech_tokenizer.cpython-310.pyc
DELETED
|
Binary file (9.62 kB)
|
|
|
MMaDA/models/__pycache__/configuration_llada.cpython-310.pyc
DELETED
|
Binary file (6.19 kB)
|
|
|
MMaDA/models/__pycache__/misc.cpython-310.pyc
DELETED
|
Binary file (1.49 kB)
|
|
|
MMaDA/models/__pycache__/modeling_emova_speech_tokenizer.cpython-310.pyc
DELETED
|
Binary file (3.34 kB)
|
|
|
MMaDA/models/__pycache__/modeling_llada.cpython-310.pyc
DELETED
|
Binary file (40.3 kB)
|
|
|
MMaDA/models/__pycache__/modeling_magvitv2.cpython-310.pyc
DELETED
|
Binary file (11.1 kB)
|
|
|
MMaDA/models/__pycache__/modeling_mmada.cpython-310.pyc
DELETED
|
Binary file (20.2 kB)
|
|
|
MMaDA/models/__pycache__/modeling_omada.cpython-310.pyc
DELETED
|
Binary file (31.9 kB)
|
|
|
MMaDA/models/__pycache__/modeling_utils.cpython-310.pyc
DELETED
|
Binary file (39.7 kB)
|
|
|
MMaDA/models/__pycache__/modeling_video_encoder.cpython-310.pyc
DELETED
|
Binary file (1.15 kB)
|
|
|
MMaDA/models/__pycache__/sampling.cpython-310.pyc
DELETED
|
Binary file (4.19 kB)
|
|
|
MMaDA/training/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (182 Bytes)
|
|
|
MMaDA/training/__pycache__/data.cpython-310.pyc
DELETED
|
Binary file (73 kB)
|
|
|
MMaDA/training/__pycache__/prompting_utils.cpython-310.pyc
DELETED
|
Binary file (35.3 kB)
|
|
|
MMaDA/training/__pycache__/utils.cpython-310.pyc
DELETED
|
Binary file (5.97 kB)
|
|
|
app.py
CHANGED
|
@@ -219,11 +219,35 @@ def _load_i2i_examples():
|
|
| 219 |
examples.append([str(img_path), instruction])
|
| 220 |
return examples
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# text-based examples
|
| 223 |
T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt")
|
| 224 |
CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt")
|
| 225 |
T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
|
| 226 |
I2I_EXAMPLES = _load_i2i_examples()
|
|
|
|
| 227 |
|
| 228 |
# audio / video / image examples
|
| 229 |
S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
|
|
@@ -419,6 +443,20 @@ def i2i_handler(instruction, image, timesteps, temperature, guidance):
|
|
| 419 |
)
|
| 420 |
return image_out, status
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
# ---------------------------
|
| 424 |
# Gradio UI (10 tabs + examples)
|
|
@@ -678,6 +716,45 @@ with gr.Blocks(
|
|
| 678 |
outputs=[i2i_image_out, i2i_status],
|
| 679 |
)
|
| 680 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
# ---- I2S ----
|
| 682 |
with gr.Tab("Image β Speech (I2S)"):
|
| 683 |
i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
|
|
|
|
| 219 |
examples.append([str(img_path), instruction])
|
| 220 |
return examples
|
| 221 |
|
| 222 |
+
def _load_ti2ti_examples():
|
| 223 |
+
"""Load TI2TI examples: pairs of source image + instruction text."""
|
| 224 |
+
d = ASSET_ROOT / "ti2ti"
|
| 225 |
+
if not d.exists():
|
| 226 |
+
return []
|
| 227 |
+
|
| 228 |
+
src_files = sorted(
|
| 229 |
+
[p for p in d.iterdir() if p.is_file() and p.name.endswith("_src.png")],
|
| 230 |
+
)
|
| 231 |
+
txt_files = {p.name.replace("_instr.txt", ""): p for p in d.iterdir() if p.is_file() and p.name.endswith("_instr.txt")}
|
| 232 |
+
|
| 233 |
+
examples = []
|
| 234 |
+
for src in src_files:
|
| 235 |
+
stem = src.name.replace("_src.png", "")
|
| 236 |
+
txt = txt_files.get(stem)
|
| 237 |
+
if not txt:
|
| 238 |
+
continue
|
| 239 |
+
instruction = txt.read_text(encoding="utf-8").strip()
|
| 240 |
+
if not instruction:
|
| 241 |
+
continue
|
| 242 |
+
examples.append([str(src), instruction])
|
| 243 |
+
return examples
|
| 244 |
+
|
| 245 |
# text-based examples
|
| 246 |
T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt")
|
| 247 |
CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt")
|
| 248 |
T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
|
| 249 |
I2I_EXAMPLES = _load_i2i_examples()
|
| 250 |
+
TI2TI_EXAMPLES = _load_ti2ti_examples()
|
| 251 |
|
| 252 |
# audio / video / image examples
|
| 253 |
S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
|
|
|
|
| 443 |
)
|
| 444 |
return image_out, status
|
| 445 |
|
| 446 |
+
@spaces.GPU
|
| 447 |
+
def ti2ti_handler(instruction, image, text_tokens, timesteps_image, timesteps_text, temperature, guidance):
|
| 448 |
+
app = get_app()
|
| 449 |
+
image_out, text_out, status = app.run_ti2ti(
|
| 450 |
+
instruction=instruction,
|
| 451 |
+
source_image=image,
|
| 452 |
+
text_tokens=int(text_tokens),
|
| 453 |
+
timesteps_image=int(timesteps_image),
|
| 454 |
+
timesteps_text=int(timesteps_text),
|
| 455 |
+
temperature=float(temperature),
|
| 456 |
+
guidance_scale=float(guidance),
|
| 457 |
+
)
|
| 458 |
+
return image_out, text_out, status
|
| 459 |
+
|
| 460 |
|
| 461 |
# ---------------------------
|
| 462 |
# Gradio UI (10 tabs + examples)
|
|
|
|
| 716 |
outputs=[i2i_image_out, i2i_status],
|
| 717 |
)
|
| 718 |
|
| 719 |
+
# ---- TI2TI ----
|
| 720 |
+
with gr.Tab("Text+Image β Text+Image (TI2TI)"):
|
| 721 |
+
ti2ti_image_in = gr.Image(type="pil", label="Source image", sources=["upload"])
|
| 722 |
+
ti2ti_instr = gr.Textbox(
|
| 723 |
+
label="Editing instruction",
|
| 724 |
+
lines=4,
|
| 725 |
+
placeholder="Describe how you want the image edited and what to say about it...",
|
| 726 |
+
)
|
| 727 |
+
ti2ti_image_out = gr.Image(label="Edited image")
|
| 728 |
+
ti2ti_text_out = gr.Textbox(label="Generated text", lines=4)
|
| 729 |
+
ti2ti_status = gr.Textbox(label="Status", interactive=False)
|
| 730 |
+
with gr.Accordion("Advanced settings", open=False):
|
| 731 |
+
ti2ti_text_tokens = gr.Slider(8, 256, value=64, step=4, label="Text placeholder tokens")
|
| 732 |
+
ti2ti_img_steps = gr.Slider(4, 128, value=64, step=2, label="Image timesteps")
|
| 733 |
+
ti2ti_text_steps = gr.Slider(4, 128, value=64, step=2, label="Text timesteps")
|
| 734 |
+
ti2ti_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
|
| 735 |
+
ti2ti_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
|
| 736 |
+
if TI2TI_EXAMPLES:
|
| 737 |
+
with gr.Accordion("Sample edits", open=False):
|
| 738 |
+
gr.Examples(
|
| 739 |
+
examples=TI2TI_EXAMPLES,
|
| 740 |
+
inputs=[ti2ti_image_in, ti2ti_instr],
|
| 741 |
+
examples_per_page=4,
|
| 742 |
+
)
|
| 743 |
+
ti2ti_btn = gr.Button("Generate edited image + text", variant="primary")
|
| 744 |
+
ti2ti_btn.click(
|
| 745 |
+
ti2ti_handler,
|
| 746 |
+
inputs=[
|
| 747 |
+
ti2ti_instr,
|
| 748 |
+
ti2ti_image_in,
|
| 749 |
+
ti2ti_text_tokens,
|
| 750 |
+
ti2ti_img_steps,
|
| 751 |
+
ti2ti_text_steps,
|
| 752 |
+
ti2ti_temperature,
|
| 753 |
+
ti2ti_guidance,
|
| 754 |
+
],
|
| 755 |
+
outputs=[ti2ti_image_out, ti2ti_text_out, ti2ti_status],
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
# ---- I2S ----
|
| 759 |
with gr.Tab("Image β Speech (I2S)"):
|
| 760 |
i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
|