jake commited on
Commit
668aead
·
1 Parent(s): e7b4b89
Files changed (2) hide show
  1. MMaDA/app.py +370 -866
  2. app.py +369 -869
MMaDA/app.py CHANGED
@@ -1,894 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
- import torch.nn.functional as F
5
- from transformers import AutoTokenizer
6
- from torchvision import transforms
7
- from models import MAGVITv2, get_mask_schedule, MMadaModelLM
8
- from training.prompting_utils import UniversalPrompting
9
- from PIL import Image
10
-
11
- def image_transform(image, resolution=256, normalize=True):
12
- image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
13
- image = transforms.CenterCrop((resolution, resolution))(image)
14
- image = transforms.ToTensor()(image)
15
- if normalize:
16
- image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
17
- return image
18
-
19
- def add_gumbel_noise(logits, temperature):
20
  """
21
- Adds Gumbel noise to logits for stochastic sampling.
22
- Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
23
- This version is more numerically stable than a version involving exp() and division.
24
  """
25
- if abs(temperature) < 1e-9: # Effectively zero temperature
26
- return logits
27
- # Ensure logits are float64 for precision with noise, as suggested by user context
28
- if DEVICE == "mps":
29
- logits = logits.to(torch.float32)
30
- else:
31
- logits = logits.to(torch.float64)
32
- # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
33
- # Add small epsilon for numerical stability inside logs
34
- if DEVICE == "mps":
35
- noise = torch.rand_like(logits, dtype=torch.float32)
36
- else:
37
- noise = torch.rand_like(logits, dtype=torch.float64)
38
- standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
39
- return logits + temperature * standard_gumbel_noise
40
-
41
- def get_num_transfer_tokens(mask_index, steps):
42
- mask_num = mask_index.sum(dim=1, keepdim=True)
43
- # Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
44
- steps = max(1, int(steps)) # Ensure steps is a positive integer
45
- base = mask_num // steps
46
- remainder = mask_num % steps
47
- num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
48
- for i in range(mask_num.size(0)): # Iterate over batch
49
- if remainder[i] > 0 : # Ensure remainder is positive before indexing
50
- num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
51
- return num_transfer_tokens
52
-
53
- MODEL = None
54
- TOKENIZER = None
55
- DEVICE = (
56
- "cuda"
57
- if torch.cuda.is_available()
58
- else "mps" if torch.backends.mps.is_available() else "cpu"
59
- )
60
- MASK_ID = None
61
- uni_prompting = None
62
- VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
63
-
64
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
65
- CURRENT_MODEL_PATH = None
66
-
67
- MODEL_CHOICES = [
68
- "MMaDA-8B-Base",
69
- "MMaDA-8B-MixCoT (coming soon)",
70
- "MMaDA-8B-Max (coming soon)"
71
- ]
72
- MODEL_ACTUAL_PATHS = {
73
- "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
74
- }
75
-
76
- def clear_outputs_action():
77
- return None, None
78
-
79
- def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
80
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
81
-
82
- if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
83
- return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
84
-
85
- CURRENT_MODEL_PATH = model_path_to_load
86
-
87
- status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
88
- try:
89
- TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
90
- status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
91
-
92
- MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
93
- status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
94
-
95
- uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
96
-
97
- if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
98
- MASK_ID = TOKENIZER.mask_token_id
99
- status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
100
- else:
101
- MASK_ID = 126336
102
- status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
103
-
104
- if TOKENIZER.pad_token_id is None:
105
- if TOKENIZER.eos_token_id is not None:
106
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
107
- TOKENIZER.pad_token = TOKENIZER.eos_token
108
- status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
109
- else:
110
- status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
111
-
112
- if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
113
- status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
114
-
115
- TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
116
-
117
- return " ".join(status_msg_parts)
118
- except Exception as e:
119
- MODEL = None
120
- TOKENIZER = None
121
- MASK_ID = None
122
- CURRENT_MODEL_PATH = None
123
- return f"Error loading model '{model_display_name_for_status}': {str(e)}"
124
-
125
- def handle_model_selection_change(selected_model_name_ui):
126
- if "coming soon" in selected_model_name_ui.lower():
127
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
128
- MODEL = None
129
- TOKENIZER = None
130
- MASK_ID = None
131
- CURRENT_MODEL_PATH = None
132
- return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
133
-
134
- actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
135
- if not actual_path:
136
- return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
137
-
138
- return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
139
-
140
-
141
- def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
142
- if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
143
- return [("Error in sequence data for visualization.", "ERROR")]
144
- # only answer part
145
- current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
146
- seq_ids = current_x_ids_batch[0].tolist()
147
- eos_token_id = tk.eos_token_id # Get EOS token ID
148
-
149
- # Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
150
- # This helps in identifying EOS tokens later without re-checking the type.
151
- intermediate_tuples = []
152
- for j, token_id_int in enumerate(seq_ids):
153
- try:
154
- token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
155
- except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
156
- token_str = f"[ID:{token_id_int}]"
157
-
158
- label = "ERROR"
159
- if token_id_int == current_mask_id:
160
- token_str = "[MASK]"
161
- label = "MASK"
162
- else:
163
- label = "GEN"
164
- intermediate_tuples.append((token_str, label, token_id_int))
165
-
166
- return intermediate_tuples
167
-
168
- @torch.no_grad()
169
- def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
170
- global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
171
-
172
- if MODEL is None or TOKENIZER is None or MASK_ID is None:
173
- yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
174
- return
175
- steps = int(steps)
176
- guidance_scale = float(guidance_scale)
177
-
178
- image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
179
- prompt_text = [prompt_text]
180
- input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
181
-
182
- if guidance_scale > 0:
183
- uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
184
- else:
185
- uncond_input_ids, uncond_attention_mask = None, None
186
-
187
- mask_schedule = get_mask_schedule(mask_schedule)
188
- blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
189
- yield blank_image, "Starting generation..."
190
- for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
191
- input_ids = input_ids,
192
- uncond_input_ids = uncond_input_ids,
193
- attention_mask = attention_mask,
194
- uncond_attention_mask = uncond_attention_mask,
195
- temperature=1.0,
196
- timesteps = steps,
197
- guidance_scale = guidance_scale,
198
- noise_schedule = mask_schedule,
199
- noise_type = "mask",
200
- seq_len = 1024,
201
- vq_model = VQ_MODEL,
202
- uni_prompting=uni_prompting):
203
- yield image_step, status_msg_step
204
-
205
-
206
-
207
-
208
- @torch.no_grad()
209
- def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
210
- cfg_scale, remasking_strategy, thinking_mode_lm):
211
- global MODEL, TOKENIZER, MASK_ID, DEVICE
212
- print(f"thinking_mode_lm: {thinking_mode_lm}")
213
- if MODEL is None or TOKENIZER is None or MASK_ID is None:
214
- yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
215
- return
216
-
217
- steps = int(steps)
218
- gen_length = int(gen_length)
219
- block_length = int(block_length)
220
-
221
- if thinking_mode_lm:
222
- prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
223
-
224
- try:
225
- m = [{"role": "user", "content": prompt_text}]
226
- processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
227
- except Exception as e:
228
- yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
229
- processed_prompt_text = prompt_text
230
- try:
231
- if TOKENIZER.pad_token_id is None:
232
- if TOKENIZER.eos_token_id is not None:
233
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
234
- else: # Should have been caught by load_model, but double check
235
- yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
236
- return
237
-
238
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
239
- raw_prompt_attention_mask = None
240
-
241
- except Exception as e:
242
- yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
243
- return
244
-
245
-
246
-
247
- batch_size = input_ids.shape[0]
248
- prompt_len = input_ids.shape[1]
249
-
250
- x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
251
- x[:, :prompt_len] = input_ids.clone()
252
-
253
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
254
-
255
- if gen_length == 0:
256
- final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
257
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
258
- return
259
-
260
- if block_length <= 0 or gen_length % block_length != 0 :
261
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
262
- f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
263
- return
264
- num_blocks = gen_length // block_length
265
-
266
- if steps <=0 or steps % num_blocks != 0:
267
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
268
- f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
269
- return
270
- steps_per_block = steps // num_blocks
271
-
272
- for num_block_iter in range(num_blocks):
273
- current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
274
- current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
275
-
276
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
277
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
278
- (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
279
-
280
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(
281
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
282
- steps_per_block
283
- )
284
 
285
- for i_step_in_block in range(steps_per_block):
286
- mask_index_global = (x == MASK_ID)
287
-
288
- if cfg_scale > 0.:
289
- un_x = x.clone()
290
- # For unconditional pass, mask out the original prompt tokens that are not padding
291
- # raw_prompt_attention_mask is (B, prompt_len)
292
- prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
293
- un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
294
-
295
- x_cfg_input = torch.cat([x, un_x], dim=0)
296
- # Pass attention_mask for CFG if model expects it, covering both parts
297
- # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
298
- model_output = MODEL(x_cfg_input)
299
- logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
300
- logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
301
- else:
302
- # Not passing explicit attention_mask here; relies on model's internal handling.
303
- model_output = MODEL(x)
304
- logits = model_output.logits
305
-
306
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
307
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
308
-
309
- if remasking_strategy == 'low_confidence':
310
- if DEVICE == "mps":
311
- probs = F.softmax(logits.to(torch.float32), dim=-1)
312
- else:
313
- probs = F.softmax(logits.to(torch.float64), dim=-1)
314
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
315
- elif remasking_strategy == 'random':
316
- if DEVICE == "mps":
317
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float32)
318
- else:
319
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
320
- else:
321
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
322
- return
323
-
324
- confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
325
- candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
326
- confidence_for_selection = torch.where(
327
- candidate_positions_for_unmasking,
328
- x0_probs,
329
- -torch.inf
330
- )
331
-
332
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
333
-
334
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
335
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
336
-
337
- for j_batch_idx in range(batch_size):
338
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
339
- candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
340
-
341
- if k_val > 0:
342
- # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
343
- conf_slice = confidence_for_selection[j_batch_idx]
344
- if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
345
-
346
- # Check if there are enough valid (non -inf) confidences
347
- valid_conf_count = (conf_slice > -torch.inf).sum().item()
348
- actual_k = min(k_val, valid_conf_count)
349
-
350
- if actual_k > 0:
351
- _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
352
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
353
-
354
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
355
-
356
- current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
357
- total_overall_steps = num_blocks * steps_per_block
358
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
359
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
360
-
361
- final_generated_ids = x[:, prompt_len:]
362
- final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
363
-
364
- final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
365
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
366
-
367
- @torch.no_grad()
368
- def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
369
- cfg_scale, remasking_strategy, thinking_mode_mmu):
370
- global MODEL, TOKENIZER, MASK_ID, DEVICE
371
-
372
- if MODEL is None or TOKENIZER is None or MASK_ID is None:
373
- yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
374
- return
375
-
376
- steps = int(steps)
377
- gen_length = int(gen_length)
378
- block_length = int(block_length)
379
-
380
- if thinking_mode_mmu:
381
- prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
382
-
383
- try:
384
- m = [{"role": "user", "content": prompt_text}]
385
- processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
386
- except Exception as e:
387
- yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
388
- processed_prompt_text = prompt_text
389
-
390
- image_vq_ids_tensor = None
391
- if uploaded_image_pil is not None:
392
- try:
393
 
394
- image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
395
- image = image.unsqueeze(0)
396
- image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
397
- except Exception as e:
398
- yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
399
- return
400
-
401
-
402
- try:
403
- if TOKENIZER.pad_token_id is None:
404
- if TOKENIZER.eos_token_id is not None:
405
- TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
406
- else:
407
- yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
408
- return
409
-
410
- input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
411
- raw_prompt_attention_mask = None
412
- if image_vq_ids_tensor is not None:
413
- if image_vq_ids_tensor.ndim == 1:
414
- image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
415
-
416
- input_ids = torch.cat([
417
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
418
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
419
- image_vq_ids_tensor,
420
- (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
421
- input_ids
422
- ], dim=1).long()
423
-
424
- else:
425
- input_ids = input_ids
426
-
427
-
428
- except Exception as e:
429
- yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
430
- return
431
-
432
-
433
-
434
- batch_size = input_ids.shape[0]
435
- prompt_len = input_ids.shape[1]
436
-
437
- x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
438
- x[:, :prompt_len] = input_ids.clone()
439
-
440
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
441
-
442
- if gen_length == 0:
443
- final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
444
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
445
- return
446
-
447
- if block_length <= 0 or gen_length % block_length != 0 :
448
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
449
- f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
450
- return
451
- num_blocks = gen_length // block_length
452
-
453
- if steps <=0 or steps % num_blocks != 0:
454
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
455
- f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
456
- return
457
- steps_per_block = steps // num_blocks
458
-
459
- for num_block_iter in range(num_blocks):
460
- current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
461
- current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
462
-
463
- block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
464
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
465
- (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
466
-
467
- num_transfer_tokens_for_this_block = get_num_transfer_tokens(
468
- block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
469
- steps_per_block
470
- )
471
 
472
- for i_step_in_block in range(steps_per_block):
473
- mask_index_global = (x == MASK_ID)
474
-
475
- if cfg_scale > 0.:
476
- un_x = x.clone()
477
- # For unconditional pass, mask out the original prompt tokens that are not padding
478
- # raw_prompt_attention_mask is (B, prompt_len)
479
- prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
480
- un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
481
-
482
- x_cfg_input = torch.cat([x, un_x], dim=0)
483
- # Pass attention_mask for CFG if model expects it, covering both parts
484
- # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
485
- model_output = MODEL(x_cfg_input)
486
- logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
487
- logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
488
- else:
489
- # Not passing explicit attention_mask here; relies on model's internal handling.
490
- model_output = MODEL(x)
491
- logits = model_output.logits
492
-
493
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
494
- x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
495
-
496
- if remasking_strategy == 'low_confidence':
497
- if DEVICE == "mps":
498
- probs = F.softmax(logits.to(torch.float32), dim=-1)
499
- else:
500
- probs = F.softmax(logits.to(torch.float64), dim=-1)
501
- x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
502
- elif remasking_strategy == 'random':
503
- if DEVICE == "mps":
504
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float32)
505
- else:
506
- x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
507
- else:
508
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
509
- return
510
-
511
- confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
512
- candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
513
- confidence_for_selection = torch.where(
514
- candidate_positions_for_unmasking,
515
- x0_probs,
516
- -torch.inf
517
- )
518
-
519
- x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
520
-
521
- transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
522
- num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
523
-
524
- for j_batch_idx in range(batch_size):
525
- k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
526
- candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
527
-
528
- if k_val > 0:
529
- # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
530
- conf_slice = confidence_for_selection[j_batch_idx]
531
- if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
532
-
533
- # Check if there are enough valid (non -inf) confidences
534
- valid_conf_count = (conf_slice > -torch.inf).sum().item()
535
- actual_k = min(k_val, valid_conf_count)
536
-
537
- if actual_k > 0:
538
- _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
539
- transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
540
-
541
- x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
542
-
543
- current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
544
- total_overall_steps = num_blocks * steps_per_block
545
- status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
546
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
547
-
548
- final_generated_ids = x[:, prompt_len:]
549
- final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
550
-
551
- final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
552
- yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
553
-
554
-
555
- css_styles = """
556
- .gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
557
- .gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
558
- .gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
559
-
560
- .highlighted-text span{
561
- padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
562
- }
563
-
564
- footer{display:none !important}
565
-
566
- #live-update-scrollable-box {
567
- max-height: 800px; /* 您可以根据需要调整这个最大高度,例如 '300px', '50vh' 等 */
568
- overflow-y: auto !important; /* 当内容超出 max-height 时显示垂直滚动条 */
569
- display: block; /* 确保元素是块级元素,以便 max-height 生效 */
570
-
571
- }
572
- #think_btn {
573
- background-color: #f3f4f6 !important;
574
- border: 1px solid #d0d0d0 !important;
575
- color: #111827 !important;
576
- font-size: 16px !important;
577
- font-weight: bold !important;
578
- }
579
- #think_btn:hover {
580
- background-color: #e0e0e0 !important;
581
- border: 1px solid #c0c0c0 !important;
582
- color: #222 !important;
583
- }
584
- #think_btn:active {
585
- background-color: #2563eb !important;
586
- border: 1px solid #b0b0b0 !important;
587
- color: white !important;
588
- }
589
- """
590
-
591
-
592
- # thinking_mode_t2i = gr.State(False)
593
- def toggle_thinking_mode_lm(current_thinking_mode):
594
- # print(f"current_thinking_mode: {current_thinking_mode}")
595
- new_state = not current_thinking_mode
596
- new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
597
- return new_state, gr.update(value=new_label)
598
-
599
- def toggle_thinking_mode_mmu(current_thinking_mode):
600
- new_state = not current_thinking_mode
601
- new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
602
- return new_state, gr.update(value=new_label)
603
-
604
-
605
- color_map_config = {
606
- "MASK": "lightgrey",
607
- "GEN": "#DCABFA",
608
- }
609
-
610
- theme = gr.themes.Ocean(
611
- primary_hue="fuchsia",
612
- )
613
- with gr.Blocks(css=css_styles, theme=theme) as demo:
614
- # with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
615
- # with gr.Blocks() as demo:
616
- thinking_mode_lm = gr.State(False)
617
- thinking_mode_mmu = gr.State(False)
618
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>MMaDA: Multimodal Large Diffusion Language Models</h1>")
619
- gr.Markdown("MMaDA is a novel class of multimodal diffusion foundation models designed to achieve superior performance across diverse domains such as textual reasoning, multimodal understanding, and text-to-image generation")
620
- gr.Markdown("Github: [Gen-Verse/MMaDA](https://github.com/Gen-Verse/MMaDA)")
621
- gr.Markdown("Paper: [MMaDA: Multimodal Large Diffusion Language Models]()")
622
- gr.Markdown("### Select Model")
623
- with gr.Row():
624
- model_select_radio = gr.Radio(
625
- label="Select Text Generation Model",
626
- choices=MODEL_CHOICES,
627
- value=MODEL_CHOICES[0]
628
- )
629
- model_load_status_box = gr.Textbox(
630
- label="Model Load Status",
631
- interactive=False,
632
- lines=3,
633
- max_lines=5
634
- )
635
 
636
- gr.Markdown("## Part 1. Text Generation")
637
- with gr.Row():
638
- with gr.Column(scale=2):
639
- prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
640
- think_button_lm = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
641
- with gr.Accordion("Generation Parameters", open=True):
642
- with gr.Row():
643
- gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
644
- steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
645
- with gr.Row():
646
- block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
647
- remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
648
- with gr.Row():
649
- cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
650
- temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
651
-
652
-
653
- with gr.Row():
654
- run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
655
- clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
656
-
657
- with gr.Column(scale=3):
658
- # gr.Markdown("## Live Generation Process")
659
- output_visualization_box_lm = gr.HighlightedText(
660
- label="Live Generation Process",
661
- show_legend=True,
662
- color_map=color_map_config,
663
- combine_adjacent=False,
664
- interactive=False,
665
- elem_id="live-update-scrollable-box",
666
- )
667
- # gr.Markdown("## Final Generated Text")
668
- output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
669
-
670
-
671
-
672
- gr.Examples(
673
- examples=[
674
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
675
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
676
- ],
677
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
678
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
679
- fn=generate_viz_wrapper_lm,
680
- )
681
 
682
- gr.Markdown("---")
683
- gr.Markdown("## Part 2. Multimodal Understanding")
684
- with gr.Row():
685
- with gr.Column(scale=2):
686
- prompt_input_box_mmu = gr.Textbox(
687
- label="Enter your prompt:",
688
- lines=3,
689
- value="Please describe this image in detail."
690
- )
691
- think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
692
- with gr.Accordion("Generation Parameters", open=True):
693
- with gr.Row():
694
- gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
695
- steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
696
- with gr.Row():
697
- block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
698
- remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
699
- with gr.Row():
700
- cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
701
- temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
702
-
703
- with gr.Row():
704
- image_upload_box = gr.Image(type="pil", label="Upload Image")
705
-
706
- with gr.Row():
707
- run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
708
- clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
709
-
710
- with gr.Column(scale=3):
711
- gr.Markdown("## Live Generation Process")
712
- output_visualization_box_mmu = gr.HighlightedText(
713
- label="Token Sequence (Live Update)",
714
- show_legend=True,
715
- color_map=color_map_config,
716
- combine_adjacent=False,
717
- interactive=False,
718
- elem_id="live-update-scrollable-box",
719
- )
720
- gr.Markdown("## Final Generated Text")
721
- output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
722
-
723
-
724
- gr.Examples(
725
- examples=[
726
- [
727
- "mmu_validation_2/sunflower.jpg",
728
- "Please describe this image in detail.",
729
- 256,
730
- 512,
731
- 128,
732
- 1,
733
- 0,
734
- "low_confidence"
735
- ],
736
- [
737
- "mmu_validation_2/woman.jpg",
738
- "Please describe this image in detail.",
739
- 256,
740
- 512,
741
- 128,
742
- 1,
743
- 0,
744
- "low_confidence"
745
- ]
746
- ],
747
- inputs=[
748
- image_upload_box,
749
- prompt_input_box_mmu,
750
- steps_slider_mmu,
751
- gen_length_slider_mmu,
752
- block_length_slider_mmu,
753
- temperature_slider_mmu,
754
- cfg_scale_slider_mmu,
755
- remasking_dropdown_mmu
756
- ],
757
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
758
- fn=generate_viz_wrapper,
759
- )
760
 
761
- gr.Markdown("---")
762
- gr.Markdown("## Part 3. Text-to-Image Generation")
763
- with gr.Row():
764
- with gr.Column(scale=2):
765
- prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
766
-
767
- with gr.Accordion("Generation Parameters", open=True):
768
- with gr.Row():
769
- steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
770
- guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.")
771
-
772
-
773
- with gr.Row():
774
- scheduler_radio_t2i = gr.Radio(
775
- choices=["cosine", "sigmoid", "linear"],
776
- value="cosine",
777
- label="Scheduler",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  )
779
 
780
- with gr.Row():
781
- run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
782
- clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
 
785
- with gr.Column(scale=3):
786
- # gr.Markdown("## Live Generation Process")
787
- output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
788
- output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
 
790
- gr.Examples(
791
- examples=[
792
- ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
793
- ["A beautiful sunset over a calm ocean, with a few clouds in the sky.", 15, 3.5, "cosine"]
794
- ],
795
- inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i],
796
- outputs=[output_image_t2i, output_status_t2i],
797
- fn=generate_viz_wrapper_t2i,
798
- )
799
-
800
- run_button_ui_t2i.click(
801
- fn=generate_viz_wrapper_t2i,
802
- inputs=[
803
- prompt_input_box_t2i,
804
- steps_slider_t2i,
805
- guidance_scale_slider_t2i,
806
- scheduler_radio_t2i
807
- ],
808
- outputs=[output_image_t2i, output_status_t2i]
809
- )
 
 
 
 
 
 
 
 
 
810
 
811
- clear_button_ui_t2i.click(
812
- fn=lambda: (None, ""),
813
- inputs=None,
814
- outputs=[output_image_t2i, output_status_t2i],
815
- queue=False
816
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
 
818
- think_button_lm.click(
819
- fn=toggle_thinking_mode_lm,
820
- inputs=[thinking_mode_lm],
821
- outputs=[thinking_mode_lm, think_button_lm]
822
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823
 
824
- think_button_mmu.click(
825
- fn=toggle_thinking_mode_mmu,
826
- inputs=[thinking_mode_mmu],
827
- outputs=[thinking_mode_mmu, think_button_mmu]
828
- )
829
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
 
832
- def initialize_default_model():
833
- default_model = "MMaDA-8B-Base"
834
- result = handle_model_selection_change(default_model)
835
- return default_model, result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
 
837
- demo.load(
838
- fn=initialize_default_model,
839
- inputs=None,
840
- outputs=[model_select_radio, model_load_status_box],
841
- queue=True
842
- )
843
 
844
- def clear_outputs():
845
- return None, None, None # Clear image, visualization, and final text
846
 
847
- clear_button_ui_lm.click(
848
- fn=clear_outputs,
849
- inputs=None,
850
- outputs=[image_upload_box, output_visualization_box_lm, output_final_text_box_lm],
851
- queue=False
852
- )
853
- clear_button_ui_mmu.click(
854
- fn=clear_outputs,
855
- inputs=None,
856
- outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu],
857
- queue=False
858
- )
859
 
860
- run_button_ui_lm.click(
861
- fn=generate_viz_wrapper_lm,
862
- inputs=[
863
- prompt_input_box_lm,
864
- steps_slider_lm,
865
- gen_length_slider_lm,
866
- block_length_slider_lm,
867
- temperature_slider_lm,
868
- cfg_scale_slider_lm,
869
- remasking_dropdown_lm,
870
- thinking_mode_lm
871
- ],
872
- outputs=[output_visualization_box_lm, output_final_text_box_lm]
873
- )
874
 
875
- run_button_ui_mmu.click(
876
- fn=generate_viz_wrapper,
877
- inputs=[
878
- image_upload_box,
879
- prompt_input_box_mmu,
880
- steps_slider_mmu,
881
- gen_length_slider_mmu,
882
- block_length_slider_mmu,
883
- temperature_slider_mmu,
884
- cfg_scale_slider_mmu,
885
- remasking_dropdown_mmu,
886
- thinking_mode_mmu
887
- ],
888
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu]
889
  )
890
 
 
 
 
891
 
892
  if __name__ == "__main__":
893
- print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
894
- demo.launch(share=True)
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ import spaces
5
+
6
+ # === Import project modules ===
7
+ PROJECT_ROOT = Path(__file__).resolve().parent
8
+ MMADA_ROOT = PROJECT_ROOT / "MMaDA"
9
+ if str(MMADA_ROOT) not in sys.path:
10
+ sys.path.insert(0, str(MMADA_ROOT))
11
+
12
+ from inference.gradio_multimodal_demo_inst import OmadaDemo
13
  import gradio as gr
14
+
15
+
16
+ # ----------------------------------------------------------------------
17
+ # 1. Asset Loading (Downloaded by entrypoint)
18
+ # ----------------------------------------------------------------------
19
+
20
+ ASSET_ROOT = PROJECT_ROOT / "_asset_cache" / "AIDAS-Omni-Modal-Diffusion-assets"
21
+ DEMO_ROOT = ASSET_ROOT # asset repo already modality-split
22
+
23
+
24
+ # ----------------------------------------------------------------------
25
+ # 2. GPU Handler Wrapper
26
+ # ----------------------------------------------------------------------
27
+
28
+ def gpu_handler(fn):
 
 
 
29
  """
30
+ Wrap an inference function using ZeroGPU.
 
 
31
  """
32
+ @spaces.GPU
33
+ def inner(*args, **kwargs):
34
+ return fn(*args, **kwargs)
35
+ return inner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # ----------------------------------------------------------------------
39
+ # 3. Build Demo UI With Examples
40
+ # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def build_zero_gpu_demo(app: OmadaDemo):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ with gr.Blocks(title="AIDAS Omni-Modal Diffusion (ZeroGPU)") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # ---------------- Header ----------------
47
+ gr.Markdown(
48
+ "<h1 style='text-align:center'>AIDAS Omni-Modal Diffusion Model</h1>"
49
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ try:
52
+ logo_path = "/mnt/data/A2E36E9F-F389-487D-9984-FFF21C9228E3.png"
53
+ gr.Image(logo_path, elem_id="logo", show_label=False, height=120)
54
+ except:
55
+ pass
56
+
57
+ gr.Markdown("### Multimodal Inference Demo (ZeroGPU Optimized)")
58
+ gr.Markdown("---")
59
+
60
+ # ---------------- Tabs ----------------
61
+
62
+ with gr.Tabs():
63
+
64
+ # ============================================================
65
+ # 1) TEXT → SPEECH (T2S)
66
+ # ============================================================
67
+ with gr.Tab("Text → Speech (T2S)"):
68
+
69
+ t2s_in = gr.Textbox(label="Input Text")
70
+ t2s_btn = gr.Button("Generate")
71
+ t2s_audio = gr.Audio(label="Speech Output")
72
+ t2s_status = gr.Textbox(label="Status", interactive=False)
73
+
74
+ t2s_examples = []
75
+ t2s_dir = DEMO_ROOT / "t2s"
76
+ if t2s_dir.exists():
77
+ for f in t2s_dir.glob("*.txt"):
78
+ txt = f.read_text().strip()
79
+ t2s_examples.append([txt])
80
+
81
+ if len(t2s_examples) > 0:
82
+ gr.Examples(
83
+ examples=t2s_examples,
84
+ inputs=[t2s_in],
85
+ outputs=[t2s_audio, t2s_status],
86
+ fn=gpu_handler(app.run_t2s),
87
+ )
88
+
89
+ t2s_btn.click(
90
+ gpu_handler(app.run_t2s),
91
+ inputs=[t2s_in],
92
+ outputs=[t2s_audio, t2s_status],
93
  )
94
 
95
+ # ============================================================
96
+ # 2) SPEECH SPEECH (S2S)
97
+ # ============================================================
98
+ with gr.Tab("Speech → Speech (S2S)"):
99
+
100
+ s2s_in = gr.Audio(type="filepath", label="Input Speech")
101
+ s2s_btn = gr.Button("Generate")
102
+ s2s_audio = gr.Audio(label="Output Speech")
103
+ s2s_status = gr.Textbox(label="Status", interactive=False)
104
+
105
+ s2s_examples = []
106
+ s2s_dir = DEMO_ROOT / "s2s"
107
+ if s2s_dir.exists():
108
+ for f in s2s_dir.glob("*.wav"):
109
+ s2s_examples.append([str(f)])
110
+
111
+ if len(s2s_examples) > 0:
112
+ gr.Examples(
113
+ examples=s2s_examples,
114
+ inputs=[s2s_in],
115
+ outputs=[s2s_audio, s2s_status],
116
+ fn=gpu_handler(app.run_s2s),
117
+ )
118
+
119
+ s2s_btn.click(
120
+ gpu_handler(app.run_s2s),
121
+ inputs=[s2s_in],
122
+ outputs=[s2s_audio, s2s_status]
123
+ )
124
 
125
+ # ============================================================
126
+ # 3) SPEECH → TEXT (S2T)
127
+ # ============================================================
128
+ with gr.Tab("Speech → Text (S2T)"):
129
+
130
+ s2t_in = gr.Audio(type="filepath", label="Input Speech")
131
+ s2t_btn = gr.Button("Transcribe")
132
+ s2t_text = gr.Textbox(label="Transcribed Text")
133
+ s2t_status = gr.Textbox(label="Status", interactive=False)
134
+
135
+ s2t_examples = []
136
+ s2t_dir = DEMO_ROOT / "s2t"
137
+ if s2t_dir.exists():
138
+ for f in s2t_dir.glob("*.wav"):
139
+ s2t_examples.append([str(f)])
140
+
141
+ if len(s2t_examples) > 0:
142
+ gr.Examples(
143
+ examples=s2t_examples,
144
+ inputs=[s2t_in],
145
+ outputs=[s2t_text, s2t_status],
146
+ fn=gpu_handler(app.run_s2t),
147
+ )
148
+
149
+ s2t_btn.click(
150
+ gpu_handler(app.run_s2t),
151
+ inputs=[s2t_in],
152
+ outputs=[s2t_text, s2t_status],
153
+ )
154
 
155
+ # ============================================================
156
+ # 4) VIDEO TEXT (V2T)
157
+ # ============================================================
158
+ with gr.Tab("Video Text (V2T)"):
159
+
160
+ v2t_in = gr.Video(type="filepath", label="Input Video")
161
+ v2t_btn = gr.Button("Generate Caption")
162
+ v2t_text = gr.Textbox(label="Caption")
163
+ v2t_status = gr.Textbox(label="Status")
164
+
165
+ v2t_examples = []
166
+ v2t_dir = DEMO_ROOT / "v2t"
167
+ if v2t_dir.exists():
168
+ for f in v2t_dir.glob("*.mp4"):
169
+ v2t_examples.append([str(f)])
170
+
171
+ if len(v2t_examples) > 0:
172
+ gr.Examples(
173
+ examples=v2t_examples,
174
+ inputs=[v2t_in],
175
+ outputs=[v2t_text, v2t_status],
176
+ fn=gpu_handler(app.run_v2t),
177
+ )
178
+
179
+ v2t_btn.click(
180
+ gpu_handler(app.run_v2t),
181
+ inputs=[v2t_in],
182
+ outputs=[v2t_text, v2t_status],
183
+ )
184
 
185
+ # ============================================================
186
+ # 5) VIDEO → SPEECH (V2S)
187
+ # ============================================================
188
+ with gr.Tab("Video Speech (V2S)"):
189
+
190
+ v2s_in = gr.Video(type="filepath", label="Input Video")
191
+ v2s_btn = gr.Button("Generate Speech")
192
+ v2s_audio = gr.Audio(label="Speech Output")
193
+ v2s_status = gr.Textbox(label="Status")
194
+
195
+ v2s_examples = []
196
+ v2s_dir = DEMO_ROOT / "v2s"
197
+ if v2s_dir.exists():
198
+ for f in v2s_dir.glob("*.mp4"):
199
+ v2s_examples.append([str(f)])
200
+
201
+ if len(v2s_examples) > 0:
202
+ gr.Examples(
203
+ examples=v2s_examples,
204
+ inputs=[v2s_in],
205
+ outputs=[v2s_audio, v2s_status],
206
+ fn=gpu_handler(app.run_v2s),
207
+ )
208
+
209
+ v2s_btn.click(
210
+ gpu_handler(app.run_v2s),
211
+ inputs=[v2s_in],
212
+ outputs=[v2s_audio, v2s_status],
213
+ )
214
 
215
+ # ============================================================
216
+ # 6) IMAGE → SPEECH (I2S)
217
+ # ============================================================
218
+ with gr.Tab("Image → Speech (I2S)"):
219
+
220
+ i2s_in = gr.Image(type="filepath", label="Input Image")
221
+ i2s_btn = gr.Button("Generate Speech")
222
+ i2s_audio = gr.Audio(label="Speech")
223
+ i2s_status = gr.Textbox(label="Status")
224
+
225
+ # Only if folder exists
226
+ i2s_examples = []
227
+ i2s_dir = DEMO_ROOT / "i2s"
228
+ if i2s_dir.exists():
229
+ for f in i2s_dir.glob("*.*"):
230
+ i2s_examples.append([str(f)])
231
+
232
+ if len(i2s_examples) > 0:
233
+ gr.Examples(
234
+ examples=i2s_examples,
235
+ inputs=[i2s_in],
236
+ outputs=[i2s_audio, i2s_status],
237
+ fn=gpu_handler(app.run_i2s),
238
+ )
239
+
240
+ i2s_btn.click(
241
+ gpu_handler(app.run_i2s),
242
+ inputs=[i2s_in],
243
+ outputs=[i2s_audio, i2s_status],
244
+ )
245
 
246
+ # ============================================================
247
+ # 7) CHAT
248
+ # ============================================================
249
+ with gr.Tab("Chat (Text)"):
250
+
251
+ chat_in = gr.Textbox(label="Message")
252
+ chat_btn = gr.Button("Send")
253
+ chat_out = gr.Textbox(label="Response")
254
+ chat_status = gr.Textbox(label="Status")
255
+
256
+ chat_examples = []
257
+ chat_dir = DEMO_ROOT / "chat"
258
+ if chat_dir.exists():
259
+ for f in chat_dir.glob("*.txt"):
260
+ txt = f.read_text().strip()
261
+ chat_examples.append([txt])
262
+
263
+ if len(chat_examples) > 0:
264
+ gr.Examples(
265
+ examples=chat_examples,
266
+ inputs=[chat_in],
267
+ outputs=[chat_out, chat_status],
268
+ fn=gpu_handler(app.run_chat),
269
+ )
270
+
271
+ chat_btn.click(
272
+ gpu_handler(app.run_chat),
273
+ inputs=[chat_in],
274
+ outputs=[chat_out, chat_status],
275
+ )
276
 
277
+ # ============================================================
278
+ # 8) MMU (2 images → text)
279
+ # ============================================================
280
+ with gr.Tab("MMU (Dual-Image Reasoning)"):
281
+
282
+ mmu_img1 = gr.Image(type="filepath", label="Image 1")
283
+ mmu_img2 = gr.Image(type="filepath", label="Image 2")
284
+ mmu_prompt = gr.Textbox(label="Prompt")
285
+ mmu_btn = gr.Button("Run MMU")
286
+ mmu_out = gr.Textbox(label="Output")
287
+ mmu_status = gr.Textbox(label="Status")
288
+
289
+ mmu_examples = []
290
+ mmu_dir = DEMO_ROOT / "mmu"
291
+ if mmu_dir.exists():
292
+ imgs = list(mmu_dir.glob("*.png"))
293
+ if len(imgs) >= 2:
294
+ mmu_examples.append([
295
+ str(imgs[0]),
296
+ str(imgs[1]),
297
+ "Describe the relation between two objects."
298
+ ])
299
+
300
+ if len(mmu_examples) > 0:
301
+ gr.Examples(
302
+ examples=mmu_examples,
303
+ inputs=[mmu_img1, mmu_img2, mmu_prompt],
304
+ outputs=[mmu_out, mmu_status],
305
+ fn=gpu_handler(app.run_mmu_dual),
306
+ )
307
+
308
+ mmu_btn.click(
309
+ gpu_handler(app.run_mmu_dual),
310
+ inputs=[mmu_img1, mmu_img2, mmu_prompt],
311
+ outputs=[mmu_out, mmu_status]
312
+ )
313
 
314
+ # ============================================================
315
+ # 9) TEXT → IMAGE (T2I)
316
+ # ============================================================
317
+ with gr.Tab("Text → Image (T2I)"):
318
+
319
+ t2i_in = gr.Textbox(label="Prompt")
320
+ t2i_btn = gr.Button("Generate Image")
321
+ t2i_img = gr.Image(label="Generated Image")
322
+ t2i_status = gr.Textbox(label="Status")
323
+
324
+ t2i_examples = []
325
+ t2i_dir = DEMO_ROOT / "t2i"
326
+ if t2i_dir.exists():
327
+ for f in t2i_dir.glob("*.txt"):
328
+ txt = f.read_text().strip()
329
+ t2i_examples.append([txt])
330
+
331
+ if len(t2i_examples) > 0:
332
+ gr.Examples(
333
+ examples=t2i_examples,
334
+ inputs=[t2i_in],
335
+ outputs=[t2i_img, t2i_status],
336
+ fn=gpu_handler(app.run_t2i),
337
+ )
338
+
339
+ t2i_btn.click(
340
+ gpu_handler(app.run_t2i),
341
+ inputs=[t2i_in],
342
+ outputs=[t2i_img, t2i_status],
343
+ )
344
 
345
+ # ============================================================
346
+ # 10) IMAGE EDITING (I2I)
347
+ # ============================================================
348
+ with gr.Tab("Image Editing (I2I)"):
349
+
350
+ i2i_in = gr.Image(type="filepath", label="Input Image")
351
+ i2i_prompt = gr.Textbox(label="Edit Instruction")
352
+ i2i_btn = gr.Button("Apply Edit")
353
+ i2i_img = gr.Image(label="Edited Image")
354
+ i2i_status = gr.Textbox(label="Status")
355
+
356
+ i2i_examples = []
357
+ i2i_dir = DEMO_ROOT / "i2i"
358
+ if i2i_dir.exists():
359
+ for f in i2i_dir.glob("*.*"):
360
+ i2i_examples.append([str(f), "Make it more vibrant."])
361
+
362
+ if len(i2i_examples) > 0:
363
+ gr.Examples(
364
+ examples=i2i_examples,
365
+ inputs=[i2i_in, i2i_prompt],
366
+ outputs=[i2i_img, i2i_status],
367
+ fn=gpu_handler(app.run_i2i),
368
+ )
369
+
370
+ i2i_btn.click(
371
+ gpu_handler(app.run_i2i),
372
+ inputs=[i2i_in, i2i_prompt],
373
+ outputs=[i2i_img, i2i_status]
374
+ )
375
 
376
+ # End Tabs
 
 
 
 
 
377
 
378
+ return demo
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ # ----------------------------------------------------------------------
382
+ # 4. Entry Point for Space
383
+ # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
384
 
385
+ @spaces.GPU
386
+ def main():
387
+ app = OmadaDemo(
388
+ train_config=str(MMADA_ROOT / "inference/demo/demo.yaml"),
389
+ checkpoint=os.getenv("MODEL_CHECKPOINT_DIR", "_ckpt_cache/omada"),
390
+ device="cpu"
 
 
 
 
 
 
 
 
391
  )
392
 
393
+ demo = build_zero_gpu_demo(app)
394
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
395
+
396
 
397
  if __name__ == "__main__":
398
+ main()
 
app.py CHANGED
@@ -1,898 +1,398 @@
1
- # """
2
- # Gradio Space entrypoint mirroring `MMaDA/inference/gradio_multimodal_demo_inst.py`.
3
- # It downloads the published checkpoint once via huggingface_hub, wires it into
4
- # OmadaDemo, and launches the existing Blocks UI.
5
-
6
- # Environment overrides:
7
- # MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
8
- # MODEL_REVISION (default: main)
9
- # ASSET_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion-assets)
10
- # ASSET_REVISION (default: main)
11
- # STYLE_REPO_ID (default: jaeikkim/aidas-style-centroid)
12
- # STYLE_REVISION (default: main)
13
- # HF_TOKEN (optional, for private model/dataset)
14
- # TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
15
- # DEVICE (default: auto cuda/cpu)
16
- # PORT (default: 7860; Space sets this)
17
- # """
18
-
19
- # import os
20
- # import sys
21
- # import subprocess
22
- # import importlib
23
- # import spaces
24
- # from pathlib import Path
25
-
26
- # from packaging.version import parse as parse_version
27
-
28
- # # Ensure local project is importable
29
- # PROJECT_ROOT = Path(__file__).resolve().parent
30
- # MMADA_ROOT = PROJECT_ROOT / "MMaDA"
31
- # if str(MMADA_ROOT) not in sys.path:
32
- # sys.path.insert(0, str(MMADA_ROOT))
33
- # EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer"
34
- # if str(EMOVA_ROOT) not in sys.path:
35
- # sys.path.insert(0, str(EMOVA_ROOT))
36
-
37
-
38
- # def ensure_hf_hub(target: str = "0.36.0"):
39
- # """
40
- # Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
41
- # The Space base image installs gradio which may upgrade it to 1.x; we downgrade here.
42
- # """
43
- # try:
44
- # import huggingface_hub as hub
45
- # except ImportError:
46
- # subprocess.check_call(
47
- # [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
48
- # )
49
- # import huggingface_hub as hub
50
-
51
- # if parse_version(hub.__version__) >= parse_version("1.0.0"):
52
- # subprocess.check_call(
53
- # [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
54
- # )
55
- # hub = importlib.reload(hub)
56
- # # Backfill missing constants in older hub versions to avoid AttributeError.
57
- # try:
58
- # import huggingface_hub.constants as hub_consts # type: ignore
59
- # except Exception:
60
- # hub_consts = None
61
- # if hub_consts and not hasattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER"):
62
- # setattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER", False)
63
- # return hub
64
-
65
-
66
- # snapshot_download = ensure_hf_hub().snapshot_download
67
-
68
- # from inference.gradio_multimodal_demo_inst import OmadaDemo, build_demo # noqa: E402
69
-
70
-
71
- # def download_assets() -> Path:
72
- # """Download demo assets (logo + sample prompts/media) and return the root path."""
73
- # repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets")
74
- # revision = os.getenv("ASSET_REVISION", "main")
75
- # token = os.getenv("HF_TOKEN")
76
- # cache_dir = PROJECT_ROOT / "_asset_cache"
77
- # cache_dir.mkdir(parents=True, exist_ok=True)
78
-
79
- # return Path(
80
- # snapshot_download(
81
- # repo_id=repo_id,
82
- # revision=revision,
83
- # repo_type="dataset",
84
- # local_dir=cache_dir,
85
- # local_dir_use_symlinks=False,
86
- # token=token,
87
- # )
88
- # )
89
-
90
-
91
- # def download_style() -> Path:
92
- # """Download style centroid dataset and return the root path."""
93
- # repo_id = os.getenv("STYLE_REPO_ID", "jaeikkim/aidas-style-centroid")
94
- # revision = os.getenv("STYLE_REVISION", "main")
95
- # token = os.getenv("HF_TOKEN")
96
- # cache_dir = PROJECT_ROOT / "_style_cache"
97
- # cache_dir.mkdir(parents=True, exist_ok=True)
98
-
99
- # return Path(
100
- # snapshot_download(
101
- # repo_id=repo_id,
102
- # revision=revision,
103
- # repo_type="dataset",
104
- # local_dir=cache_dir,
105
- # local_dir_use_symlinks=False,
106
- # token=token,
107
- # )
108
- # )
109
-
110
-
111
- # def download_checkpoint() -> Path:
112
- # """Download checkpoint snapshot and return an `unwrapped_model` directory."""
113
- # repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
114
- # revision = os.getenv("MODEL_REVISION", "main")
115
- # token = os.getenv("HF_TOKEN")
116
- # cache_dir = PROJECT_ROOT / "_ckpt_cache"
117
- # cache_dir.mkdir(parents=True, exist_ok=True)
118
-
119
- # snapshot_path = Path(
120
- # snapshot_download(
121
- # repo_id=repo_id,
122
- # revision=revision,
123
- # repo_type="model",
124
- # local_dir=cache_dir,
125
- # local_dir_use_symlinks=False,
126
- # token=token,
127
- # )
128
- # )
129
-
130
- # # If snapshot itself is unwrapped_model, return it; otherwise point a symlink to it.
131
- # if snapshot_path.name == "unwrapped_model":
132
- # return snapshot_path
133
- # nested = snapshot_path / "unwrapped_model"
134
- # if nested.is_dir():
135
- # return nested
136
- # aliased = snapshot_path.parent / "unwrapped_model"
137
- # if not aliased.exists():
138
- # aliased.symlink_to(snapshot_path, target_is_directory=True)
139
- # return aliased
140
-
141
-
142
- # @spaces.GPU
143
- # def main():
144
- # checkpoint_dir = download_checkpoint()
145
- # asset_root = download_assets()
146
- # style_root = download_style()
147
-
148
- # # Symlink style centroid npy files to expected locations
149
- # style_targets = [
150
- # MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid",
151
- # PROJECT_ROOT
152
- # / "EMOVA_speech_tokenizer"
153
- # / "emova_speech_tokenizer"
154
- # / "speech_tokenization"
155
- # / "condition_style_centroid",
156
- # ]
157
- # for starget in style_targets:
158
- # if starget.exists():
159
- # continue
160
- # starget.parent.mkdir(parents=True, exist_ok=True)
161
- # starget.symlink_to(style_root, target_is_directory=True)
162
-
163
- # # Point demo assets (logo, sample prompts/media) to the downloaded dataset
164
- # from inference import gradio_multimodal_demo_inst as demo_mod # noqa: WPS433
165
-
166
- # demo_root = asset_root / "demo"
167
- # demo_mod.DEMO_ROOT = demo_root
168
- # demo_mod.LOGO_PATH = demo_root / "logo.png"
169
- # demo_mod.T2S_TEXT_PATH = demo_root / "t2s" / "text.txt"
170
- # demo_mod.CHAT_TEXT_PATH = demo_root / "chat" / "text.txt"
171
- # demo_mod.T2I_TEXT_PATH = demo_root / "t2i" / "text.txt"
172
-
173
- # default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
174
- # legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
175
- # train_config = os.getenv("TRAIN_CONFIG_PATH")
176
- # if not train_config:
177
- # # Prefer configs/mmada_demo.yaml (in repo), fallback to legacy path if restored.
178
- # train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
179
- # device = os.getenv("DEVICE")
180
- # port = int(os.getenv("PORT", "7860"))
181
-
182
- # app = OmadaDemo(train_config=train_config, checkpoint=str(checkpoint_dir), device=device)
183
- # build_demo(app, share=False, server_name="0.0.0.0", server_port=port)
184
-
185
-
186
- # if __name__ == "__main__":
187
- # main()
188
-
189
- """
190
- ZeroGPU-friendly Gradio entrypoint for OMada demo.
191
-
192
- - Downloads checkpoint + assets + style centroids from Hugging Face Hub
193
- - Instantiates OmadaDemo once (global)
194
- - Exposes 10 modalities via Gradio tabs
195
- - Uses @spaces.GPU only on inference handlers so GPU is allocated per request
196
-
197
- Environment overrides:
198
- MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion)
199
- MODEL_REVISION (default: main)
200
- ASSET_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion-assets)
201
- ASSET_REVISION (default: main)
202
- STYLE_REPO_ID (default: jaeikkim/aidas-style-centroid)
203
- STYLE_REVISION (default: main)
204
- HF_TOKEN (optional, for private model/dataset)
205
- TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml)
206
- DEVICE (default: cuda)
207
- """
208
-
209
  import os
210
  import sys
211
- import subprocess
212
- import importlib
213
  from pathlib import Path
214
-
215
- import gradio as gr
216
  import spaces
217
- from packaging.version import parse as parse_version
218
-
219
- # ---------------------------
220
- # Project roots & sys.path
221
- # ---------------------------
222
 
 
223
  PROJECT_ROOT = Path(__file__).resolve().parent
224
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
225
  if str(MMADA_ROOT) not in sys.path:
226
  sys.path.insert(0, str(MMADA_ROOT))
227
 
228
- EMOVA_ROOT = PROJECT_ROOT / "EMOVA_speech_tokenizer"
229
- if str(EMOVA_ROOT) not in sys.path:
230
- sys.path.insert(0, str(EMOVA_ROOT))
231
-
232
-
233
- # ---------------------------
234
- # HuggingFace Hub helper
235
- # ---------------------------
236
-
237
- def ensure_hf_hub(target: str = "0.36.0"):
238
- """
239
- Make sure huggingface_hub stays <1.0 to satisfy transformers/tokenizers.
240
-
241
- The Spaces base image may pull in a newer version via gradio, so we pin it.
242
- """
243
- try:
244
- import huggingface_hub as hub
245
- except ImportError:
246
- subprocess.check_call(
247
- [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
248
- )
249
- import huggingface_hub as hub
250
-
251
- if parse_version(hub.__version__) >= parse_version("1.0.0"):
252
- subprocess.check_call(
253
- [sys.executable, "-m", "pip", "install", f"huggingface-hub=={target}", "--no-cache-dir"]
254
- )
255
- hub = importlib.reload(hub)
256
-
257
- # Backfill missing constants in older hub versions to avoid AttributeError.
258
- try:
259
- import huggingface_hub.constants as hub_consts # type: ignore
260
- except Exception:
261
- hub_consts = None
262
- if hub_consts and not hasattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER"):
263
- setattr(hub_consts, "HF_HUB_ENABLE_HF_TRANSFER", False)
264
- return hub
265
-
266
-
267
- snapshot_download = ensure_hf_hub().snapshot_download
268
-
269
-
270
- # ---------------------------
271
- # Imports from OMada demo
272
- # ---------------------------
273
-
274
- from inference.gradio_multimodal_demo_inst import ( # noqa: E402
275
- OmadaDemo,
276
- CUSTOM_CSS,
277
- FORCE_LIGHT_MODE_JS,
278
- )
279
-
280
-
281
- # ---------------------------
282
- # HF download helpers
283
- # ---------------------------
284
-
285
- def download_assets() -> Path:
286
- """Download demo assets (logo + sample prompts/media) and return the root path."""
287
- repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets")
288
- revision = os.getenv("ASSET_REVISION", "main")
289
- token = os.getenv("HF_TOKEN")
290
- cache_dir = PROJECT_ROOT / "_asset_cache"
291
- cache_dir.mkdir(parents=True, exist_ok=True)
292
-
293
- return Path(
294
- snapshot_download(
295
- repo_id=repo_id,
296
- revision=revision,
297
- repo_type="dataset",
298
- local_dir=cache_dir,
299
- local_dir_use_symlinks=False,
300
- token=token,
301
- )
302
- )
303
-
304
-
305
- def download_style() -> Path:
306
- """Download style centroid dataset and return the root path."""
307
- repo_id = os.getenv("STYLE_REPO_ID", "jaeikkim/aidas-style-centroid")
308
- revision = os.getenv("STYLE_REVISION", "main")
309
- token = os.getenv("HF_TOKEN")
310
- cache_dir = PROJECT_ROOT / "_style_cache"
311
- cache_dir.mkdir(parents=True, exist_ok=True)
312
-
313
- return Path(
314
- snapshot_download(
315
- repo_id=repo_id,
316
- revision=revision,
317
- repo_type="dataset",
318
- local_dir=cache_dir,
319
- local_dir_use_symlinks=False,
320
- token=token,
321
- )
322
- )
323
-
324
-
325
- def download_checkpoint() -> Path:
326
- """Download checkpoint snapshot and return an `unwrapped_model` directory."""
327
- repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion")
328
- revision = os.getenv("MODEL_REVISION", "main")
329
- token = os.getenv("HF_TOKEN")
330
- cache_dir = PROJECT_ROOT / "_ckpt_cache"
331
- cache_dir.mkdir(parents=True, exist_ok=True)
332
-
333
- snapshot_path = Path(
334
- snapshot_download(
335
- repo_id=repo_id,
336
- revision=revision,
337
- repo_type="model",
338
- local_dir=cache_dir,
339
- local_dir_use_symlinks=False,
340
- token=token,
341
- )
342
- )
343
-
344
- # If snapshot itself is unwrapped_model, return it; otherwise look for nested dir,
345
- # and finally alias via symlink.
346
- if snapshot_path.name == "unwrapped_model":
347
- return snapshot_path
348
-
349
- nested = snapshot_path / "unwrapped_model"
350
- if nested.is_dir():
351
- return nested
352
-
353
- aliased = snapshot_path.parent / "unwrapped_model"
354
- if not aliased.exists():
355
- aliased.symlink_to(snapshot_path, target_is_directory=True)
356
- return aliased
357
-
358
-
359
- # ---------------------------
360
- # Global OmadaDemo instance
361
- # ---------------------------
362
-
363
- APP = None # type: ignore
364
-
365
-
366
- def get_app() -> OmadaDemo:
367
- global APP
368
- if APP is not None:
369
- return APP
370
-
371
- # Download everything once
372
- ckpt_dir = download_checkpoint()
373
- asset_root = download_assets()
374
- style_root = download_style()
375
-
376
- # Wire style centroids to expected locations
377
- style_targets = [
378
- MMADA_ROOT / "models" / "speech_tokenization" / "condition_style_centroid",
379
- PROJECT_ROOT
380
- / "EMOVA_speech_tokenizer"
381
- / "emova_speech_tokenizer"
382
- / "speech_tokenization"
383
- / "condition_style_centroid",
384
- ]
385
- for starget in style_targets:
386
- if not starget.exists():
387
- starget.parent.mkdir(parents=True, exist_ok=True)
388
- starget.symlink_to(style_root, target_is_directory=True)
389
-
390
- # Choose train config
391
- default_cfg = PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"
392
- legacy_cfg = PROJECT_ROOT / "MMaDA" / "configs" / "mmada_demo.yaml"
393
- train_config = os.getenv("TRAIN_CONFIG_PATH")
394
- if not train_config:
395
- train_config = str(default_cfg if default_cfg.exists() else legacy_cfg)
396
-
397
- # Device: in ZeroGPU environment, "cuda" is virtualized and only actually
398
- # attached inside @spaces.GPU handlers.
399
- device = os.getenv("DEVICE", "cuda")
400
-
401
- APP = OmadaDemo(train_config=train_config, checkpoint=str(ckpt_dir), device=device)
402
- return APP
403
-
404
-
405
- # ---------------------------
406
- # ZeroGPU-wrapped handlers
407
- # ---------------------------
408
-
409
- @spaces.GPU
410
- def t2s_handler(
411
- text,
412
- max_tokens,
413
- steps,
414
- block_len,
415
- temperature,
416
- cfg_scale,
417
- gender,
418
- emotion,
419
- speed,
420
- pitch,
421
- ):
422
- app = get_app()
423
- audio, status = app.run_t2s(
424
- text=text,
425
- max_new_tokens=int(max_tokens),
426
- steps=int(steps),
427
- block_length=int(block_len),
428
- temperature=float(temperature),
429
- cfg_scale=float(cfg_scale),
430
- gender_choice=gender,
431
- emotion_choice=emotion,
432
- speed_choice=speed,
433
- pitch_choice=pitch,
434
- )
435
- return audio, status
436
-
437
-
438
- @spaces.GPU
439
- def s2s_handler(
440
- audio_path,
441
- max_tokens,
442
- steps,
443
- block_len,
444
- temperature,
445
- cfg_scale,
446
- ):
447
- app = get_app()
448
- audio, status = app.run_s2s(
449
- audio_path=audio_path,
450
- max_new_tokens=int(max_tokens),
451
- steps=int(steps),
452
- block_length=int(block_len),
453
- temperature=float(temperature),
454
- cfg_scale=float(cfg_scale),
455
- )
456
- return audio, status
457
-
458
-
459
- @spaces.GPU
460
- def s2t_handler(
461
- audio_path,
462
- steps,
463
- block_len,
464
- max_tokens,
465
- remasking,
466
- ):
467
- app = get_app()
468
- text, status = app.run_s2t(
469
- audio_path=audio_path,
470
- steps=int(steps),
471
- block_length=int(block_len),
472
- max_new_tokens=int(max_tokens),
473
- remasking=str(remasking),
474
- )
475
- return text, status
476
-
477
 
478
- @spaces.GPU
479
- def v2t_handler(
480
- video,
481
- steps,
482
- block_len,
483
- max_tokens,
484
- ):
485
- app = get_app()
486
- text, status = app.run_v2t(
487
- video_path=video,
488
- steps=int(steps),
489
- block_length=int(block_len),
490
- max_new_tokens=int(max_tokens),
491
- )
492
- return text, status
493
 
 
 
 
494
 
495
- @spaces.GPU
496
- def v2s_handler(
497
- video,
498
- message,
499
- max_tokens,
500
- steps,
501
- block_len,
502
- temperature,
503
- cfg_scale,
504
- ):
505
- app = get_app()
506
- audio, status = app.run_v2s(
507
- video_path=video,
508
- message=message,
509
- max_new_tokens=int(max_tokens),
510
- steps=int(steps),
511
- block_length=int(block_len),
512
- temperature=float(temperature),
513
- cfg_scale=float(cfg_scale),
514
- )
515
- return audio, status
516
 
517
 
518
- @spaces.GPU
519
- def i2s_handler(
520
- image,
521
- message,
522
- max_tokens,
523
- steps,
524
- block_len,
525
- temperature,
526
- cfg_scale,
527
- ):
528
- app = get_app()
529
- audio, status = app.run_i2s(
530
- image=image,
531
- message=message,
532
- max_new_tokens=int(max_tokens),
533
- steps=int(steps),
534
- block_length=int(block_len),
535
- temperature=float(temperature),
536
- cfg_scale=float(cfg_scale),
537
- )
538
- return audio, status
539
 
 
 
 
 
 
 
 
 
540
 
541
- @spaces.GPU
542
- def chat_handler(
543
- message,
544
- max_tokens,
545
- steps,
546
- block_len,
547
- temperature,
548
- ):
549
- app = get_app()
550
- text, status = app.run_chat(
551
- message=message,
552
- max_new_tokens=int(max_tokens),
553
- steps=int(steps),
554
- block_length=int(block_len),
555
- temperature=float(temperature),
556
- )
557
- return text, status
558
 
 
 
 
559
 
560
- @spaces.GPU
561
- def mmu_handler(
562
- image_a,
563
- image_b,
564
- question,
565
- max_tokens,
566
- steps,
567
- block_len,
568
- temperature,
569
- ):
570
- app = get_app()
571
- text, status = app.run_mmu_dual(
572
- image_a=image_a,
573
- image_b=image_b,
574
- message=question,
575
- max_new_tokens=int(max_tokens),
576
- steps=int(steps),
577
- block_length=int(block_len),
578
- temperature=float(temperature),
579
- )
580
- return text, status
581
 
 
582
 
583
- @spaces.GPU
584
- def t2i_handler(
585
- prompt,
586
- timesteps,
587
- temperature,
588
- guidance,
589
- ):
590
- app = get_app()
591
- image, status = app.run_t2i(
592
- prompt=prompt,
593
- timesteps=int(timesteps),
594
- temperature=float(temperature),
595
- guidance_scale=float(guidance),
596
- )
597
- return image, status
598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  @spaces.GPU
601
- def i2i_handler(
602
- instruction,
603
- image,
604
- timesteps,
605
- temperature,
606
- guidance,
607
- ):
608
- app = get_app()
609
- image_out, status = app.run_i2i(
610
- instruction=instruction,
611
- source_image=image,
612
- timesteps=int(timesteps),
613
- temperature=float(temperature),
614
- guidance_scale=float(guidance),
615
- )
616
- return image_out, status
617
-
618
-
619
- # ---------------------------
620
- # Gradio UI (10 tabs)
621
- # ---------------------------
622
-
623
- theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray")
624
-
625
- with gr.Blocks(
626
- title="AIDAS Lab @ SNU - OMni-modal Diffusion (ZeroGPU)",
627
- css=CUSTOM_CSS,
628
- theme=theme,
629
- js=FORCE_LIGHT_MODE_JS,
630
- ) as demo:
631
- gr.Markdown(
632
- "## Omni-modal Diffusion Foundation Model\n"
633
- "### ZeroGPU-compatible demo (AIDAS Lab @ SNU)"
634
  )
635
 
636
- with gr.Tab("Text → Speech (T2S)"):
637
- with gr.Row():
638
- t2s_text = gr.Textbox(
639
- label="Input text",
640
- lines=4,
641
- placeholder="Type the speech you want to synthesize...",
642
- )
643
- t2s_audio = gr.Audio(label="Generated speech", type="numpy")
644
- t2s_status = gr.Textbox(label="Status", interactive=False)
645
- with gr.Accordion("Advanced settings", open=False):
646
- t2s_max_tokens = gr.Slider(2, 512, value=384, step=2, label="Speech token length")
647
- t2s_steps = gr.Slider(2, 512, value=128, step=2, label="Total refinement steps")
648
- t2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
649
- t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
650
- t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, step=0.1, label="CFG scale")
651
- with gr.Row():
652
- t2s_gender = gr.Dropdown(["random", "female", "male"], value="random", label="Gender")
653
- t2s_emotion = gr.Dropdown(["random", "angry", "happy", "neutral", "sad"], value="random", label="Emotion")
654
- with gr.Row():
655
- t2s_speed = gr.Dropdown(["random", "normal", "fast", "slow"], value="random", label="Speed")
656
- t2s_pitch = gr.Dropdown(["random", "normal", "high", "low"], value="random", label="Pitch")
657
- t2s_btn = gr.Button("Generate speech", variant="primary")
658
- t2s_btn.click(
659
- t2s_handler,
660
- inputs=[
661
- t2s_text,
662
- t2s_max_tokens,
663
- t2s_steps,
664
- t2s_block,
665
- t2s_temperature,
666
- t2s_cfg,
667
- t2s_gender,
668
- t2s_emotion,
669
- t2s_speed,
670
- t2s_pitch,
671
- ],
672
- outputs=[t2s_audio, t2s_status],
673
- )
674
-
675
- with gr.Tab("Speech → Speech (S2S)"):
676
- s2s_audio_in = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"])
677
- s2s_audio_out = gr.Audio(type="numpy", label="Reply speech")
678
- s2s_status = gr.Textbox(label="Status", interactive=False)
679
- with gr.Accordion("Advanced settings", open=False):
680
- s2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
681
- s2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps")
682
- s2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
683
- s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="Sampling temperature")
684
- s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, step=0.1, label="CFG scale")
685
- s2s_btn = gr.Button("Generate reply speech", variant="primary")
686
- s2s_btn.click(
687
- s2s_handler,
688
- inputs=[
689
- s2s_audio_in,
690
- s2s_max_tokens,
691
- s2s_steps,
692
- s2s_block,
693
- s2s_temperature,
694
- s2s_cfg,
695
- ],
696
- outputs=[s2s_audio_out, s2s_status],
697
- )
698
-
699
- with gr.Tab("Speech → Text (S2T)"):
700
- s2t_audio_in = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"])
701
- s2t_text_out = gr.Textbox(label="Transcription", lines=4)
702
- s2t_status = gr.Textbox(label="Status", interactive=False)
703
- with gr.Accordion("Advanced settings", open=False):
704
- s2t_steps = gr.Slider(2, 512, value=128, step=2, label="Denoising steps")
705
- s2t_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
706
- s2t_max_tokens = gr.Slider(2, 512, value=128, step=2, label="Max new tokens")
707
- s2t_remasking = gr.Dropdown(
708
- ["low_confidence", "random"],
709
- value="low_confidence",
710
- label="Remasking strategy",
711
- )
712
- s2t_btn = gr.Button("Transcribe", variant="primary")
713
- s2t_btn.click(
714
- s2t_handler,
715
- inputs=[s2t_audio_in, s2t_steps, s2t_block, s2t_max_tokens, s2t_remasking],
716
- outputs=[s2t_text_out, s2t_status],
717
- )
718
-
719
- with gr.Tab("Video → Text (V2T)"):
720
- v2t_video_in = gr.Video(
721
- label="Upload or record video",
722
- height=256,
723
- sources=["upload", "webcam"],
724
- )
725
- v2t_text_out = gr.Textbox(label="Caption / answer", lines=4)
726
- v2t_status = gr.Textbox(label="Status", interactive=False)
727
- with gr.Accordion("Advanced settings", open=False):
728
- v2t_steps = gr.Slider(2, 512, value=64, step=2, label="Denoising steps")
729
- v2t_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
730
- v2t_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Max new tokens")
731
- v2t_btn = gr.Button("Generate caption", variant="primary")
732
- v2t_btn.click(
733
- v2t_handler,
734
- inputs=[v2t_video_in, v2t_steps, v2t_block, v2t_max_tokens],
735
- outputs=[v2t_text_out, v2t_status],
736
- )
737
-
738
- with gr.Tab("Video → Speech (V2S)"):
739
- v2s_video_in = gr.Video(
740
- label="Upload or record video",
741
- height=256,
742
- sources=["upload", "webcam"],
743
- )
744
- v2s_prompt = gr.Textbox(
745
- label="Optional instruction",
746
- placeholder="(Optional) e.g., 'Describe this scene in spoken form.'",
747
- )
748
- v2s_audio_out = gr.Audio(type="numpy", label="Generated speech")
749
- v2s_status = gr.Textbox(label="Status", interactive=False)
750
- with gr.Accordion("Advanced settings", open=False):
751
- v2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
752
- v2s_steps = gr.Slider(2, 512, value=128, step=2, label="Refinement steps")
753
- v2s_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
754
- v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
755
- v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
756
- v2s_btn = gr.Button("Generate speech from video", variant="primary")
757
- v2s_btn.click(
758
- v2s_handler,
759
- inputs=[
760
- v2s_video_in,
761
- v2s_prompt,
762
- v2s_max_tokens,
763
- v2s_steps,
764
- v2s_block,
765
- v2s_temperature,
766
- v2s_cfg,
767
- ],
768
- outputs=[v2s_audio_out, v2s_status],
769
- )
770
-
771
- with gr.Tab("Image → Speech (I2S)"):
772
- i2s_image_in = gr.Image(type="pil", label="Image input", sources=["upload"])
773
- i2s_prompt = gr.Textbox(
774
- label="Optional question",
775
- placeholder="(Optional) e.g., 'Describe this image aloud.'",
776
- )
777
- i2s_audio_out = gr.Audio(type="numpy", label="Spoken description")
778
- i2s_status = gr.Textbox(label="Status", interactive=False)
779
- with gr.Accordion("Advanced settings", open=False):
780
- i2s_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Reply token length")
781
- i2s_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
782
- i2s_block = gr.Slider(2, 512, value=256, step=2, label="Block length")
783
- i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
784
- i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, step=0.1, label="CFG scale")
785
- i2s_btn = gr.Button("Generate spoken description", variant="primary")
786
- i2s_btn.click(
787
- i2s_handler,
788
- inputs=[
789
- i2s_image_in,
790
- i2s_prompt,
791
- i2s_max_tokens,
792
- i2s_steps,
793
- i2s_block,
794
- i2s_temperature,
795
- i2s_cfg,
796
- ],
797
- outputs=[i2s_audio_out, i2s_status],
798
- )
799
-
800
- with gr.Tab("Text Chat"):
801
- chat_in = gr.Textbox(
802
- label="Message",
803
- lines=4,
804
- placeholder="Ask anything. The model will reply in text.",
805
- )
806
- chat_out = gr.Textbox(label="Assistant reply", lines=6)
807
- chat_status = gr.Textbox(label="Status", interactive=False)
808
- with gr.Accordion("Advanced settings", open=False):
809
- chat_max_tokens = gr.Slider(2, 512, value=64, step=2, label="Reply max tokens")
810
- chat_steps = gr.Slider(2, 512, value=64, step=2, label="Refinement steps")
811
- chat_block = gr.Slider(2, 512, value=64, step=2, label="Block length")
812
- chat_temperature_slider = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="Sampling temperature")
813
- chat_btn = gr.Button("Send", variant="primary")
814
- chat_btn.click(
815
- chat_handler,
816
- inputs=[
817
- chat_in,
818
- chat_max_tokens,
819
- chat_steps,
820
- chat_block,
821
- chat_temperature_slider,
822
- ],
823
- outputs=[chat_out, chat_status],
824
- )
825
-
826
- with gr.Tab("MMU (2 images → text)"):
827
- mmu_img_a = gr.Image(type="pil", label="Image A", sources=["upload"])
828
- mmu_img_b = gr.Image(type="pil", label="Image B", sources=["upload"])
829
- mmu_question = gr.Textbox(
830
- label="Question",
831
- lines=3,
832
- placeholder="Ask about the relationship or differences between the two images.",
833
- )
834
- mmu_answer = gr.Textbox(label="Answer", lines=6)
835
- mmu_status = gr.Textbox(label="Status", interactive=False)
836
- with gr.Accordion("Advanced settings", open=False):
837
- mmu_max_tokens = gr.Slider(2, 512, value=256, step=2, label="Answer max tokens")
838
- mmu_steps = gr.Slider(2, 512, value=256, step=2, label="Refinement steps")
839
- mmu_block = gr.Slider(2, 512, value=128, step=2, label="Block length")
840
- mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Sampling temperature")
841
- mmu_btn = gr.Button("Answer about the two images", variant="primary")
842
- mmu_btn.click(
843
- mmu_handler,
844
- inputs=[
845
- mmu_img_a,
846
- mmu_img_b,
847
- mmu_question,
848
- mmu_max_tokens,
849
- mmu_steps,
850
- mmu_block,
851
- mmu_temperature,
852
- ],
853
- outputs=[mmu_answer, mmu_status],
854
- )
855
-
856
- with gr.Tab("Text → Image (T2I)"):
857
- t2i_prompt = gr.Textbox(
858
- label="Prompt",
859
- lines=4,
860
- placeholder="Describe the image you want to generate...",
861
- )
862
- t2i_image_out = gr.Image(label="Generated image")
863
- t2i_status = gr.Textbox(label="Status", interactive=False)
864
- with gr.Accordion("Advanced settings", open=False):
865
- t2i_timesteps = gr.Slider(4, 128, value=32, step=2, label="Timesteps")
866
- t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
867
- t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
868
- t2i_btn = gr.Button("Generate image", variant="primary")
869
- t2i_btn.click(
870
- t2i_handler,
871
- inputs=[t2i_prompt, t2i_timesteps, t2i_temperature, t2i_guidance],
872
- outputs=[t2i_image_out, t2i_status],
873
- )
874
-
875
- with gr.Tab("Image Editing (I2I)"):
876
- i2i_image_in = gr.Image(type="pil", label="Reference image", sources=["upload"])
877
- i2i_instr = gr.Textbox(
878
- label="Editing instruction",
879
- lines=4,
880
- placeholder="Describe how you want to edit the image...",
881
- )
882
- i2i_image_out = gr.Image(label="Edited image")
883
- i2i_status = gr.Textbox(label="Status", interactive=False)
884
- with gr.Accordion("Advanced settings", open=False):
885
- i2i_timesteps = gr.Slider(4, 128, value=18, step=2, label="Timesteps")
886
- i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.05, label="Sampling temperature")
887
- i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, step=0.1, label="CFG scale")
888
- i2i_btn = gr.Button("Apply edit", variant="primary")
889
- i2i_btn.click(
890
- i2i_handler,
891
- inputs=[i2i_instr, i2i_image_in, i2i_timesteps, i2i_temperature, i2i_guidance],
892
- outputs=[i2i_image_out, i2i_status],
893
- )
894
 
895
 
896
  if __name__ == "__main__":
897
- demo.launch()
898
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
 
 
3
  from pathlib import Path
 
 
4
  import spaces
 
 
 
 
 
5
 
6
+ # === Import project modules ===
7
  PROJECT_ROOT = Path(__file__).resolve().parent
8
  MMADA_ROOT = PROJECT_ROOT / "MMaDA"
9
  if str(MMADA_ROOT) not in sys.path:
10
  sys.path.insert(0, str(MMADA_ROOT))
11
 
12
+ from inference.gradio_multimodal_demo_inst import OmadaDemo
13
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # ----------------------------------------------------------------------
17
+ # 1. Asset Loading (Downloaded by entrypoint)
18
+ # ----------------------------------------------------------------------
19
 
20
+ ASSET_ROOT = PROJECT_ROOT / "_asset_cache" / "AIDAS-Omni-Modal-Diffusion-assets"
21
+ DEMO_ROOT = ASSET_ROOT # asset repo already modality-split
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
+ # ----------------------------------------------------------------------
25
+ # 2. GPU Handler Wrapper
26
+ # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def gpu_handler(fn):
29
+ """
30
+ Wrap an inference function using ZeroGPU.
31
+ """
32
+ @spaces.GPU
33
+ def inner(*args, **kwargs):
34
+ return fn(*args, **kwargs)
35
+ return inner
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # ----------------------------------------------------------------------
39
+ # 3. Build Demo UI With Examples
40
+ # ----------------------------------------------------------------------
41
 
42
+ def build_zero_gpu_demo(app: OmadaDemo):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ with gr.Blocks(title="AIDAS Omni-Modal Diffusion") as demo:
45
 
46
+ # ---------------- Header ----------------
47
+ gr.Markdown(
48
+ "<h1 style='text-align:center'>AIDAS Omni-Modal Diffusion Model</h1>"
49
+ )
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ try:
52
+ logo_path = "/mnt/data/A2E36E9F-F389-487D-9984-FFF21C9228E3.png"
53
+ gr.Image(logo_path, elem_id="logo", show_label=False, height=120)
54
+ except:
55
+ pass
56
+
57
+ gr.Markdown("### Multimodal Inference Demo)")
58
+ gr.Markdown("---")
59
+
60
+ # ---------------- Tabs ----------------
61
+
62
+ with gr.Tabs():
63
+
64
+ # ============================================================
65
+ # 1) TEXT → SPEECH (T2S)
66
+ # ============================================================
67
+ with gr.Tab("Text → Speech (T2S)"):
68
+
69
+ t2s_in = gr.Textbox(label="Input Text")
70
+ t2s_btn = gr.Button("Generate")
71
+ t2s_audio = gr.Audio(label="Speech Output")
72
+ t2s_status = gr.Textbox(label="Status", interactive=False)
73
+
74
+ t2s_examples = []
75
+ t2s_dir = DEMO_ROOT / "t2s"
76
+ if t2s_dir.exists():
77
+ for f in t2s_dir.glob("*.txt"):
78
+ txt = f.read_text().strip()
79
+ t2s_examples.append([txt])
80
+
81
+ if len(t2s_examples) > 0:
82
+ gr.Examples(
83
+ examples=t2s_examples,
84
+ inputs=[t2s_in],
85
+ outputs=[t2s_audio, t2s_status],
86
+ fn=gpu_handler(app.run_t2s),
87
+ )
88
+
89
+ t2s_btn.click(
90
+ gpu_handler(app.run_t2s),
91
+ inputs=[t2s_in],
92
+ outputs=[t2s_audio, t2s_status],
93
+ )
94
+
95
+ # ============================================================
96
+ # 2) SPEECH → SPEECH (S2S)
97
+ # ============================================================
98
+ with gr.Tab("Speech → Speech (S2S)"):
99
+
100
+ s2s_in = gr.Audio(type="filepath", label="Input Speech")
101
+ s2s_btn = gr.Button("Generate")
102
+ s2s_audio = gr.Audio(label="Output Speech")
103
+ s2s_status = gr.Textbox(label="Status", interactive=False)
104
+
105
+ s2s_examples = []
106
+ s2s_dir = DEMO_ROOT / "s2s"
107
+ if s2s_dir.exists():
108
+ for f in s2s_dir.glob("*.wav"):
109
+ s2s_examples.append([str(f)])
110
+
111
+ if len(s2s_examples) > 0:
112
+ gr.Examples(
113
+ examples=s2s_examples,
114
+ inputs=[s2s_in],
115
+ outputs=[s2s_audio, s2s_status],
116
+ fn=gpu_handler(app.run_s2s),
117
+ )
118
+
119
+ s2s_btn.click(
120
+ gpu_handler(app.run_s2s),
121
+ inputs=[s2s_in],
122
+ outputs=[s2s_audio, s2s_status]
123
+ )
124
+
125
+ # ============================================================
126
+ # 3) SPEECH → TEXT (S2T)
127
+ # ============================================================
128
+ with gr.Tab("Speech → Text (S2T)"):
129
+
130
+ s2t_in = gr.Audio(type="filepath", label="Input Speech")
131
+ s2t_btn = gr.Button("Transcribe")
132
+ s2t_text = gr.Textbox(label="Transcribed Text")
133
+ s2t_status = gr.Textbox(label="Status", interactive=False)
134
+
135
+ s2t_examples = []
136
+ s2t_dir = DEMO_ROOT / "s2t"
137
+ if s2t_dir.exists():
138
+ for f in s2t_dir.glob("*.wav"):
139
+ s2t_examples.append([str(f)])
140
+
141
+ if len(s2t_examples) > 0:
142
+ gr.Examples(
143
+ examples=s2t_examples,
144
+ inputs=[s2t_in],
145
+ outputs=[s2t_text, s2t_status],
146
+ fn=gpu_handler(app.run_s2t),
147
+ )
148
+
149
+ s2t_btn.click(
150
+ gpu_handler(app.run_s2t),
151
+ inputs=[s2t_in],
152
+ outputs=[s2t_text, s2t_status],
153
+ )
154
+
155
+ # ============================================================
156
+ # 4) VIDEO → TEXT (V2T)
157
+ # ============================================================
158
+ with gr.Tab("Video → Text (V2T)"):
159
+
160
+ v2t_in = gr.Video(type="filepath", label="Input Video")
161
+ v2t_btn = gr.Button("Generate Caption")
162
+ v2t_text = gr.Textbox(label="Caption")
163
+ v2t_status = gr.Textbox(label="Status")
164
+
165
+ v2t_examples = []
166
+ v2t_dir = DEMO_ROOT / "v2t"
167
+ if v2t_dir.exists():
168
+ for f in v2t_dir.glob("*.mp4"):
169
+ v2t_examples.append([str(f)])
170
+
171
+ if len(v2t_examples) > 0:
172
+ gr.Examples(
173
+ examples=v2t_examples,
174
+ inputs=[v2t_in],
175
+ outputs=[v2t_text, v2t_status],
176
+ fn=gpu_handler(app.run_v2t),
177
+ )
178
+
179
+ v2t_btn.click(
180
+ gpu_handler(app.run_v2t),
181
+ inputs=[v2t_in],
182
+ outputs=[v2t_text, v2t_status],
183
+ )
184
+
185
+ # ============================================================
186
+ # 5) VIDEO → SPEECH (V2S)
187
+ # ============================================================
188
+ with gr.Tab("Video → Speech (V2S)"):
189
+
190
+ v2s_in = gr.Video(type="filepath", label="Input Video")
191
+ v2s_btn = gr.Button("Generate Speech")
192
+ v2s_audio = gr.Audio(label="Speech Output")
193
+ v2s_status = gr.Textbox(label="Status")
194
+
195
+ v2s_examples = []
196
+ v2s_dir = DEMO_ROOT / "v2s"
197
+ if v2s_dir.exists():
198
+ for f in v2s_dir.glob("*.mp4"):
199
+ v2s_examples.append([str(f)])
200
+
201
+ if len(v2s_examples) > 0:
202
+ gr.Examples(
203
+ examples=v2s_examples,
204
+ inputs=[v2s_in],
205
+ outputs=[v2s_audio, v2s_status],
206
+ fn=gpu_handler(app.run_v2s),
207
+ )
208
+
209
+ v2s_btn.click(
210
+ gpu_handler(app.run_v2s),
211
+ inputs=[v2s_in],
212
+ outputs=[v2s_audio, v2s_status],
213
+ )
214
+
215
+ # ============================================================
216
+ # 6) IMAGE → SPEECH (I2S)
217
+ # ============================================================
218
+ with gr.Tab("Image → Speech (I2S)"):
219
+
220
+ i2s_in = gr.Image(type="filepath", label="Input Image")
221
+ i2s_btn = gr.Button("Generate Speech")
222
+ i2s_audio = gr.Audio(label="Speech")
223
+ i2s_status = gr.Textbox(label="Status")
224
+
225
+ # Only if folder exists
226
+ i2s_examples = []
227
+ i2s_dir = DEMO_ROOT / "i2s"
228
+ if i2s_dir.exists():
229
+ for f in i2s_dir.glob("*.*"):
230
+ i2s_examples.append([str(f)])
231
+
232
+ if len(i2s_examples) > 0:
233
+ gr.Examples(
234
+ examples=i2s_examples,
235
+ inputs=[i2s_in],
236
+ outputs=[i2s_audio, i2s_status],
237
+ fn=gpu_handler(app.run_i2s),
238
+ )
239
+
240
+ i2s_btn.click(
241
+ gpu_handler(app.run_i2s),
242
+ inputs=[i2s_in],
243
+ outputs=[i2s_audio, i2s_status],
244
+ )
245
+
246
+ # ============================================================
247
+ # 7) CHAT
248
+ # ============================================================
249
+ with gr.Tab("Chat (Text)"):
250
+
251
+ chat_in = gr.Textbox(label="Message")
252
+ chat_btn = gr.Button("Send")
253
+ chat_out = gr.Textbox(label="Response")
254
+ chat_status = gr.Textbox(label="Status")
255
+
256
+ chat_examples = []
257
+ chat_dir = DEMO_ROOT / "chat"
258
+ if chat_dir.exists():
259
+ for f in chat_dir.glob("*.txt"):
260
+ txt = f.read_text().strip()
261
+ chat_examples.append([txt])
262
+
263
+ if len(chat_examples) > 0:
264
+ gr.Examples(
265
+ examples=chat_examples,
266
+ inputs=[chat_in],
267
+ outputs=[chat_out, chat_status],
268
+ fn=gpu_handler(app.run_chat),
269
+ )
270
+
271
+ chat_btn.click(
272
+ gpu_handler(app.run_chat),
273
+ inputs=[chat_in],
274
+ outputs=[chat_out, chat_status],
275
+ )
276
+
277
+ # ============================================================
278
+ # 8) MMU (2 images → text)
279
+ # ============================================================
280
+ with gr.Tab("MMU (Dual-Image Reasoning)"):
281
+
282
+ mmu_img1 = gr.Image(type="filepath", label="Image 1")
283
+ mmu_img2 = gr.Image(type="filepath", label="Image 2")
284
+ mmu_prompt = gr.Textbox(label="Prompt")
285
+ mmu_btn = gr.Button("Run MMU")
286
+ mmu_out = gr.Textbox(label="Output")
287
+ mmu_status = gr.Textbox(label="Status")
288
+
289
+ mmu_examples = []
290
+ mmu_dir = DEMO_ROOT / "mmu"
291
+ if mmu_dir.exists():
292
+ imgs = list(mmu_dir.glob("*.png"))
293
+ if len(imgs) >= 2:
294
+ mmu_examples.append([
295
+ str(imgs[0]),
296
+ str(imgs[1]),
297
+ "Describe the relation between two objects."
298
+ ])
299
+
300
+ if len(mmu_examples) > 0:
301
+ gr.Examples(
302
+ examples=mmu_examples,
303
+ inputs=[mmu_img1, mmu_img2, mmu_prompt],
304
+ outputs=[mmu_out, mmu_status],
305
+ fn=gpu_handler(app.run_mmu_dual),
306
+ )
307
+
308
+ mmu_btn.click(
309
+ gpu_handler(app.run_mmu_dual),
310
+ inputs=[mmu_img1, mmu_img2, mmu_prompt],
311
+ outputs=[mmu_out, mmu_status]
312
+ )
313
+
314
+ # ============================================================
315
+ # 9) TEXT → IMAGE (T2I)
316
+ # ============================================================
317
+ with gr.Tab("Text → Image (T2I)"):
318
+
319
+ t2i_in = gr.Textbox(label="Prompt")
320
+ t2i_btn = gr.Button("Generate Image")
321
+ t2i_img = gr.Image(label="Generated Image")
322
+ t2i_status = gr.Textbox(label="Status")
323
+
324
+ t2i_examples = []
325
+ t2i_dir = DEMO_ROOT / "t2i"
326
+ if t2i_dir.exists():
327
+ for f in t2i_dir.glob("*.txt"):
328
+ txt = f.read_text().strip()
329
+ t2i_examples.append([txt])
330
+
331
+ if len(t2i_examples) > 0:
332
+ gr.Examples(
333
+ examples=t2i_examples,
334
+ inputs=[t2i_in],
335
+ outputs=[t2i_img, t2i_status],
336
+ fn=gpu_handler(app.run_t2i),
337
+ )
338
+
339
+ t2i_btn.click(
340
+ gpu_handler(app.run_t2i),
341
+ inputs=[t2i_in],
342
+ outputs=[t2i_img, t2i_status],
343
+ )
344
+
345
+ # ============================================================
346
+ # 10) IMAGE EDITING (I2I)
347
+ # ============================================================
348
+ with gr.Tab("Image Editing (I2I)"):
349
+
350
+ i2i_in = gr.Image(type="filepath", label="Input Image")
351
+ i2i_prompt = gr.Textbox(label="Edit Instruction")
352
+ i2i_btn = gr.Button("Apply Edit")
353
+ i2i_img = gr.Image(label="Edited Image")
354
+ i2i_status = gr.Textbox(label="Status")
355
+
356
+ i2i_examples = []
357
+ i2i_dir = DEMO_ROOT / "i2i"
358
+ if i2i_dir.exists():
359
+ for f in i2i_dir.glob("*.*"):
360
+ i2i_examples.append([str(f), "Make it more vibrant."])
361
+
362
+ if len(i2i_examples) > 0:
363
+ gr.Examples(
364
+ examples=i2i_examples,
365
+ inputs=[i2i_in, i2i_prompt],
366
+ outputs=[i2i_img, i2i_status],
367
+ fn=gpu_handler(app.run_i2i),
368
+ )
369
+
370
+ i2i_btn.click(
371
+ gpu_handler(app.run_i2i),
372
+ inputs=[i2i_in, i2i_prompt],
373
+ outputs=[i2i_img, i2i_status]
374
+ )
375
+
376
+ # End Tabs
377
+
378
+ return demo
379
+
380
+
381
+ # ----------------------------------------------------------------------
382
+ # 4. Entry Point for Space
383
+ # ----------------------------------------------------------------------
384
 
385
  @spaces.GPU
386
+ def main():
387
+ app = OmadaDemo(
388
+ train_config=str(MMADA_ROOT / "inference/demo/demo.yaml"),
389
+ checkpoint=os.getenv("MODEL_CHECKPOINT_DIR", "_ckpt_cache/omada"),
390
+ device="cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  )
392
 
393
+ demo = build_zero_gpu_demo(app)
394
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
 
397
  if __name__ == "__main__":
398
+ main()