jaeikkim commited on
Commit
333ef29
Β·
1 Parent(s): 88f06d8

Add TI2TI UI without binaries

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ MMaDA/inference/demo/ti2ti/
MMaDA/inference/__pycache__/common.cpython-310.pyc DELETED
Binary file (5.69 kB)
 
MMaDA/inference/gradio_multimodal_demo_inst.py CHANGED
@@ -495,6 +495,32 @@ def _load_i2i_examples():
495
  return examples
496
 
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  def _load_media_examples(subdir: str, suffixes):
499
  target_dir = DEMO_ROOT / subdir
500
  if not target_dir.exists():
@@ -510,6 +536,7 @@ T2S_EXAMPLES = _load_t2s_examples()
510
  CHAT_EXAMPLES = _load_chat_examples()
511
  T2I_EXAMPLES = _load_t2i_examples()
512
  I2I_EXAMPLES = _load_i2i_examples()
 
513
  S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
514
  V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
515
  S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"})
@@ -629,6 +656,33 @@ def _render_image_message(status: str, image: Optional[Image.Image]) -> str:
629
  return _render_response(status, image_html)
630
 
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  def _format_user_message(message: str) -> str:
633
  clean = html.escape(message or "")
634
  return clean.replace("\n", "<br>")
@@ -1180,6 +1234,146 @@ class OmadaDemo:
1180
  image = self._decode_image_tokens(gen_tokens[0])
1181
  return image, "Edited image generated."
1182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1183
  # ------------------------------------------------------------------
1184
  # Video-to-Speech
1185
  # ------------------------------------------------------------------
@@ -1866,7 +2060,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
1866
  group_to_modes = {
1867
  "Any β†’ Speech": ["Text β†’ Speech", "Speech β†’ Speech", "Video β†’ Speech", "Image β†’ Speech"],
1868
  "Any β†’ Text": ["Speech β†’ Text", "Video β†’ Text", "Chat", "MMU (2 Images β†’ Text)"],
1869
- "Image Generation": ["Text β†’ Image", "Image Editing"],
1870
  }
1871
  default_group = "Any β†’ Speech"
1872
  default_mode = group_to_modes[default_group][0]
@@ -1881,6 +2075,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
1881
  "MMU (2 Images β†’ Text)": "Ask a question about the two uploaded images.",
1882
  "Text β†’ Image": "Describe the image you want to generate...",
1883
  "Image Editing": "Describe how you want to edit the uploaded image...",
 
1884
  }
1885
  with gr.Row(elem_classes=["omada-layout"], equal_height=False):
1886
  with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]):
@@ -2075,6 +2270,27 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2075
  inputs=[chat_input],
2076
  examples_per_page=4,
2077
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2078
  with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel:
2079
  with gr.Group(elem_classes=["omada-card"]):
2080
  gr.Markdown("### Chat Controls")
@@ -2123,7 +2339,8 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2123
  show_v2t = group == "Any β†’ Text" and mode == "Video β†’ Text"
2124
  show_chat = group == "Any β†’ Text" and mode == "Chat"
2125
  show_mmu = group == "Any β†’ Text" and mode == "MMU (2 Images β†’ Text)"
2126
- show_image = group == "Image Generation"
 
2127
  placeholder = placeholder_map.get(mode, chat_input.placeholder)
2128
  image_mode_value = "Generation" if mode == "Text β†’ Image" else "Editing"
2129
  t2i_visible = show_image and mode == "Text β†’ Image"
@@ -2140,6 +2357,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2140
  gr.update(visible=show_chat),
2141
  gr.update(visible=show_mmu),
2142
  gr.update(visible=show_image),
 
2143
  image_mode_update,
2144
  gr.update(visible=t2i_visible),
2145
  gr.update(visible=i2i_visible),
@@ -2169,6 +2387,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2169
  chat_panel,
2170
  mmu_panel,
2171
  image_panel,
 
2172
  image_mode_selector,
2173
  t2i_settings,
2174
  i2i_settings,
@@ -2189,6 +2408,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2189
  chat_panel,
2190
  mmu_panel,
2191
  image_panel,
 
2192
  image_mode_selector,
2193
  t2i_settings,
2194
  i2i_settings,
@@ -2250,6 +2470,12 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2250
  i2i_timesteps,
2251
  i2i_temperature,
2252
  i2i_guidance,
 
 
 
 
 
 
2253
  v2s_video_path,
2254
  v2s_max_tokens,
2255
  v2s_steps,
@@ -2390,7 +2616,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2390
  )
2391
  response = _render_image_message(status, image_result)
2392
  display_user_raw = message or "[Image generation request]"
2393
- else: # Image Editing
2394
  image_result, status = app.run_i2i(
2395
  message,
2396
  i2i_image,
@@ -2400,6 +2626,18 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2400
  )
2401
  response = _render_image_message(status, image_result)
2402
  display_user_raw = message or "[Image editing request]"
 
 
 
 
 
 
 
 
 
 
 
 
2403
 
2404
  if not response:
2405
  status = f"Mode '{mode}' is not supported."
@@ -2453,6 +2691,12 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2453
  i2i_timesteps,
2454
  i2i_temperature,
2455
  i2i_guidance,
 
 
 
 
 
 
2456
  v2s_video,
2457
  v2s_max_tokens,
2458
  v2s_steps,
@@ -2487,6 +2731,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2487
  gr.update(value=None),
2488
  gr.update(value=None),
2489
  gr.update(value=None),
 
2490
  )
2491
 
2492
  clear_button.click(
@@ -2501,6 +2746,7 @@ def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optio
2501
  i2s_image,
2502
  v2s_video,
2503
  i2i_image,
 
2504
  mmu_image_a,
2505
  mmu_image_b,
2506
  ],
 
495
  return examples
496
 
497
 
498
+ def _load_ti2ti_examples():
499
+ """demo/ti2ti의 sample##_src.png + sample##_instr.txt μŒμ„ Examples둜 λ¬Άμ–΄μ€Œ."""
500
+ d = DEMO_ROOT / "ti2ti"
501
+ if not d.exists():
502
+ return []
503
+
504
+ src_files = sorted([p for p in d.iterdir() if p.is_file() and p.name.endswith("_src.png")])
505
+ txt_files = {
506
+ p.name.replace("_instr.txt", ""): p
507
+ for p in d.iterdir()
508
+ if p.is_file() and p.name.endswith("_instr.txt")
509
+ }
510
+
511
+ examples = []
512
+ for src in src_files:
513
+ stem = src.name.replace("_src.png", "")
514
+ txt = txt_files.get(stem)
515
+ if not txt:
516
+ continue
517
+ instruction = txt.read_text(encoding="utf-8").strip()
518
+ if not instruction:
519
+ continue
520
+ examples.append([str(src), instruction])
521
+ return examples
522
+
523
+
524
  def _load_media_examples(subdir: str, suffixes):
525
  target_dir = DEMO_ROOT / subdir
526
  if not target_dir.exists():
 
536
  CHAT_EXAMPLES = _load_chat_examples()
537
  T2I_EXAMPLES = _load_t2i_examples()
538
  I2I_EXAMPLES = _load_i2i_examples()
539
+ TI2TI_EXAMPLES = _load_ti2ti_examples()
540
  S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
541
  V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
542
  S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"})
 
656
  return _render_response(status, image_html)
657
 
658
 
659
+ def _render_image_text_message(status: str, image: Optional[Image.Image], text: str) -> str:
660
+ """Render combined text + image output for TI2TI."""
661
+ blocks = []
662
+ text_clean = (text or "").strip()
663
+ if text_clean:
664
+ safe_text = html.escape(text_clean).replace("\n", "<br>")
665
+ blocks.append(f"<div class='omada-response-block'>{safe_text}</div>")
666
+
667
+ if image is not None:
668
+ buffer = io.BytesIO()
669
+ try:
670
+ image.save(buffer, format="PNG")
671
+ encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
672
+ blocks.append(
673
+ "<div class='omada-response-block'>"
674
+ "<img src='data:image/png;base64,"
675
+ f"{encoded}"
676
+ "' alt='Generated image' style='max-width:100%;border-radius:12px;' />"
677
+ "</div>"
678
+ )
679
+ except Exception:
680
+ pass
681
+
682
+ body = "".join(blocks)
683
+ return _render_response(status, body if body else None)
684
+
685
+
686
  def _format_user_message(message: str) -> str:
687
  clean = html.escape(message or "")
688
  return clean.replace("\n", "<br>")
 
1234
  image = self._decode_image_tokens(gen_tokens[0])
1235
  return image, "Edited image generated."
1236
 
1237
+ # ------------------------------------------------------------------
1238
+ # Text+Image β†’ Text+Image (TI2TI)
1239
+ # ------------------------------------------------------------------
1240
+ def run_ti2ti(
1241
+ self,
1242
+ instruction: str,
1243
+ source_image: Optional[Image.Image],
1244
+ text_tokens: int,
1245
+ timesteps_image: int,
1246
+ timesteps_text: int,
1247
+ temperature: float,
1248
+ guidance_scale: float,
1249
+ ) -> Tuple[Optional[Image.Image], str, str]:
1250
+ instruction_clean = (instruction or "").strip()
1251
+ if source_image is None:
1252
+ return None, "", "Please upload a source image."
1253
+ if not instruction_clean:
1254
+ return None, "", "Provide an editing instruction for TI2TI."
1255
+
1256
+ try:
1257
+ src_tokens = self._prepare_image_tokens(source_image)
1258
+ except Exception as exc:
1259
+ return None, "", f"Failed to encode source image: {exc}"
1260
+
1261
+ text_tokens = max(4, min(int(text_tokens), self.max_text_len))
1262
+ prompt_ids = self.uni_prompting.text_tokenizer(instruction_clean)['input_ids']
1263
+ if isinstance(prompt_ids, list) and prompt_ids and isinstance(prompt_ids[0], list):
1264
+ prompt_ids = prompt_ids[0]
1265
+ if len(prompt_ids) == 0 or prompt_ids[0] != self.uni_prompting.text_tokenizer.bos_token_id:
1266
+ prompt_ids = [self.uni_prompting.text_tokenizer.bos_token_id] + prompt_ids
1267
+ prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
1268
+ prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
1269
+
1270
+ ti2ti_id = int(self.uni_prompting.sptids_dict['<|ti2ti|>'][0].item())
1271
+ soi_id = int(self.uni_prompting.sptids_dict['<|soi|>'][0].item())
1272
+ eoi_id = int(self.uni_prompting.sptids_dict['<|eoi|>'][0].item())
1273
+ pad_raw = getattr(self.uni_prompting, "pad_id", 0)
1274
+ pad_id = int(pad_raw if pad_raw is not None else 0)
1275
+
1276
+ img_placeholder = torch.full(
1277
+ (self.image_seq_len,),
1278
+ self.mask_token_id,
1279
+ dtype=torch.long,
1280
+ device=self.device,
1281
+ )
1282
+ text_placeholder = torch.full(
1283
+ (text_tokens,),
1284
+ self.mask_token_id,
1285
+ dtype=torch.long,
1286
+ device=self.device,
1287
+ )
1288
+
1289
+ src_flat = src_tokens.view(-1)
1290
+ prompt_len = prompt_tensor.numel()
1291
+ img_len = img_placeholder.numel()
1292
+ text_len = text_placeholder.numel()
1293
+
1294
+ prompt_start = 2 + src_flat.numel() + 1
1295
+ prompt_end = prompt_start + prompt_len
1296
+ img_start = prompt_end + 1
1297
+ img_end = img_start + img_len
1298
+ text_start = img_end + 1
1299
+ text_end = text_start + text_len
1300
+
1301
+ seq_parts = [
1302
+ torch.tensor([ti2ti_id, soi_id], device=self.device, dtype=torch.long),
1303
+ src_flat,
1304
+ torch.tensor([eoi_id], device=self.device, dtype=torch.long),
1305
+ prompt_tensor,
1306
+ torch.tensor([soi_id], device=self.device, dtype=torch.long),
1307
+ img_placeholder,
1308
+ torch.tensor([eoi_id], device=self.device, dtype=torch.long),
1309
+ text_placeholder,
1310
+ ]
1311
+ seq = torch.cat(seq_parts, dim=0).unsqueeze(0)
1312
+ attn = torch.ones_like(seq, dtype=torch.long, device=self.device)
1313
+
1314
+ uncond_seq = seq.clone()
1315
+ uncond_attn = attn.clone()
1316
+ uncond_seq[:, prompt_start:prompt_end] = pad_id
1317
+ uncond_attn[:, prompt_start:prompt_end] = 0
1318
+
1319
+ with torch.no_grad():
1320
+ filled_tokens, _ = self.model.ti2ti_generate(
1321
+ input_ids=seq.to(self.device),
1322
+ uncond_input_ids=uncond_seq.to(self.device),
1323
+ attention_mask=attn.to(self.device),
1324
+ uncond_attention_mask=uncond_attn.to(self.device),
1325
+ temperature=float(temperature),
1326
+ timesteps=int(timesteps_image),
1327
+ timesteps_text=int(timesteps_text),
1328
+ timesteps_image=int(timesteps_image),
1329
+ guidance_scale=float(guidance_scale),
1330
+ noise_schedule=self.image_noise_schedule,
1331
+ seq_len=self.image_seq_len,
1332
+ mask_token_id=self.mask_token_id,
1333
+ codebook_size=self.codebook_size,
1334
+ uni_prompting=self.uni_prompting,
1335
+ config=self.train_cfg,
1336
+ )
1337
+
1338
+ if filled_tokens is None:
1339
+ return None, "", "TI2TI generation failed."
1340
+
1341
+ filled_tokens = torch.clamp(
1342
+ filled_tokens,
1343
+ min=0,
1344
+ max=self.codebook_size + self.text_vocab_size - 1,
1345
+ )
1346
+ pred_img_tokens = filled_tokens[:, img_start:img_end] - self.text_vocab_size
1347
+ pred_img_tokens = torch.clamp(pred_img_tokens, min=0, max=self.codebook_size - 1)
1348
+ try:
1349
+ image_out = self._decode_image_tokens(pred_img_tokens[0])
1350
+ except Exception as exc:
1351
+ return None, "", f"Failed to decode generated image: {exc}"
1352
+
1353
+ text_slice = slice(text_start, min(text_end, filled_tokens.shape[1]))
1354
+ text_block = filled_tokens[:, text_slice]
1355
+ text_vocab = self.text_vocab_size
1356
+ mask_id = int(self.mask_token_id)
1357
+ eos_id = int(self.uni_prompting.text_tokenizer.eos_token_id)
1358
+ eot_id = int(self.uni_prompting.sptids_dict.get("<|eot_id|>", torch.tensor([eos_id], device=self.device))[0].item())
1359
+ pad_token_id = int(pad_id)
1360
+
1361
+ pred_texts = []
1362
+ for row in text_block:
1363
+ seq_list = []
1364
+ for t in row.tolist():
1365
+ if t in (pad_token_id, mask_id):
1366
+ continue
1367
+ if t == eos_id or t == eot_id:
1368
+ break
1369
+ if 0 <= t < text_vocab:
1370
+ seq_list.append(int(t))
1371
+ pred_texts.append(self.uni_prompting.text_tokenizer.decode(seq_list, skip_special_tokens=True))
1372
+ pred_text = pred_texts[0] if pred_texts else ""
1373
+
1374
+ status = "TI2TI generated image and text."
1375
+ return image_out, pred_text, status
1376
+
1377
  # ------------------------------------------------------------------
1378
  # Video-to-Speech
1379
  # ------------------------------------------------------------------
 
2060
  group_to_modes = {
2061
  "Any β†’ Speech": ["Text β†’ Speech", "Speech β†’ Speech", "Video β†’ Speech", "Image β†’ Speech"],
2062
  "Any β†’ Text": ["Speech β†’ Text", "Video β†’ Text", "Chat", "MMU (2 Images β†’ Text)"],
2063
+ "Image Generation": ["Text β†’ Image", "Image Editing", "Text+Image β†’ Text+Image (TI2TI)"],
2064
  }
2065
  default_group = "Any β†’ Speech"
2066
  default_mode = group_to_modes[default_group][0]
 
2075
  "MMU (2 Images β†’ Text)": "Ask a question about the two uploaded images.",
2076
  "Text β†’ Image": "Describe the image you want to generate...",
2077
  "Image Editing": "Describe how you want to edit the uploaded image...",
2078
+ "Text+Image β†’ Text+Image (TI2TI)": "Upload an image and describe how you want it edited and captioned.",
2079
  }
2080
  with gr.Row(elem_classes=["omada-layout"], equal_height=False):
2081
  with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]):
 
2270
  inputs=[chat_input],
2271
  examples_per_page=4,
2272
  )
2273
+ with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as ti2ti_panel:
2274
+ with gr.Group(elem_classes=["omada-card"]):
2275
+ gr.Markdown("### Text+Image β†’ Text+Image (TI2TI)")
2276
+ ti2ti_image = gr.Image(type="pil", label="Source image", sources=["upload"])
2277
+ with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]):
2278
+ ti2ti_text_tokens = gr.Slider(8, 256, value=64, label="Text placeholder tokens", step=4)
2279
+ with gr.Row():
2280
+ ti2ti_img_timesteps = gr.Slider(4, 128, value=64, label="Image timesteps", step=2)
2281
+ ti2ti_text_timesteps = gr.Slider(4, 128, value=64, label="Text timesteps", step=2)
2282
+ with gr.Row():
2283
+ ti2ti_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05)
2284
+ ti2ti_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1)
2285
+ if TI2TI_EXAMPLES:
2286
+ with gr.Group(elem_classes=["omada-card", "omada-examples-card"]):
2287
+ gr.Markdown("**Sample edits**")
2288
+ with gr.Column(elem_classes=["omada-examples"]):
2289
+ gr.Examples(
2290
+ examples=TI2TI_EXAMPLES,
2291
+ inputs=[ti2ti_image, chat_input],
2292
+ examples_per_page=4,
2293
+ )
2294
  with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel:
2295
  with gr.Group(elem_classes=["omada-card"]):
2296
  gr.Markdown("### Chat Controls")
 
2339
  show_v2t = group == "Any β†’ Text" and mode == "Video β†’ Text"
2340
  show_chat = group == "Any β†’ Text" and mode == "Chat"
2341
  show_mmu = group == "Any β†’ Text" and mode == "MMU (2 Images β†’ Text)"
2342
+ show_image = group == "Image Generation" and mode in ("Text β†’ Image", "Image Editing")
2343
+ show_ti2ti = group == "Image Generation" and mode == "Text+Image β†’ Text+Image (TI2TI)"
2344
  placeholder = placeholder_map.get(mode, chat_input.placeholder)
2345
  image_mode_value = "Generation" if mode == "Text β†’ Image" else "Editing"
2346
  t2i_visible = show_image and mode == "Text β†’ Image"
 
2357
  gr.update(visible=show_chat),
2358
  gr.update(visible=show_mmu),
2359
  gr.update(visible=show_image),
2360
+ gr.update(visible=show_ti2ti),
2361
  image_mode_update,
2362
  gr.update(visible=t2i_visible),
2363
  gr.update(visible=i2i_visible),
 
2387
  chat_panel,
2388
  mmu_panel,
2389
  image_panel,
2390
+ ti2ti_panel,
2391
  image_mode_selector,
2392
  t2i_settings,
2393
  i2i_settings,
 
2408
  chat_panel,
2409
  mmu_panel,
2410
  image_panel,
2411
+ ti2ti_panel,
2412
  image_mode_selector,
2413
  t2i_settings,
2414
  i2i_settings,
 
2470
  i2i_timesteps,
2471
  i2i_temperature,
2472
  i2i_guidance,
2473
+ ti2ti_image,
2474
+ ti2ti_text_tokens,
2475
+ ti2ti_img_timesteps,
2476
+ ti2ti_text_timesteps,
2477
+ ti2ti_temperature,
2478
+ ti2ti_guidance,
2479
  v2s_video_path,
2480
  v2s_max_tokens,
2481
  v2s_steps,
 
2616
  )
2617
  response = _render_image_message(status, image_result)
2618
  display_user_raw = message or "[Image generation request]"
2619
+ elif mode == "Image Editing":
2620
  image_result, status = app.run_i2i(
2621
  message,
2622
  i2i_image,
 
2626
  )
2627
  response = _render_image_message(status, image_result)
2628
  display_user_raw = message or "[Image editing request]"
2629
+ else: # TI2TI
2630
+ image_result, text_result, status = app.run_ti2ti(
2631
+ message,
2632
+ ti2ti_image,
2633
+ ti2ti_text_tokens,
2634
+ ti2ti_img_timesteps,
2635
+ ti2ti_text_timesteps,
2636
+ ti2ti_temperature,
2637
+ ti2ti_guidance,
2638
+ )
2639
+ response = _render_image_text_message(status, image_result, text_result)
2640
+ display_user_raw = message or "[TI2TI request]"
2641
 
2642
  if not response:
2643
  status = f"Mode '{mode}' is not supported."
 
2691
  i2i_timesteps,
2692
  i2i_temperature,
2693
  i2i_guidance,
2694
+ ti2ti_image,
2695
+ ti2ti_text_tokens,
2696
+ ti2ti_img_timesteps,
2697
+ ti2ti_text_timesteps,
2698
+ ti2ti_temperature,
2699
+ ti2ti_guidance,
2700
  v2s_video,
2701
  v2s_max_tokens,
2702
  v2s_steps,
 
2731
  gr.update(value=None),
2732
  gr.update(value=None),
2733
  gr.update(value=None),
2734
+ gr.update(value=None),
2735
  )
2736
 
2737
  clear_button.click(
 
2746
  i2s_image,
2747
  v2s_video,
2748
  i2i_image,
2749
+ ti2ti_image,
2750
  mmu_image_a,
2751
  mmu_image_b,
2752
  ],
MMaDA/models/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (469 Bytes)
 
MMaDA/models/__pycache__/common_modules.cpython-310.pyc DELETED
Binary file (10.2 kB)
 
MMaDA/models/__pycache__/configuration_emova_speech_tokenizer.cpython-310.pyc DELETED
Binary file (9.62 kB)
 
MMaDA/models/__pycache__/configuration_llada.cpython-310.pyc DELETED
Binary file (6.19 kB)
 
MMaDA/models/__pycache__/misc.cpython-310.pyc DELETED
Binary file (1.49 kB)
 
MMaDA/models/__pycache__/modeling_emova_speech_tokenizer.cpython-310.pyc DELETED
Binary file (3.34 kB)
 
MMaDA/models/__pycache__/modeling_llada.cpython-310.pyc DELETED
Binary file (40.3 kB)
 
MMaDA/models/__pycache__/modeling_magvitv2.cpython-310.pyc DELETED
Binary file (11.1 kB)
 
MMaDA/models/__pycache__/modeling_mmada.cpython-310.pyc DELETED
Binary file (20.2 kB)
 
MMaDA/models/__pycache__/modeling_omada.cpython-310.pyc DELETED
Binary file (31.9 kB)
 
MMaDA/models/__pycache__/modeling_utils.cpython-310.pyc DELETED
Binary file (39.7 kB)
 
MMaDA/models/__pycache__/modeling_video_encoder.cpython-310.pyc DELETED
Binary file (1.15 kB)
 
MMaDA/models/__pycache__/sampling.cpython-310.pyc DELETED
Binary file (4.19 kB)
 
MMaDA/training/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (182 Bytes)
 
MMaDA/training/__pycache__/data.cpython-310.pyc DELETED
Binary file (73 kB)
 
MMaDA/training/__pycache__/prompting_utils.cpython-310.pyc DELETED
Binary file (35.3 kB)
 
MMaDA/training/__pycache__/utils.cpython-310.pyc DELETED
Binary file (5.97 kB)
 
app.py CHANGED
@@ -219,11 +219,35 @@ def _load_i2i_examples():
219
  examples.append([str(img_path), instruction])
220
  return examples
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  # text-based examples
223
  T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt")
224
  CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt")
225
  T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
226
  I2I_EXAMPLES = _load_i2i_examples()
 
227
 
228
  # audio / video / image examples
229
  S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
@@ -419,6 +443,20 @@ def i2i_handler(instruction, image, timesteps, temperature, guidance):
419
  )
420
  return image_out, status
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  # ---------------------------
424
  # Gradio UI (10 tabs + examples)
@@ -678,6 +716,45 @@ with gr.Blocks(
678
  outputs=[i2i_image_out, i2i_status],
679
  )
680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  # ---- I2S ----
682
  with gr.Tab("Image β†’ Speech (I2S)"):
683
  i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
 
219
  examples.append([str(img_path), instruction])
220
  return examples
221
 
222
+ def _load_ti2ti_examples():
223
+ """Load TI2TI examples: pairs of source image + instruction text."""
224
+ d = ASSET_ROOT / "ti2ti"
225
+ if not d.exists():
226
+ return []
227
+
228
+ src_files = sorted(
229
+ [p for p in d.iterdir() if p.is_file() and p.name.endswith("_src.png")],
230
+ )
231
+ txt_files = {p.name.replace("_instr.txt", ""): p for p in d.iterdir() if p.is_file() and p.name.endswith("_instr.txt")}
232
+
233
+ examples = []
234
+ for src in src_files:
235
+ stem = src.name.replace("_src.png", "")
236
+ txt = txt_files.get(stem)
237
+ if not txt:
238
+ continue
239
+ instruction = txt.read_text(encoding="utf-8").strip()
240
+ if not instruction:
241
+ continue
242
+ examples.append([str(src), instruction])
243
+ return examples
244
+
245
  # text-based examples
246
  T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt")
247
  CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt")
248
  T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
249
  I2I_EXAMPLES = _load_i2i_examples()
250
+ TI2TI_EXAMPLES = _load_ti2ti_examples()
251
 
252
  # audio / video / image examples
253
  S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
 
443
  )
444
  return image_out, status
445
 
446
+ @spaces.GPU
447
+ def ti2ti_handler(instruction, image, text_tokens, timesteps_image, timesteps_text, temperature, guidance):
448
+ app = get_app()
449
+ image_out, text_out, status = app.run_ti2ti(
450
+ instruction=instruction,
451
+ source_image=image,
452
+ text_tokens=int(text_tokens),
453
+ timesteps_image=int(timesteps_image),
454
+ timesteps_text=int(timesteps_text),
455
+ temperature=float(temperature),
456
+ guidance_scale=float(guidance),
457
+ )
458
+ return image_out, text_out, status
459
+
460
 
461
  # ---------------------------
462
  # Gradio UI (10 tabs + examples)
 
716
  outputs=[i2i_image_out, i2i_status],
717
  )
718
 
719
+ # ---- TI2TI ----
720
+ with gr.Tab("Text+Image β†’ Text+Image (TI2TI)"):
721
+ ti2ti_image_in = gr.Image(type="pil", label="Source image", sources=["upload"])
722
+ ti2ti_instr = gr.Textbox(
723
+ label="Editing instruction",
724
+ lines=4,
725
+ placeholder="Describe how you want the image edited and what to say about it...",
726
+ )
727
+ ti2ti_image_out = gr.Image(label="Edited image")
728
+ ti2ti_text_out = gr.Textbox(label="Generated text", lines=4)
729
+ ti2ti_status = gr.Textbox(label="Status", interactive=False)
730
+ with gr.Accordion("Advanced settings", open=False):
731
+ ti2ti_text_tokens = gr.Slider(8, 256, value=64, step=4, label="Text placeholder tokens")
732
+ ti2ti_img_steps = gr.Slider(4, 128, value=64, step=2, label="Image timesteps")
733
+ ti2ti_text_steps = gr.Slider(4, 128, value=64, step=2, label="Text timesteps")
734
+ ti2ti_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
735
+ ti2ti_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
736
+ if TI2TI_EXAMPLES:
737
+ with gr.Accordion("Sample edits", open=False):
738
+ gr.Examples(
739
+ examples=TI2TI_EXAMPLES,
740
+ inputs=[ti2ti_image_in, ti2ti_instr],
741
+ examples_per_page=4,
742
+ )
743
+ ti2ti_btn = gr.Button("Generate edited image + text", variant="primary")
744
+ ti2ti_btn.click(
745
+ ti2ti_handler,
746
+ inputs=[
747
+ ti2ti_instr,
748
+ ti2ti_image_in,
749
+ ti2ti_text_tokens,
750
+ ti2ti_img_steps,
751
+ ti2ti_text_steps,
752
+ ti2ti_temperature,
753
+ ti2ti_guidance,
754
+ ],
755
+ outputs=[ti2ti_image_out, ti2ti_text_out, ti2ti_status],
756
+ )
757
+
758
  # ---- I2S ----
759
  with gr.Tab("Image β†’ Speech (I2S)"):
760
  i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])