jake commited on
Commit
e7b4b89
Β·
1 Parent(s): d3b2d63

change app

Browse files
Files changed (1) hide show
  1. app.py +745 -34
app.py CHANGED
@@ -1,7 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
- Gradio Space entrypoint mirroring `MMaDA/inference/gradio_multimodal_demo_inst.py`.
3
- It downloads the published checkpoint once via huggingface_hub, wires it into
4
- OmadaDemo, and launches the existing Blocks UI.
 
 
 
5
 
6
  Environment overrides:
7
  MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
@@ -12,33 +203,42 @@ Environment overrides:
12
  STYLE_REVISION (default: main)
13
  HF_TOKEN (optional, for private model/dataset)
14
  TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
15
- DEVICE (default: auto cuda/cpu)
16
- PORT (default: 7860; Space sets this)
17
  """
18
 
19
  import os
20
  import sys
21
  import subprocess
22
  import importlib
23
- import spaces
24
  from pathlib import Path
25
 
 
 
26
  from packaging.version import parse as parse_version
27
 
28
- # Ensure local project is importable
 
 
 
29
  PROJECT_ROOT = Path(__file__).resolve().parent
30
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
31
  if str(MMADA_ROOT) not in sys.path:
32
  sys.path.insert(0, str(MMADA_ROOT))
 
33
  EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer"
34
  if str(EMOVA_ROOT) not in sys.path:
35
  sys.path.insert(0, str(EMOVA_ROOT))
36
 
37
 
 
 
 
 
38
  def ensure_hf_hub(target: str = "0.36.0"):
39
  """
40
  Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
41
- The Space base image installs gradio which may upgrade it to 1.x; we downgrade here.
 
42
  """
43
  try:
44
  import huggingface_hub as hub
@@ -53,6 +253,7 @@ def ensure_hf_hub(target: str = "0.36.0"):
53
  [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
54
  )
55
  hub = importlib.reload(hub)
 
56
  # Backfill missing constants in older hub versions to avoid AttributeError.
57
  try:
58
  import huggingface_hub.constants as hub_consts # type: ignore
@@ -65,9 +266,22 @@ def ensure_hf_hub(target: str = "0.36.0"):
65
 
66
  snapshot_download = ensure_hf_hub().snapshot_download
67
 
68
- from inference.gradio_multimodal_demo_inst import OmadaDemo, build_demo # noqa: E402
 
 
 
 
 
 
 
 
 
69
 
70
 
 
 
 
 
71
  def download_assets() -> Path:
72
  """Download demo assets (logo + sample prompts/media) and return the root path."""
73
  repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets")
@@ -127,25 +341,39 @@ def download_checkpoint() -> Path:
127
  )
128
  )
129
 
130
- # If snapshot itself is unwrapped_model, return it; otherwise point a symlink to it.
 
131
  if snapshot_path.name == "unwrapped_model":
132
  return snapshot_path
 
133
  nested = snapshot_path / "unwrapped_model"
134
  if nested.is_dir():
135
  return nested
 
136
  aliased = snapshot_path.parent / "unwrapped_model"
137
  if not aliased.exists():
138
  aliased.symlink_to(snapshot_path, target_is_directory=True)
139
  return aliased
140
 
141
 
142
- @spaces.GPU
143
- def main():
144
- checkpoint_dir = download_checkpoint()
 
 
 
 
 
 
 
 
 
 
 
145
  asset_root = download_assets()
146
  style_root = download_style()
147
 
148
- # Symlink style centroid npy files to expected locations
149
  style_targets = [
150
  MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid",
151
  PROJECT_ROOT
@@ -155,33 +383,516 @@ def main():
155
  / "condition_style_centroid",
156
  ]
157
  for starget in style_targets:
158
- if starget.exists():
159
- continue
160
- starget.parent.mkdir(parents=True, exist_ok=True)
161
- starget.symlink_to(style_root, target_is_directory=True)
162
-
163
- # Point demo assets (logo, sample prompts/media) to the downloaded dataset
164
- from inference import gradio_multimodal_demo_inst as demo_mod # noqa: WPS433
165
-
166
- demo_root = asset_root / "demo"
167
- demo_mod.DEMO_ROOT = demo_root
168
- demo_mod.LOGO_PATH = demo_root / "logo.png"
169
- demo_mod.T2S_TEXT_PATH = demo_root / "t2s" / "text.txt"
170
- demo_mod.CHAT_TEXT_PATH = demo_root / "chat" / "text.txt"
171
- demo_mod.T2I_TEXT_PATH = demo_root / "t2i" / "text.txt"
172
 
 
173
  default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
174
  legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
175
  train_config = os.getenv("TRAIN_CONFIG_PATH")
176
  if not train_config:
177
- # Prefer configs/mmada_demo.yaml (in repo), fallback to legacy path if restored.
178
  train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
179
- device = os.getenv("DEVICE")
180
- port = int(os.getenv("PORT", "7860"))
181
 
182
- app = OmadaDemo(train_config=train_config, checkpoint=str(checkpoint_dir), device=device)
183
- build_demo(app, share=False, server_name="0.0.0.0", server_port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
  if __name__ == "__main__":
187
- main()
 
 
1
+ # """
2
+ # Gradio Space entrypoint mirroring `MMaDA/inference/gradio_multimodal_demo_inst.py`.
3
+ # It downloads the published checkpoint once via huggingface_hub, wires it into
4
+ # OmadaDemo, and launches the existing Blocks UI.
5
+
6
+ # Environment overrides:
7
+ # MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
8
+ # MODEL_REVISION (default: main)
9
+ # ASSET_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion-assets)
10
+ # ASSET_REVISION (default: main)
11
+ # STYLE_REPO_ID (default: jaeikkim/aidas-style-centroid)
12
+ # STYLE_REVISION (default: main)
13
+ # HF_TOKEN (optional, for private model/dataset)
14
+ # TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
15
+ # DEVICE (default: auto cuda/cpu)
16
+ # PORT (default: 7860; Space sets this)
17
+ # """
18
+
19
+ # import os
20
+ # import sys
21
+ # import subprocess
22
+ # import importlib
23
+ # import spaces
24
+ # from pathlib import Path
25
+
26
+ # from packaging.version import parse as parse_version
27
+
28
+ # # Ensure local project is importable
29
+ # PROJECT_ROOT = Path(__file__).resolve().parent
30
+ # MMADA_ROOT = PROJECT_ROOT / "MMaDA"
31
+ # if str(MMADA_ROOT) not in sys.path:
32
+ # sys.path.insert(0, str(MMADA_ROOT))
33
+ # EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer"
34
+ # if str(EMOVA_ROOT) not in sys.path:
35
+ # sys.path.insert(0, str(EMOVA_ROOT))
36
+
37
+
38
+ # def ensure_hf_hub(target: str = "0.36.0"):
39
+ # """
40
+ # Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
41
+ # The Space base image installs gradio which may upgrade it to 1.x; we downgrade here.
42
+ # """
43
+ # try:
44
+ # import huggingface_hub as hub
45
+ # except ImportError:
46
+ # subprocess.check_call(
47
+ # [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
48
+ # )
49
+ # import huggingface_hub as hub
50
+
51
+ # if parse_version(hub.__version__) >= parse_version("1.0.0"):
52
+ # subprocess.check_call(
53
+ # [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
54
+ # )
55
+ # hub = importlib.reload(hub)
56
+ # # Backfill missing constants in older hub versions to avoid AttributeError.
57
+ # try:
58
+ # import huggingface_hub.constants as hub_consts # type: ignore
59
+ # except Exception:
60
+ # hub_consts = None
61
+ # if hub_consts and not hasattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER"):
62
+ # setattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER", False)
63
+ # return hub
64
+
65
+
66
+ # snapshot_download = ensure_hf_hub().snapshot_download
67
+
68
+ # from inference.gradio_multimodal_demo_inst import OmadaDemo, build_demo # noqa: E402
69
+
70
+
71
+ # def download_assets() -> Path:
72
+ # """Download demo assets (logo + sample prompts/media) and return the root path."""
73
+ # repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets")
74
+ # revision = os.getenv("ASSET_REVISION", "main")
75
+ # token = os.getenv("HF_TOKEN")
76
+ # cache_dir = PROJECT_ROOT / "_asset_cache"
77
+ # cache_dir.mkdir(parents=True, exist_ok=True)
78
+
79
+ # return Path(
80
+ # snapshot_download(
81
+ # repo_id=repo_id,
82
+ # revision=revision,
83
+ # repo_type="dataset",
84
+ # local_dir=cache_dir,
85
+ # local_dir_use_symlinks=False,
86
+ # token=token,
87
+ # )
88
+ # )
89
+
90
+
91
+ # def download_style() -> Path:
92
+ # """Download style centroid dataset and return the root path."""
93
+ # repo_id = os.getenv("STYLE_REPO_ID", "jaeikkim/aidas-style-centroid")
94
+ # revision = os.getenv("STYLE_REVISION", "main")
95
+ # token = os.getenv("HF_TOKEN")
96
+ # cache_dir = PROJECT_ROOT / "_style_cache"
97
+ # cache_dir.mkdir(parents=True, exist_ok=True)
98
+
99
+ # return Path(
100
+ # snapshot_download(
101
+ # repo_id=repo_id,
102
+ # revision=revision,
103
+ # repo_type="dataset",
104
+ # local_dir=cache_dir,
105
+ # local_dir_use_symlinks=False,
106
+ # token=token,
107
+ # )
108
+ # )
109
+
110
+
111
+ # def download_checkpoint() -> Path:
112
+ # """Download checkpoint snapshot and return an `unwrapped_model` directory."""
113
+ # repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
114
+ # revision = os.getenv("MODEL_REVISION", "main")
115
+ # token = os.getenv("HF_TOKEN")
116
+ # cache_dir = PROJECT_ROOT / "_ckpt_cache"
117
+ # cache_dir.mkdir(parents=True, exist_ok=True)
118
+
119
+ # snapshot_path = Path(
120
+ # snapshot_download(
121
+ # repo_id=repo_id,
122
+ # revision=revision,
123
+ # repo_type="model",
124
+ # local_dir=cache_dir,
125
+ # local_dir_use_symlinks=False,
126
+ # token=token,
127
+ # )
128
+ # )
129
+
130
+ # # If snapshot itself is unwrapped_model, return it; otherwise point a symlink to it.
131
+ # if snapshot_path.name == "unwrapped_model":
132
+ # return snapshot_path
133
+ # nested = snapshot_path / "unwrapped_model"
134
+ # if nested.is_dir():
135
+ # return nested
136
+ # aliased = snapshot_path.parent / "unwrapped_model"
137
+ # if not aliased.exists():
138
+ # aliased.symlink_to(snapshot_path, target_is_directory=True)
139
+ # return aliased
140
+
141
+
142
+ # @spaces.GPU
143
+ # def main():
144
+ # checkpoint_dir = download_checkpoint()
145
+ # asset_root = download_assets()
146
+ # style_root = download_style()
147
+
148
+ # # Symlink style centroid npy files to expected locations
149
+ # style_targets = [
150
+ # MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid",
151
+ # PROJECT_ROOT
152
+ # / "EMOVA_speech_tokenizer"
153
+ # / "emova_speech_tokenizer"
154
+ # / "speech_tokenization"
155
+ # / "condition_style_centroid",
156
+ # ]
157
+ # for starget in style_targets:
158
+ # if starget.exists():
159
+ # continue
160
+ # starget.parent.mkdir(parents=True, exist_ok=True)
161
+ # starget.symlink_to(style_root, target_is_directory=True)
162
+
163
+ # # Point demo assets (logo, sample prompts/media) to the downloaded dataset
164
+ # from inference import gradio_multimodal_demo_inst as demo_mod # noqa: WPS433
165
+
166
+ # demo_root = asset_root / "demo"
167
+ # demo_mod.DEMO_ROOT = demo_root
168
+ # demo_mod.LOGO_PATH = demo_root / "logo.png"
169
+ # demo_mod.T2S_TEXT_PATH = demo_root / "t2s" / "text.txt"
170
+ # demo_mod.CHAT_TEXT_PATH = demo_root / "chat" / "text.txt"
171
+ # demo_mod.T2I_TEXT_PATH = demo_root / "t2i" / "text.txt"
172
+
173
+ # default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
174
+ # legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
175
+ # train_config = os.getenv("TRAIN_CONFIG_PATH")
176
+ # if not train_config:
177
+ # # Prefer configs/mmada_demo.yaml (in repo), fallback to legacy path if restored.
178
+ # train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
179
+ # device = os.getenv("DEVICE")
180
+ # port = int(os.getenv("PORT", "7860"))
181
+
182
+ # app = OmadaDemo(train_config=train_config, checkpoint=str(checkpoint_dir), device=device)
183
+ # build_demo(app, share=False, server_name="0.0.0.0", server_port=port)
184
+
185
+
186
+ # if __name__ == "__main__":
187
+ # main()
188
+
189
  """
190
+ ZeroGPU-friendly Gradio entrypoint for OMada demo.
191
+
192
+ - Downloads checkpoint + assets + style centroids from Hugging Face Hub
193
+ - Instantiates OmadaDemo once (global)
194
+ - Exposes 10 modalities via Gradio tabs
195
+ - Uses @spaces.GPU only on inference handlers so GPU is allocated per request
196
 
197
  Environment overrides:
198
  MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
 
203
  STYLE_REVISION (default: main)
204
  HF_TOKEN (optional, for private model/dataset)
205
  TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
206
+ DEVICE (default: cuda)
 
207
  """
208
 
209
  import os
210
  import sys
211
  import subprocess
212
  import importlib
 
213
  from pathlib import Path
214
 
215
+ import gradio as gr
216
+ import spaces
217
  from packaging.version import parse as parse_version
218
 
219
+ # ---------------------------
220
+ # Project roots & sys.path
221
+ # ---------------------------
222
+
223
  PROJECT_ROOT = Path(__file__).resolve().parent
224
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
225
  if str(MMADA_ROOT) not in sys.path:
226
  sys.path.insert(0, str(MMADA_ROOT))
227
+
228
  EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer"
229
  if str(EMOVA_ROOT) not in sys.path:
230
  sys.path.insert(0, str(EMOVA_ROOT))
231
 
232
 
233
+ # ---------------------------
234
+ # HuggingFace Hub helper
235
+ # ---------------------------
236
+
237
  def ensure_hf_hub(target: str = "0.36.0"):
238
  """
239
  Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
240
+
241
+ The Spaces base image may pull in a newer version via gradio, so we pin it.
242
  """
243
  try:
244
  import huggingface_hub as hub
 
253
  [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
254
  )
255
  hub = importlib.reload(hub)
256
+
257
  # Backfill missing constants in older hub versions to avoid AttributeError.
258
  try:
259
  import huggingface_hub.constants as hub_consts # type: ignore
 
266
 
267
  snapshot_download = ensure_hf_hub().snapshot_download
268
 
269
+
270
+ # ---------------------------
271
+ # Imports from OMada demo
272
+ # ---------------------------
273
+
274
+ from inference.gradio_multimodal_demo_inst import ( # noqa: E402
275
+ OmadaDemo,
276
+ CUSTOM_CSS,
277
+ FORCE_LIGHT_MODE_JS,
278
+ )
279
 
280
 
281
+ # ---------------------------
282
+ # HF download helpers
283
+ # ---------------------------
284
+
285
  def download_assets() -> Path:
286
  """Download demo assets (logo + sample prompts/media) and return the root path."""
287
  repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets")
 
341
  )
342
  )
343
 
344
+ # If snapshot itself is unwrapped_model, return it; otherwise look for nested dir,
345
+ # and finally alias via symlink.
346
  if snapshot_path.name == "unwrapped_model":
347
  return snapshot_path
348
+
349
  nested = snapshot_path / "unwrapped_model"
350
  if nested.is_dir():
351
  return nested
352
+
353
  aliased = snapshot_path.parent / "unwrapped_model"
354
  if not aliased.exists():
355
  aliased.symlink_to(snapshot_path, target_is_directory=True)
356
  return aliased
357
 
358
 
359
+ # ---------------------------
360
+ # Global OmadaDemo instance
361
+ # ---------------------------
362
+
363
+ APP = None # type: ignore
364
+
365
+
366
+ def get_app() -> OmadaDemo:
367
+ global APP
368
+ if APP is not None:
369
+ return APP
370
+
371
+ # Download everything once
372
+ ckpt_dir = download_checkpoint()
373
  asset_root = download_assets()
374
  style_root = download_style()
375
 
376
+ # Wire style centroids to expected locations
377
  style_targets = [
378
  MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid",
379
  PROJECT_ROOT
 
383
  / "condition_style_centroid",
384
  ]
385
  for starget in style_targets:
386
+ if not starget.exists():
387
+ starget.parent.mkdir(parents=True, exist_ok=True)
388
+ starget.symlink_to(style_root, target_is_directory=True)
 
 
 
 
 
 
 
 
 
 
 
389
 
390
+ # Choose train config
391
  default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
392
  legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
393
  train_config = os.getenv("TRAIN_CONFIG_PATH")
394
  if not train_config:
 
395
  train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
 
 
396
 
397
+ # Device: in ZeroGPU environment, "cuda" is virtualized and only actually
398
+ # attached inside @spaces.GPU handlers.
399
+ device = os.getenv("DEVICE", "cuda")
400
+
401
+ APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
402
+ return APP
403
+
404
+
405
+ # ---------------------------
406
+ # ZeroGPU-wrapped handlers
407
+ # ---------------------------
408
+
409
+ @spaces.GPU
410
+ def t2s_handler(
411
+ text,
412
+ max_tokens,
413
+ steps,
414
+ block_len,
415
+ temperature,
416
+ cfg_scale,
417
+ gender,
418
+ emotion,
419
+ speed,
420
+ pitch,
421
+ ):
422
+ app = get_app()
423
+ audio, status = app.run_t2s(
424
+ text=text,
425
+ max_new_tokens=int(max_tokens),
426
+ steps=int(steps),
427
+ block_length=int(block_len),
428
+ temperature=float(temperature),
429
+ cfg_scale=float(cfg_scale),
430
+ gender_choice=gender,
431
+ emotion_choice=emotion,
432
+ speed_choice=speed,
433
+ pitch_choice=pitch,
434
+ )
435
+ return audio, status
436
+
437
+
438
+ @spaces.GPU
439
+ def s2s_handler(
440
+ audio_path,
441
+ max_tokens,
442
+ steps,
443
+ block_len,
444
+ temperature,
445
+ cfg_scale,
446
+ ):
447
+ app = get_app()
448
+ audio, status = app.run_s2s(
449
+ audio_path=audio_path,
450
+ max_new_tokens=int(max_tokens),
451
+ steps=int(steps),
452
+ block_length=int(block_len),
453
+ temperature=float(temperature),
454
+ cfg_scale=float(cfg_scale),
455
+ )
456
+ return audio, status
457
+
458
+
459
+ @spaces.GPU
460
+ def s2t_handler(
461
+ audio_path,
462
+ steps,
463
+ block_len,
464
+ max_tokens,
465
+ remasking,
466
+ ):
467
+ app = get_app()
468
+ text, status = app.run_s2t(
469
+ audio_path=audio_path,
470
+ steps=int(steps),
471
+ block_length=int(block_len),
472
+ max_new_tokens=int(max_tokens),
473
+ remasking=str(remasking),
474
+ )
475
+ return text, status
476
+
477
+
478
+ @spaces.GPU
479
+ def v2t_handler(
480
+ video,
481
+ steps,
482
+ block_len,
483
+ max_tokens,
484
+ ):
485
+ app = get_app()
486
+ text, status = app.run_v2t(
487
+ video_path=video,
488
+ steps=int(steps),
489
+ block_length=int(block_len),
490
+ max_new_tokens=int(max_tokens),
491
+ )
492
+ return text, status
493
+
494
+
495
+ @spaces.GPU
496
+ def v2s_handler(
497
+ video,
498
+ message,
499
+ max_tokens,
500
+ steps,
501
+ block_len,
502
+ temperature,
503
+ cfg_scale,
504
+ ):
505
+ app = get_app()
506
+ audio, status = app.run_v2s(
507
+ video_path=video,
508
+ message=message,
509
+ max_new_tokens=int(max_tokens),
510
+ steps=int(steps),
511
+ block_length=int(block_len),
512
+ temperature=float(temperature),
513
+ cfg_scale=float(cfg_scale),
514
+ )
515
+ return audio, status
516
+
517
+
518
+ @spaces.GPU
519
+ def i2s_handler(
520
+ image,
521
+ message,
522
+ max_tokens,
523
+ steps,
524
+ block_len,
525
+ temperature,
526
+ cfg_scale,
527
+ ):
528
+ app = get_app()
529
+ audio, status = app.run_i2s(
530
+ image=image,
531
+ message=message,
532
+ max_new_tokens=int(max_tokens),
533
+ steps=int(steps),
534
+ block_length=int(block_len),
535
+ temperature=float(temperature),
536
+ cfg_scale=float(cfg_scale),
537
+ )
538
+ return audio, status
539
+
540
+
541
+ @spaces.GPU
542
+ def chat_handler(
543
+ message,
544
+ max_tokens,
545
+ steps,
546
+ block_len,
547
+ temperature,
548
+ ):
549
+ app = get_app()
550
+ text, status = app.run_chat(
551
+ message=message,
552
+ max_new_tokens=int(max_tokens),
553
+ steps=int(steps),
554
+ block_length=int(block_len),
555
+ temperature=float(temperature),
556
+ )
557
+ return text, status
558
+
559
+
560
+ @spaces.GPU
561
+ def mmu_handler(
562
+ image_a,
563
+ image_b,
564
+ question,
565
+ max_tokens,
566
+ steps,
567
+ block_len,
568
+ temperature,
569
+ ):
570
+ app = get_app()
571
+ text, status = app.run_mmu_dual(
572
+ image_a=image_a,
573
+ image_b=image_b,
574
+ message=question,
575
+ max_new_tokens=int(max_tokens),
576
+ steps=int(steps),
577
+ block_length=int(block_len),
578
+ temperature=float(temperature),
579
+ )
580
+ return text, status
581
+
582
+
583
+ @spaces.GPU
584
+ def t2i_handler(
585
+ prompt,
586
+ timesteps,
587
+ temperature,
588
+ guidance,
589
+ ):
590
+ app = get_app()
591
+ image, status = app.run_t2i(
592
+ prompt=prompt,
593
+ timesteps=int(timesteps),
594
+ temperature=float(temperature),
595
+ guidance_scale=float(guidance),
596
+ )
597
+ return image, status
598
+
599
+
600
+ @spaces.GPU
601
+ def i2i_handler(
602
+ instruction,
603
+ image,
604
+ timesteps,
605
+ temperature,
606
+ guidance,
607
+ ):
608
+ app = get_app()
609
+ image_out, status = app.run_i2i(
610
+ instruction=instruction,
611
+ source_image=image,
612
+ timesteps=int(timesteps),
613
+ temperature=float(temperature),
614
+ guidance_scale=float(guidance),
615
+ )
616
+ return image_out, status
617
+
618
+
619
+ # ---------------------------
620
+ # Gradio UI (10 tabs)
621
+ # ---------------------------
622
+
623
+ theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
624
+
625
+ with gr.Blocks(
626
+ title="AIDAS Lab @ SNU - OMni-modal Diffusion (ZeroGPU)",
627
+ css=CUSTOM_CSS,
628
+ theme=theme,
629
+ js=FORCE_LIGHT_MODE_JS,
630
+ ) as demo:
631
+ gr.Markdown(
632
+ "## Omni-modal Diffusion Foundation Model\n"
633
+ "### ZeroGPU-compatible demo (AIDAS Lab @ SNU)"
634
+ )
635
+
636
+ with gr.Tab("Text β†’ Speech (T2S)"):
637
+ with gr.Row():
638
+ t2s_text = gr.Textbox(
639
+ label="Input text",
640
+ lines=4,
641
+ placeholder="Type the speech you want to synthesize...",
642
+ )
643
+ t2s_audio = gr.Audio(label="Generated speech", type="numpy")
644
+ t2s_status = gr.Textbox(label="Status", interactive=False)
645
+ with gr.Accordion("Advanced settings", open=False):
646
+ t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length")
647
+ t2s_steps = gr.Slider(2, 512, value=128, step=2, label="Total refinement steps")
648
+ t2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
649
+ t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
650
+ t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="CFG scale")
651
+ with gr.Row():
652
+ t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="Gender")
653
+ t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="Emotion")
654
+ with gr.Row():
655
+ t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed")
656
+ t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch")
657
+ t2s_btn = gr.Button("Generate speech", variant="primary")
658
+ t2s_btn.click(
659
+ t2s_handler,
660
+ inputs=[
661
+ t2s_text,
662
+ t2s_max_tokens,
663
+ t2s_steps,
664
+ t2s_block,
665
+ t2s_temperature,
666
+ t2s_cfg,
667
+ t2s_gender,
668
+ t2s_emotion,
669
+ t2s_speed,
670
+ t2s_pitch,
671
+ ],
672
+ outputs=[t2s_audio, t2s_status],
673
+ )
674
+
675
+ with gr.Tab("Speech β†’ Speech (S2S)"):
676
+ s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
677
+ s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
678
+ s2s_status = gr.Textbox(label="Status", interactive=False)
679
+ with gr.Accordion("Advanced settings", open=False):
680
+ s2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
681
+ s2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps")
682
+ s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
683
+ s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature")
684
+ s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale")
685
+ s2s_btn = gr.Button("Generate reply speech", variant="primary")
686
+ s2s_btn.click(
687
+ s2s_handler,
688
+ inputs=[
689
+ s2s_audio_in,
690
+ s2s_max_tokens,
691
+ s2s_steps,
692
+ s2s_block,
693
+ s2s_temperature,
694
+ s2s_cfg,
695
+ ],
696
+ outputs=[s2s_audio_out, s2s_status],
697
+ )
698
+
699
+ with gr.Tab("Speech β†’ Text (S2T)"):
700
+ s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
701
+ s2t_text_out = gr.Textbox(label="Transcription", lines=4)
702
+ s2t_status = gr.Textbox(label="Status", interactive=False)
703
+ with gr.Accordion("Advanced settings", open=False):
704
+ s2t_steps = gr.Slider(2, 512, value=128, step=2, label="Denoising steps")
705
+ s2t_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
706
+ s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="Max new tokens")
707
+ s2t_remasking = gr.Dropdown(
708
+ ["low_confidence", "random"],
709
+ value="low_confidence",
710
+ label="Remasking strategy",
711
+ )
712
+ s2t_btn = gr.Button("Transcribe", variant="primary")
713
+ s2t_btn.click(
714
+ s2t_handler,
715
+ inputs=[s2t_audio_in, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking],
716
+ outputs=[s2t_text_out, s2t_status],
717
+ )
718
+
719
+ with gr.Tab("Video β†’ Text (V2T)"):
720
+ v2t_video_in = gr.Video(
721
+ label="Upload or record video",
722
+ height=256,
723
+ sources=["upload", "webcam"],
724
+ )
725
+ v2t_text_out = gr.Textbox(label="Caption / answer", lines=4)
726
+ v2t_status = gr.Textbox(label="Status", interactive=False)
727
+ with gr.Accordion("Advanced settings", open=False):
728
+ v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps")
729
+ v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
730
+ v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens")
731
+ v2t_btn = gr.Button("Generate caption", variant="primary")
732
+ v2t_btn.click(
733
+ v2t_handler,
734
+ inputs=[v2t_video_in, v2t_steps, v2t_block, v2t_max_tokens],
735
+ outputs=[v2t_text_out, v2t_status],
736
+ )
737
+
738
+ with gr.Tab("Video β†’ Speech (V2S)"):
739
+ v2s_video_in = gr.Video(
740
+ label="Upload or record video",
741
+ height=256,
742
+ sources=["upload", "webcam"],
743
+ )
744
+ v2s_prompt = gr.Textbox(
745
+ label="Optional instruction",
746
+ placeholder="(Optional) e.g., 'Describe this scene in spoken form.'",
747
+ )
748
+ v2s_audio_out = gr.Audio(type="numpy", label="Generated speech")
749
+ v2s_status = gr.Textbox(label="Status", interactive=False)
750
+ with gr.Accordion("Advanced settings", open=False):
751
+ v2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
752
+ v2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps")
753
+ v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
754
+ v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
755
+ v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
756
+ v2s_btn = gr.Button("Generate speech from video", variant="primary")
757
+ v2s_btn.click(
758
+ v2s_handler,
759
+ inputs=[
760
+ v2s_video_in,
761
+ v2s_prompt,
762
+ v2s_max_tokens,
763
+ v2s_steps,
764
+ v2s_block,
765
+ v2s_temperature,
766
+ v2s_cfg,
767
+ ],
768
+ outputs=[v2s_audio_out, v2s_status],
769
+ )
770
+
771
+ with gr.Tab("Image β†’ Speech (I2S)"):
772
+ i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
773
+ i2s_prompt = gr.Textbox(
774
+ label="Optional question",
775
+ placeholder="(Optional) e.g., 'Describe this image aloud.'",
776
+ )
777
+ i2s_audio_out = gr.Audio(type="numpy", label="Spoken description")
778
+ i2s_status = gr.Textbox(label="Status", interactive=False)
779
+ with gr.Accordion("Advanced settings", open=False):
780
+ i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
781
+ i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
782
+ i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
783
+ i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
784
+ i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
785
+ i2s_btn = gr.Button("Generate spoken description", variant="primary")
786
+ i2s_btn.click(
787
+ i2s_handler,
788
+ inputs=[
789
+ i2s_image_in,
790
+ i2s_prompt,
791
+ i2s_max_tokens,
792
+ i2s_steps,
793
+ i2s_block,
794
+ i2s_temperature,
795
+ i2s_cfg,
796
+ ],
797
+ outputs=[i2s_audio_out, i2s_status],
798
+ )
799
+
800
+ with gr.Tab("Text Chat"):
801
+ chat_in = gr.Textbox(
802
+ label="Message",
803
+ lines=4,
804
+ placeholder="Ask anything. The model will reply in text.",
805
+ )
806
+ chat_out = gr.Textbox(label="Assistant reply", lines=6)
807
+ chat_status = gr.Textbox(label="Status", interactive=False)
808
+ with gr.Accordion("Advanced settings", open=False):
809
+ chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens")
810
+ chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps")
811
+ chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
812
+ chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature")
813
+ chat_btn = gr.Button("Send", variant="primary")
814
+ chat_btn.click(
815
+ chat_handler,
816
+ inputs=[
817
+ chat_in,
818
+ chat_max_tokens,
819
+ chat_steps,
820
+ chat_block,
821
+ chat_temperature_slider,
822
+ ],
823
+ outputs=[chat_out, chat_status],
824
+ )
825
+
826
+ with gr.Tab("MMU (2 images β†’ text)"):
827
+ mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
828
+ mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
829
+ mmu_question = gr.Textbox(
830
+ label="Question",
831
+ lines=3,
832
+ placeholder="Ask about the relationship or differences between the two images.",
833
+ )
834
+ mmu_answer = gr.Textbox(label="Answer", lines=6)
835
+ mmu_status = gr.Textbox(label="Status", interactive=False)
836
+ with gr.Accordion("Advanced settings", open=False):
837
+ mmu_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Answer max tokens")
838
+ mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
839
+ mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
840
+ mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature")
841
+ mmu_btn = gr.Button("Answer about the two images", variant="primary")
842
+ mmu_btn.click(
843
+ mmu_handler,
844
+ inputs=[
845
+ mmu_img_a,
846
+ mmu_img_b,
847
+ mmu_question,
848
+ mmu_max_tokens,
849
+ mmu_steps,
850
+ mmu_block,
851
+ mmu_temperature,
852
+ ],
853
+ outputs=[mmu_answer, mmu_status],
854
+ )
855
+
856
+ with gr.Tab("Text β†’ Image (T2I)"):
857
+ t2i_prompt = gr.Textbox(
858
+ label="Prompt",
859
+ lines=4,
860
+ placeholder="Describe the image you want to generate...",
861
+ )
862
+ t2i_image_out = gr.Image(label="Generated image")
863
+ t2i_status = gr.Textbox(label="Status", interactive=False)
864
+ with gr.Accordion("Advanced settings", open=False):
865
+ t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
866
+ t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
867
+ t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
868
+ t2i_btn = gr.Button("Generate image", variant="primary")
869
+ t2i_btn.click(
870
+ t2i_handler,
871
+ inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
872
+ outputs=[t2i_image_out, t2i_status],
873
+ )
874
+
875
+ with gr.Tab("Image Editing (I2I)"):
876
+ i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
877
+ i2i_instr = gr.Textbox(
878
+ label="Editing instruction",
879
+ lines=4,
880
+ placeholder="Describe how you want to edit the image...",
881
+ )
882
+ i2i_image_out = gr.Image(label="Edited image")
883
+ i2i_status = gr.Textbox(label="Status", interactive=False)
884
+ with gr.Accordion("Advanced settings", open=False):
885
+ i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
886
+ i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
887
+ i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
888
+ i2i_btn = gr.Button("Apply edit", variant="primary")
889
+ i2i_btn.click(
890
+ i2i_handler,
891
+ inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
892
+ outputs=[i2i_image_out, i2i_status],
893
+ )
894
 
895
 
896
  if __name__ == "__main__":
897
+ demo.launch()
898
+