jake commited on
Commit
06b0a1f
Β·
1 Parent(s): df89a6a
Files changed (1) hide show
  1. app.py +192 -337
app.py CHANGED
@@ -5,17 +5,6 @@ ZeroGPU-friendly Gradio entrypoint for OMada demo.
5
  - Instantiates OmadaDemo once (global)
6
  - Exposes 10 modalities via Gradio tabs
7
  - Uses @spaces.GPU only on inference handlers so GPU is allocated per request
8
-
9
- Environment overrides:
10
- MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
11
- MODEL_REVISION (default: main)
12
- ASSET_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion-assets)
13
- ASSET_REVISION (default: main)
14
- STYLE_REPO_ID (default: jaeikkim/aidas-style-centroid)
15
- STYLE_REVISION (default: main)
16
- HF_TOKEN (optional, for private model/dataset)
17
- TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
18
- DEVICE (default: cuda)
19
  """
20
 
21
  import os
@@ -49,8 +38,6 @@ if str(EMOVA_ROOT) not in sys.path:
49
  def ensure_hf_hub(target: str = "0.36.0"):
50
  """
51
  Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
52
-
53
- The Spaces base image may pull in a newer version via gradio, so we pin it.
54
  """
55
  try:
56
  import huggingface_hub as hub
@@ -80,7 +67,7 @@ snapshot_download = ensure_hf_hub().snapshot_download
80
 
81
 
82
  # ---------------------------
83
- # Imports from OMada demo
84
  # ---------------------------
85
 
86
  from inference.gradio_multimodal_demo_inst import ( # noqa: E402
@@ -153,8 +140,6 @@ def download_checkpoint() -> Path:
153
  )
154
  )
155
 
156
- # If snapshot itself is unwrapped_model, return it; otherwise look for nested dir,
157
- # and finally alias via symlink.
158
  if snapshot_path.name == "unwrapped_model":
159
  return snapshot_path
160
 
@@ -169,82 +154,57 @@ def download_checkpoint() -> Path:
169
 
170
 
171
  # ---------------------------
172
- # Assets & examples from HF dataset
173
  # ---------------------------
174
 
175
  ASSET_ROOT = download_assets()
176
- DEMO_ROOT = ASSET_ROOT / "demo"
177
-
178
- LOGO_PATH = DEMO_ROOT / "logo.png"
179
- T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt"
180
- CHAT_TEXT_PATH = DEMO_ROOT / "chat" / "text.txt"
181
- T2I_TEXT_PATH = DEMO_ROOT / "t2i" / "text.txt"
182
-
183
 
184
  def _load_text_examples(path: Path):
185
  if not path.exists():
186
  return []
187
- try:
188
- lines = [
189
- line.strip()
190
- for line in path.read_text(encoding="utf-8").splitlines()
191
- if line.strip()
192
- ]
193
- except Exception:
194
- return []
195
- return [[line] for line in lines]
196
 
197
 
198
  def _load_media_examples(subdir: str, suffixes):
199
- d = DEMO_ROOT / subdir
200
  if not d.exists():
201
  return []
202
- examples = []
203
  for p in sorted(d.iterdir()):
204
  if p.is_file() and p.suffix.lower() in suffixes:
205
- examples.append([str(p)])
206
- return examples
207
-
208
-
209
- # ν…μŠ€νŠΈ 기반 예제
210
- T2S_EXAMPLES = _load_text_examples(T2S_TEXT_PATH)
211
- CHAT_EXAMPLES = _load_text_examples(CHAT_TEXT_PATH)
212
- T2I_EXAMPLES = _load_text_examples(T2I_TEXT_PATH)
213
-
214
- # μ˜€λ””μ˜€ / λΉ„λ””μ˜€ / 이미지 예제
215
- _AUDIO_SUFFIXES = {".wav", ".mp3", ".flac", ".ogg"}
216
- _VIDEO_SUFFIXES = {".mp4", ".mov", ".avi", ".webm"}
217
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"}
218
-
219
- S2T_EXAMPLES = _load_media_examples("s2t", _AUDIO_SUFFIXES)
220
- V2T_EXAMPLES = _load_media_examples("v2t", _VIDEO_SUFFIXES)
221
- S2S_EXAMPLES = _load_media_examples("s2s", _AUDIO_SUFFIXES)
222
- if not S2S_EXAMPLES and S2T_EXAMPLES:
223
- S2S_EXAMPLES = S2T_EXAMPLES[: min(4, len(S2T_EXAMPLES))]
224
-
225
- V2S_EXAMPLES = _load_media_examples("v2s", _VIDEO_SUFFIXES)
226
- if not V2S_EXAMPLES and V2T_EXAMPLES:
227
- V2S_EXAMPLES = V2T_EXAMPLES[: min(4, len(V2T_EXAMPLES))]
228
-
229
- I2S_EXAMPLES = _load_media_examples("i2s", _IMAGE_SUFFIXES)
230
-
231
- # MMU: 2 images + question
232
- MMU_DIR = DEMO_ROOT / "mmu"
233
- MMU_EXAMPLES = []
234
- if MMU_DIR.exists():
235
- mmu_imgs = [
236
- p for p in sorted(MMU_DIR.iterdir())
237
- if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
238
- ]
239
- if len(mmu_imgs) >= 2:
240
- MMU_EXAMPLES = [[
241
- str(mmu_imgs[0]),
242
- str(mmu_imgs[1]),
243
- "What are the differences between the two images?"
244
- ]]
245
-
246
- # i2sκ°€ μ—†κ³  mmu μ˜ˆμ œκ°€ 있으면, 첫 번째 이미지λ₯Ό 이미지 예제둜 μž¬μ‚¬μš©
247
  if not I2S_EXAMPLES and MMU_EXAMPLES:
 
248
  I2S_EXAMPLES = [[MMU_EXAMPLES[0][0]]]
249
 
250
 
@@ -260,9 +220,7 @@ def get_app() -> OmadaDemo:
260
  if APP is not None:
261
  return APP
262
 
263
- # Download ckpt + style centroids once
264
  ckpt_dir = download_checkpoint()
265
- style_root = download_style()
266
 
267
  # Wire style centroids to expected locations
268
  style_targets = [
@@ -276,19 +234,15 @@ def get_app() -> OmadaDemo:
276
  for starget in style_targets:
277
  if not starget.exists():
278
  starget.parent.mkdir(parents=True, exist_ok=True)
279
- starget.symlink_to(style_root, target_is_directory=True)
280
 
281
- # Choose train config
282
  default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
283
  legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
284
  train_config = os.getenv("TRAIN_CONFIG_PATH")
285
  if not train_config:
286
  train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
287
 
288
- # Device: in ZeroGPU environment, "cuda" is virtualized and only actually
289
- # attached inside @spaces.GPU handlers.
290
  device = os.getenv("DEVICE", "cuda")
291
-
292
  APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
293
  return APP
294
 
@@ -296,20 +250,9 @@ def get_app() -> OmadaDemo:
296
  # ---------------------------
297
  # ZeroGPU-wrapped handlers
298
  # ---------------------------
299
-
300
  @spaces.GPU
301
- def t2s_handler(
302
- text,
303
- max_tokens,
304
- steps,
305
- block_len,
306
- temperature,
307
- cfg_scale,
308
- gender,
309
- emotion,
310
- speed,
311
- pitch,
312
- ):
313
  app = get_app()
314
  audio, status = app.run_t2s(
315
  text=text,
@@ -325,16 +268,8 @@ def t2s_handler(
325
  )
326
  return audio, status
327
 
328
-
329
  @spaces.GPU
330
- def s2s_handler(
331
- audio_path,
332
- max_tokens,
333
- steps,
334
- block_len,
335
- temperature,
336
- cfg_scale,
337
- ):
338
  app = get_app()
339
  audio, status = app.run_s2s(
340
  audio_path=audio_path,
@@ -346,15 +281,8 @@ def s2s_handler(
346
  )
347
  return audio, status
348
 
349
-
350
  @spaces.GPU
351
- def s2t_handler(
352
- audio_path,
353
- steps,
354
- block_len,
355
- max_tokens,
356
- remasking,
357
- ):
358
  app = get_app()
359
  text, status = app.run_s2t(
360
  audio_path=audio_path,
@@ -365,14 +293,8 @@ def s2t_handler(
365
  )
366
  return text, status
367
 
368
-
369
  @spaces.GPU
370
- def v2t_handler(
371
- video,
372
- steps,
373
- block_len,
374
- max_tokens,
375
- ):
376
  app = get_app()
377
  text, status = app.run_v2t(
378
  video_path=video,
@@ -382,17 +304,8 @@ def v2t_handler(
382
  )
383
  return text, status
384
 
385
-
386
  @spaces.GPU
387
- def v2s_handler(
388
- video,
389
- message,
390
- max_tokens,
391
- steps,
392
- block_len,
393
- temperature,
394
- cfg_scale,
395
- ):
396
  app = get_app()
397
  audio, status = app.run_v2s(
398
  video_path=video,
@@ -405,17 +318,8 @@ def v2s_handler(
405
  )
406
  return audio, status
407
 
408
-
409
  @spaces.GPU
410
- def i2s_handler(
411
- image,
412
- message,
413
- max_tokens,
414
- steps,
415
- block_len,
416
- temperature,
417
- cfg_scale,
418
- ):
419
  app = get_app()
420
  audio, status = app.run_i2s(
421
  image=image,
@@ -428,15 +332,8 @@ def i2s_handler(
428
  )
429
  return audio, status
430
 
431
-
432
  @spaces.GPU
433
- def chat_handler(
434
- message,
435
- max_tokens,
436
- steps,
437
- block_len,
438
- temperature,
439
- ):
440
  app = get_app()
441
  text, status = app.run_chat(
442
  message=message,
@@ -447,17 +344,8 @@ def chat_handler(
447
  )
448
  return text, status
449
 
450
-
451
  @spaces.GPU
452
- def mmu_handler(
453
- image_a,
454
- image_b,
455
- question,
456
- max_tokens,
457
- steps,
458
- block_len,
459
- temperature,
460
- ):
461
  app = get_app()
462
  text, status = app.run_mmu_dual(
463
  image_a=image_a,
@@ -470,14 +358,8 @@ def mmu_handler(
470
  )
471
  return text, status
472
 
473
-
474
  @spaces.GPU
475
- def t2i_handler(
476
- prompt,
477
- timesteps,
478
- temperature,
479
- guidance,
480
- ):
481
  app = get_app()
482
  image, status = app.run_t2i(
483
  prompt=prompt,
@@ -487,15 +369,8 @@ def t2i_handler(
487
  )
488
  return image, status
489
 
490
-
491
  @spaces.GPU
492
- def i2i_handler(
493
- instruction,
494
- image,
495
- timesteps,
496
- temperature,
497
- guidance,
498
- ):
499
  app = get_app()
500
  image_out, status = app.run_i2i(
501
  instruction=instruction,
@@ -508,32 +383,31 @@ def i2i_handler(
508
 
509
 
510
  # ---------------------------
511
- # Gradio UI (10 tabs)
512
  # ---------------------------
513
 
514
  theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
515
 
516
  with gr.Blocks(
517
- title="AIDAS Lab @ SNU - OMni-modal Diffusion",
518
  css=CUSTOM_CSS,
519
  theme=theme,
520
  js=FORCE_LIGHT_MODE_JS,
521
  ) as demo:
522
- # 둜고 (있으면)
523
- if LOGO_PATH.exists():
524
- gr.Image(
525
- value=str(LOGO_PATH),
526
- show_label=False,
527
- height=140,
528
- interactive=False,
 
 
 
 
529
  )
530
 
531
- gr.Markdown(
532
- "## Omni-modal Diffusion Foundation Model\n"
533
- "### AIDAS Lab @ SNU"
534
- )
535
-
536
- # ---------- T2S ----------
537
  with gr.Tab("Text β†’ Speech (T2S)"):
538
  with gr.Row():
539
  t2s_text = gr.Textbox(
@@ -555,6 +429,13 @@ with gr.Blocks(
555
  with gr.Row():
556
  t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed")
557
  t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch")
 
 
 
 
 
 
 
558
  t2s_btn = gr.Button("Generate speech", variant="primary")
559
  t2s_btn.click(
560
  t2s_handler,
@@ -573,15 +454,7 @@ with gr.Blocks(
573
  outputs=[t2s_audio, t2s_status],
574
  )
575
 
576
- if T2S_EXAMPLES:
577
- gr.Markdown("**Sample prompts**")
578
- gr.Examples(
579
- examples=T2S_EXAMPLES,
580
- inputs=[t2s_text],
581
- examples_per_page=4,
582
- )
583
-
584
- # ---------- S2S ----------
585
  with gr.Tab("Speech β†’ Speech (S2S)"):
586
  s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
587
  s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
@@ -592,6 +465,13 @@ with gr.Blocks(
592
  s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
593
  s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature")
594
  s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale")
 
 
 
 
 
 
 
595
  s2s_btn = gr.Button("Generate reply speech", variant="primary")
596
  s2s_btn.click(
597
  s2s_handler,
@@ -606,15 +486,7 @@ with gr.Blocks(
606
  outputs=[s2s_audio_out, s2s_status],
607
  )
608
 
609
- if S2S_EXAMPLES:
610
- gr.Markdown("**Sample S2S clips**")
611
- gr.Examples(
612
- examples=S2S_EXAMPLES,
613
- inputs=[s2s_audio_in],
614
- examples_per_page=4,
615
- )
616
-
617
- # ---------- S2T ----------
618
  with gr.Tab("Speech β†’ Text (S2T)"):
619
  s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
620
  s2t_text_out = gr.Textbox(label="Transcription", lines=4)
@@ -628,6 +500,13 @@ with gr.Blocks(
628
  value="low_confidence",
629
  label="Remasking strategy",
630
  )
 
 
 
 
 
 
 
631
  s2t_btn = gr.Button("Transcribe", variant="primary")
632
  s2t_btn.click(
633
  s2t_handler,
@@ -635,15 +514,7 @@ with gr.Blocks(
635
  outputs=[s2t_text_out, s2t_status],
636
  )
637
 
638
- if S2T_EXAMPLES:
639
- gr.Markdown("**Sample S2T clips**")
640
- gr.Examples(
641
- examples=S2T_EXAMPLES,
642
- inputs=[s2t_audio_in],
643
- examples_per_page=4,
644
- )
645
-
646
- # ---------- V2T ----------
647
  with gr.Tab("Video β†’ Text (V2T)"):
648
  v2t_video_in = gr.Video(
649
  label="Upload or record video",
@@ -656,6 +527,13 @@ with gr.Blocks(
656
  v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps")
657
  v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
658
  v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens")
 
 
 
 
 
 
 
659
  v2t_btn = gr.Button("Generate caption", variant="primary")
660
  v2t_btn.click(
661
  v2t_handler,
@@ -663,15 +541,7 @@ with gr.Blocks(
663
  outputs=[v2t_text_out, v2t_status],
664
  )
665
 
666
- if V2T_EXAMPLES:
667
- gr.Markdown("**Sample videos**")
668
- gr.Examples(
669
- examples=V2T_EXAMPLES,
670
- inputs=[v2t_video_in],
671
- examples_per_page=4,
672
- )
673
-
674
- # ---------- V2S ----------
675
  with gr.Tab("Video β†’ Speech (V2S)"):
676
  v2s_video_in = gr.Video(
677
  label="Upload or record video",
@@ -690,6 +560,7 @@ with gr.Blocks(
690
  v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
691
  v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
692
  v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
 
693
  v2s_btn = gr.Button("Generate speech from video", variant="primary")
694
  v2s_btn.click(
695
  v2s_handler,
@@ -705,100 +576,7 @@ with gr.Blocks(
705
  outputs=[v2s_audio_out, v2s_status],
706
  )
707
 
708
- if V2S_EXAMPLES:
709
- gr.Markdown("**Sample videos**")
710
- gr.Examples(
711
- examples=V2S_EXAMPLES,
712
- inputs=[v2s_video_in],
713
- examples_per_page=4,
714
- )
715
-
716
- # ---------- T2I ----------
717
- with gr.Tab("Text β†’ Image (T2I)"):
718
- t2i_prompt = gr.Textbox(
719
- label="Prompt",
720
- lines=4,
721
- placeholder="Describe the image you want to generate...",
722
- )
723
- t2i_image_out = gr.Image(label="Generated image")
724
- t2i_status = gr.Textbox(label="Status", interactive=False)
725
- with gr.Accordion("Advanced settings", open=False):
726
- t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
727
- t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
728
- t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
729
- t2i_btn = gr.Button("Generate image", variant="primary")
730
- t2i_btn.click(
731
- t2i_handler,
732
- inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
733
- outputs=[t2i_image_out, t2i_status],
734
- )
735
-
736
- if T2I_EXAMPLES:
737
- gr.Markdown("**Sample prompts**")
738
- gr.Examples(
739
- examples=T2I_EXAMPLES,
740
- inputs=[t2i_prompt],
741
- examples_per_page=4,
742
- )
743
-
744
- # ---------- I2I ----------
745
- with gr.Tab("Image Editing (I2I)"):
746
- i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
747
- i2i_instr = gr.Textbox(
748
- label="Editing instruction",
749
- lines=4,
750
- placeholder="Describe how you want to edit the image...",
751
- )
752
- i2i_image_out = gr.Image(label="Edited image")
753
- i2i_status = gr.Textbox(label="Status", interactive=False)
754
- with gr.Accordion("Advanced settings", open=False):
755
- i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
756
- i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
757
- i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
758
- i2i_btn = gr.Button("Apply edit", variant="primary")
759
- i2i_btn.click(
760
- i2i_handler,
761
- inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
762
- outputs=[i2i_image_out, i2i_status],
763
- )
764
-
765
- # ---------- Chat ----------
766
- with gr.Tab("Text Chat"):
767
- chat_in = gr.Textbox(
768
- label="Message",
769
- lines=4,
770
- placeholder="Ask anything. The model will reply in text.",
771
- )
772
- chat_out = gr.Textbox(label="Assistant reply", lines=6)
773
- chat_status = gr.Textbox(label="Status", interactive=False)
774
- with gr.Accordion("Advanced settings", open=False):
775
- chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens")
776
- chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps")
777
- chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
778
- chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature")
779
- chat_btn = gr.Button("Send", variant="primary")
780
- chat_btn.click(
781
- chat_handler,
782
- inputs=[
783
- chat_in,
784
- chat_max_tokens,
785
- chat_steps,
786
- chat_block,
787
- chat_temperature_slider,
788
- ],
789
- outputs=[chat_out, chat_status],
790
- )
791
-
792
- if CHAT_EXAMPLES:
793
- gr.Markdown("**Sample prompts**")
794
- gr.Examples(
795
- examples=CHAT_EXAMPLES,
796
- inputs=[chat_in],
797
- examples_per_page=4,
798
- )
799
-
800
-
801
- # ---------- I2S ----------
802
  with gr.Tab("Image β†’ Speech (I2S)"):
803
  i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
804
  i2s_prompt = gr.Textbox(
@@ -813,6 +591,13 @@ with gr.Blocks(
813
  i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
814
  i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
815
  i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
 
 
 
 
 
 
 
816
  i2s_btn = gr.Button("Generate spoken description", variant="primary")
817
  i2s_btn.click(
818
  i2s_handler,
@@ -828,16 +613,41 @@ with gr.Blocks(
828
  outputs=[i2s_audio_out, i2s_status],
829
  )
830
 
831
- if I2S_EXAMPLES:
832
- gr.Markdown("**Sample images**")
833
- gr.Examples(
834
- examples=I2S_EXAMPLES,
835
- inputs=[i2s_image_in],
836
- examples_per_page=4,
837
- )
838
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
 
840
- # ---------- MMU ----------
841
  with gr.Tab("MMU (2 images β†’ text)"):
842
  mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
843
  mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
@@ -853,6 +663,13 @@ with gr.Blocks(
853
  mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
854
  mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
855
  mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature")
 
 
 
 
 
 
 
856
  mmu_btn = gr.Button("Answer about the two images", variant="primary")
857
  mmu_btn.click(
858
  mmu_handler,
@@ -868,16 +685,54 @@ with gr.Blocks(
868
  outputs=[mmu_answer, mmu_status],
869
  )
870
 
871
- if MMU_EXAMPLES:
872
- gr.Markdown("**Sample MMU example**")
873
- gr.Examples(
874
- examples=MMU_EXAMPLES,
875
- inputs=[mmu_img_a, mmu_img_b, mmu_question],
876
- examples_per_page=1,
877
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
 
879
- # I2IλŠ” 별도 예제 ν…μŠ€νŠΈ/이미지 ꡬ쑰가 μ• λ§€ν•΄μ„œ 일단 μƒλž΅
880
- # (ν•„μš”ν•˜λ©΄ demo/i2i_prompt.txt + demo/i2i_images/ 둜 λ‚˜λˆ μ„œ λ„£κ³  wiring ν•˜λ©΄ 됨)
881
 
882
  if __name__ == "__main__":
883
  demo.launch()
 
5
  - Instantiates OmadaDemo once (global)
6
  - Exposes 10 modalities via Gradio tabs
7
  - Uses @spaces.GPU only on inference handlers so GPU is allocated per request
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  import os
 
38
  def ensure_hf_hub(target: str = "0.36.0"):
39
  """
40
  Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
 
 
41
  """
42
  try:
43
  import huggingface_hub as hub
 
67
 
68
 
69
  # ---------------------------
70
+ # OMada demo imports
71
  # ---------------------------
72
 
73
  from inference.gradio_multimodal_demo_inst import ( # noqa: E402
 
140
  )
141
  )
142
 
 
 
143
  if snapshot_path.name == "unwrapped_model":
144
  return snapshot_path
145
 
 
154
 
155
 
156
  # ---------------------------
157
+ # Assets (for examples + logo)
158
  # ---------------------------
159
 
160
  ASSET_ROOT = download_assets()
161
+ STYLE_ROOT = download_style()
162
+ LOGO_PATH = ASSET_ROOT / "logo.png" # optional
 
 
 
 
 
163
 
164
  def _load_text_examples(path: Path):
165
  if not path.exists():
166
  return []
167
+ lines = [
168
+ ln.strip()
169
+ for ln in path.read_text(encoding="utf-8").splitlines()
170
+ if ln.strip()
171
+ ]
172
+ return [[ln] for ln in lines]
 
 
 
173
 
174
 
175
  def _load_media_examples(subdir: str, suffixes):
176
+ d = ASSET_ROOT / subdir
177
  if not d.exists():
178
  return []
179
+ ex = []
180
  for p in sorted(d.iterdir()):
181
  if p.is_file() and p.suffix.lower() in suffixes:
182
+ ex.append([str(p)])
183
+ return ex
184
+
185
+
186
+ # text-based examples
187
+ T2S_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2s" / "text.txt")
188
+ CHAT_EXAMPLES = _load_text_examples(ASSET_ROOT / "chat" / "text.txt")
189
+ T2I_EXAMPLES = _load_text_examples(ASSET_ROOT / "t2i" / "text.txt")
190
+
191
+ # audio / video / image examples
192
+ S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"})
193
+ S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"})
194
+ V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"})
195
+
196
+ # MMU images (and fallback for I2S)
197
+ MMU_IMAGE_A = ASSET_ROOT / "mmu" / "1.jpg"
198
+ MMU_IMAGE_B = ASSET_ROOT / "mmu" / "2.jpg"
199
+ if MMU_IMAGE_A.exists() and MMU_IMAGE_B.exists():
200
+ MMU_EXAMPLES = [[str(MMU_IMAGE_A), str(MMU_IMAGE_B),
201
+ "What are the differences in coloring and physical features between animal1 and animal2 in the bird images?"]]
202
+ else:
203
+ MMU_EXAMPLES = []
204
+
205
+ I2S_EXAMPLES = _load_media_examples("i2s", {".png", ".jpg", ".jpeg", ".webp"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  if not I2S_EXAMPLES and MMU_EXAMPLES:
207
+ # use image A from MMU as sample I2S input
208
  I2S_EXAMPLES = [[MMU_EXAMPLES[0][0]]]
209
 
210
 
 
220
  if APP is not None:
221
  return APP
222
 
 
223
  ckpt_dir = download_checkpoint()
 
224
 
225
  # Wire style centroids to expected locations
226
  style_targets = [
 
234
  for starget in style_targets:
235
  if not starget.exists():
236
  starget.parent.mkdir(parents=True, exist_ok=True)
237
+ starget.symlink_to(STYLE_ROOT, target_is_directory=True)
238
 
 
239
  default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
240
  legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
241
  train_config = os.getenv("TRAIN_CONFIG_PATH")
242
  if not train_config:
243
  train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
244
 
 
 
245
  device = os.getenv("DEVICE", "cuda")
 
246
  APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
247
  return APP
248
 
 
250
  # ---------------------------
251
  # ZeroGPU-wrapped handlers
252
  # ---------------------------
253
+ # (== κ·ΈλŒ€λ‘œ, μƒλž΅ 없이 λ‘” λΆ€λΆ„ ==)
254
  @spaces.GPU
255
+ def t2s_handler(text, max_tokens, steps, block_len, temperature, cfg_scale, gender, emotion, speed, pitch):
 
 
 
 
 
 
 
 
 
 
 
256
  app = get_app()
257
  audio, status = app.run_t2s(
258
  text=text,
 
268
  )
269
  return audio, status
270
 
 
271
  @spaces.GPU
272
+ def s2s_handler(audio_path, max_tokens, steps, block_len, temperature, cfg_scale):
 
 
 
 
 
 
 
273
  app = get_app()
274
  audio, status = app.run_s2s(
275
  audio_path=audio_path,
 
281
  )
282
  return audio, status
283
 
 
284
  @spaces.GPU
285
+ def s2t_handler(audio_path, steps, block_len, max_tokens, remasking):
 
 
 
 
 
 
286
  app = get_app()
287
  text, status = app.run_s2t(
288
  audio_path=audio_path,
 
293
  )
294
  return text, status
295
 
 
296
  @spaces.GPU
297
+ def v2t_handler(video, steps, block_len, max_tokens):
 
 
 
 
 
298
  app = get_app()
299
  text, status = app.run_v2t(
300
  video_path=video,
 
304
  )
305
  return text, status
306
 
 
307
  @spaces.GPU
308
+ def v2s_handler(video, message, max_tokens, steps, block_len, temperature, cfg_scale):
 
 
 
 
 
 
 
 
309
  app = get_app()
310
  audio, status = app.run_v2s(
311
  video_path=video,
 
318
  )
319
  return audio, status
320
 
 
321
  @spaces.GPU
322
+ def i2s_handler(image, message, max_tokens, steps, block_len, temperature, cfg_scale):
 
 
 
 
 
 
 
 
323
  app = get_app()
324
  audio, status = app.run_i2s(
325
  image=image,
 
332
  )
333
  return audio, status
334
 
 
335
  @spaces.GPU
336
+ def chat_handler(message, max_tokens, steps, block_len, temperature):
 
 
 
 
 
 
337
  app = get_app()
338
  text, status = app.run_chat(
339
  message=message,
 
344
  )
345
  return text, status
346
 
 
347
  @spaces.GPU
348
+ def mmu_handler(image_a, image_b, question, max_tokens, steps, block_len, temperature):
 
 
 
 
 
 
 
 
349
  app = get_app()
350
  text, status = app.run_mmu_dual(
351
  image_a=image_a,
 
358
  )
359
  return text, status
360
 
 
361
  @spaces.GPU
362
+ def t2i_handler(prompt, timesteps, temperature, guidance):
 
 
 
 
 
363
  app = get_app()
364
  image, status = app.run_t2i(
365
  prompt=prompt,
 
369
  )
370
  return image, status
371
 
 
372
  @spaces.GPU
373
+ def i2i_handler(instruction, image, timesteps, temperature, guidance):
 
 
 
 
 
 
374
  app = get_app()
375
  image_out, status = app.run_i2i(
376
  instruction=instruction,
 
383
 
384
 
385
  # ---------------------------
386
+ # Gradio UI (10 tabs + examples)
387
  # ---------------------------
388
 
389
  theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
390
 
391
  with gr.Blocks(
392
+ title="AIDAS Lab @ SNU - Omni-modal Diffusion",
393
  css=CUSTOM_CSS,
394
  theme=theme,
395
  js=FORCE_LIGHT_MODE_JS,
396
  ) as demo:
397
+ with gr.Row():
398
+ if LOGO_PATH.exists():
399
+ gr.Image(
400
+ value=str(LOGO_PATH),
401
+ show_label=False,
402
+ height=80,
403
+ interactive=False,
404
+ )
405
+ gr.Markdown(
406
+ "## Omni-modal Diffusion Foundation Model\n"
407
+ "### AIDAS Lab @ SNU"
408
  )
409
 
410
+ # ---- T2S ----
 
 
 
 
 
411
  with gr.Tab("Text β†’ Speech (T2S)"):
412
  with gr.Row():
413
  t2s_text = gr.Textbox(
 
429
  with gr.Row():
430
  t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed")
431
  t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch")
432
+ if T2S_EXAMPLES:
433
+ with gr.Accordion("Sample prompts", open=False):
434
+ gr.Examples(
435
+ examples=T2S_EXAMPLES,
436
+ inputs=[t2s_text],
437
+ examples_per_page=6,
438
+ )
439
  t2s_btn = gr.Button("Generate speech", variant="primary")
440
  t2s_btn.click(
441
  t2s_handler,
 
454
  outputs=[t2s_audio, t2s_status],
455
  )
456
 
457
+ # ---- S2S ----
 
 
 
 
 
 
 
 
458
  with gr.Tab("Speech β†’ Speech (S2S)"):
459
  s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
460
  s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
 
465
  s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
466
  s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature")
467
  s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale")
468
+ if S2S_EXAMPLES:
469
+ with gr.Accordion("Sample clips", open=False):
470
+ gr.Examples(
471
+ examples=S2S_EXAMPLES,
472
+ inputs=[s2s_audio_in],
473
+ examples_per_page=4,
474
+ )
475
  s2s_btn = gr.Button("Generate reply speech", variant="primary")
476
  s2s_btn.click(
477
  s2s_handler,
 
486
  outputs=[s2s_audio_out, s2s_status],
487
  )
488
 
489
+ # ---- S2T ----
 
 
 
 
 
 
 
 
490
  with gr.Tab("Speech β†’ Text (S2T)"):
491
  s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
492
  s2t_text_out = gr.Textbox(label="Transcription", lines=4)
 
500
  value="low_confidence",
501
  label="Remasking strategy",
502
  )
503
+ if S2T_EXAMPLES:
504
+ with gr.Accordion("Sample clips", open=False):
505
+ gr.Examples(
506
+ examples=S2T_EXAMPLES,
507
+ inputs=[s2t_audio_in],
508
+ examples_per_page=4,
509
+ )
510
  s2t_btn = gr.Button("Transcribe", variant="primary")
511
  s2t_btn.click(
512
  s2t_handler,
 
514
  outputs=[s2t_text_out, s2t_status],
515
  )
516
 
517
+ # ---- V2T ----
 
 
 
 
 
 
 
 
518
  with gr.Tab("Video β†’ Text (V2T)"):
519
  v2t_video_in = gr.Video(
520
  label="Upload or record video",
 
527
  v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps")
528
  v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
529
  v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens")
530
+ if V2T_EXAMPLES:
531
+ with gr.Accordion("Sample videos", open=False):
532
+ gr.Examples(
533
+ examples=V2T_EXAMPLES,
534
+ inputs=[v2t_video_in],
535
+ examples_per_page=4,
536
+ )
537
  v2t_btn = gr.Button("Generate caption", variant="primary")
538
  v2t_btn.click(
539
  v2t_handler,
 
541
  outputs=[v2t_text_out, v2t_status],
542
  )
543
 
544
+ # ---- V2S ----
 
 
 
 
 
 
 
 
545
  with gr.Tab("Video β†’ Speech (V2S)"):
546
  v2s_video_in = gr.Video(
547
  label="Upload or record video",
 
560
  v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
561
  v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
562
  v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
563
+ # (optional v2s examples: if you later add 'v2s' folder, same νŒ¨ν„΄μœΌλ‘œ 뢙이면 됨)
564
  v2s_btn = gr.Button("Generate speech from video", variant="primary")
565
  v2s_btn.click(
566
  v2s_handler,
 
576
  outputs=[v2s_audio_out, v2s_status],
577
  )
578
 
579
+ # ---- I2S ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  with gr.Tab("Image β†’ Speech (I2S)"):
581
  i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
582
  i2s_prompt = gr.Textbox(
 
591
  i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
592
  i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
593
  i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
594
+ if I2S_EXAMPLES:
595
+ with gr.Accordion("Sample images", open=False):
596
+ gr.Examples(
597
+ examples=I2S_EXAMPLES,
598
+ inputs=[i2s_image_in],
599
+ examples_per_page=4,
600
+ )
601
  i2s_btn = gr.Button("Generate spoken description", variant="primary")
602
  i2s_btn.click(
603
  i2s_handler,
 
613
  outputs=[i2s_audio_out, i2s_status],
614
  )
615
 
616
+ # ---- Chat ----
617
+ with gr.Tab("Text Chat"):
618
+ chat_in = gr.Textbox(
619
+ label="Message",
620
+ lines=4,
621
+ placeholder="Ask anything. The model will reply in text.",
622
+ )
623
+ chat_out = gr.Textbox(label="Assistant reply", lines=6)
624
+ chat_status = gr.Textbox(label="Status", interactive=False)
625
+ with gr.Accordion("Advanced settings", open=False):
626
+ chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens")
627
+ chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps")
628
+ chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
629
+ chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature")
630
+ if CHAT_EXAMPLES:
631
+ with gr.Accordion("Sample prompts", open=False):
632
+ gr.Examples(
633
+ examples=CHAT_EXAMPLES,
634
+ inputs=[chat_in],
635
+ examples_per_page=6,
636
+ )
637
+ chat_btn = gr.Button("Send", variant="primary")
638
+ chat_btn.click(
639
+ chat_handler,
640
+ inputs=[
641
+ chat_in,
642
+ chat_max_tokens,
643
+ chat_steps,
644
+ chat_block,
645
+ chat_temperature_slider,
646
+ ],
647
+ outputs=[chat_out, chat_status],
648
+ )
649
 
650
+ # ---- MMU ----
651
  with gr.Tab("MMU (2 images β†’ text)"):
652
  mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
653
  mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
 
663
  mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
664
  mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
665
  mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature")
666
+ if MMU_EXAMPLES:
667
+ with gr.Accordion("Sample MMU pair", open=False):
668
+ gr.Examples(
669
+ examples=MMU_EXAMPLES,
670
+ inputs=[mmu_img_a, mmu_img_b, mmu_question],
671
+ examples_per_page=1,
672
+ )
673
  mmu_btn = gr.Button("Answer about the two images", variant="primary")
674
  mmu_btn.click(
675
  mmu_handler,
 
685
  outputs=[mmu_answer, mmu_status],
686
  )
687
 
688
+ # ---- T2I ----
689
+ with gr.Tab("Text β†’ Image (T2I)"):
690
+ t2i_prompt = gr.Textbox(
691
+ label="Prompt",
692
+ lines=4,
693
+ placeholder="Describe the image you want to generate...",
694
+ )
695
+ t2i_image_out = gr.Image(label="Generated image")
696
+ t2i_status = gr.Textbox(label="Status", interactive=False)
697
+ with gr.Accordion("Advanced settings", open=False):
698
+ t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
699
+ t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
700
+ t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
701
+ if T2I_EXAMPLES:
702
+ with gr.Accordion("Sample prompts", open=False):
703
+ gr.Examples(
704
+ examples=T2I_EXAMPLES,
705
+ inputs=[t2i_prompt],
706
+ examples_per_page=6,
707
+ )
708
+ t2i_btn = gr.Button("Generate image", variant="primary")
709
+ t2i_btn.click(
710
+ t2i_handler,
711
+ inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
712
+ outputs=[t2i_image_out, t2i_status],
713
+ )
714
+
715
+ # ---- I2I ----
716
+ with gr.Tab("Image Editing (I2I)"):
717
+ i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
718
+ i2i_instr = gr.Textbox(
719
+ label="Editing instruction",
720
+ lines=4,
721
+ placeholder="Describe how you want to edit the image...",
722
+ )
723
+ i2i_image_out = gr.Image(label="Edited image")
724
+ i2i_status = gr.Textbox(label="Status", interactive=False)
725
+ with gr.Accordion("Advanced settings", open=False):
726
+ i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
727
+ i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
728
+ i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
729
+ i2i_btn = gr.Button("Apply edit", variant="primary")
730
+ i2i_btn.click(
731
+ i2i_handler,
732
+ inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
733
+ outputs=[i2i_image_out, i2i_status],
734
+ )
735
 
 
 
736
 
737
  if __name__ == "__main__":
738
  demo.launch()