ginipick commited on
Commit
09d623f
Β·
verified Β·
1 Parent(s): 0312016

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -18
app.py CHANGED
@@ -197,25 +197,65 @@ TOPIC_DICT_EN = load_json_dict("story_en.json", DEFAULT_TOPICS_EN_DICT)
197
  CATEGORY_LIST = list(TOPIC_DICT_KO.keys())
198
 
199
  # ────────────────────────────────────────────────────────────────
200
- # 4. Initialize Video Models
201
  # ────────────────────────────────────────────────────────────────
202
- vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
203
- wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
204
- transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16)
205
- pipe = NAGWanPipeline.from_pretrained(
206
- MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16
207
- )
208
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
209
- pipe.to("cuda")
210
-
211
- pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
212
- pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
213
- pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
 
214
 
215
  # ────────────────────────────────────────────────────────────────
216
- # 5. Initialize Audio Model
217
  # ────────────────────────────────────────────────────────────────
218
- def get_mmaudio_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  seq_cfg = audio_model_config.seq_cfg
220
 
221
  net: MMAudio = get_my_mmaudio(audio_model_config.model_name).to(device, dtype).eval()
@@ -230,9 +270,11 @@ def get_mmaudio_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
230
  need_vae_encoder=False)
231
  feature_utils = feature_utils.to(device, dtype).eval()
232
 
233
- return net, feature_utils, seq_cfg
234
-
235
- audio_net, audio_feature_utils, audio_seq_cfg = get_mmaudio_model()
 
 
236
 
237
  # ────────────────────────────────────────────────────────────────
238
  # 6. Story Seed Functions
@@ -345,9 +387,13 @@ Write all elements as one flowing paragraph that video creators can immediately
345
  # ────────────────────────────────────────────────────────────────
346
  # 8. Video/Audio Generation Functions
347
  # ────────────────────────────────────────────────────────────────
 
348
  @torch.inference_mode()
349
  def add_audio_to_video(video_path, prompt, audio_negative_prompt, audio_steps, audio_cfg_strength, duration):
350
  """Generate and add audio to video using MMAudio"""
 
 
 
351
  rng = torch.Generator(device=device)
352
  rng.seed()
353
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps)
@@ -393,6 +439,9 @@ def generate_video_with_audio(
393
  enable_audio=True, audio_negative_prompt=DEFAULT_AUDIO_NEGATIVE_PROMPT,
394
  audio_steps=25, audio_cfg_strength=4.5,
395
  ):
 
 
 
396
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
397
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
398
 
 
197
  CATEGORY_LIST = list(TOPIC_DICT_KO.keys())
198
 
199
  # ────────────────────────────────────────────────────────────────
200
+ # 4. Global Model Variables (Will be initialized in GPU functions)
201
  # ────────────────────────────────────────────────────────────────
202
+ # Video models
203
+ vae = None
204
+ transformer = None
205
+ pipe = None
206
+
207
+ # Audio models
208
+ audio_net = None
209
+ audio_feature_utils = None
210
+ audio_seq_cfg = None
211
+
212
+ # Model initialization flags
213
+ video_models_initialized = False
214
+ audio_models_initialized = False
215
 
216
  # ────────────────────────────────────────────────────────────────
217
+ # 5. Model Initialization Functions (Called inside GPU functions)
218
  # ────────────────────────────────────────────────────────────────
219
+ def initialize_video_models():
220
+ """Initialize video generation models - must be called inside GPU function"""
221
+ global vae, transformer, pipe, video_models_initialized
222
+
223
+ if video_models_initialized:
224
+ return
225
+
226
+ logger.info("Initializing video models inside GPU function...")
227
+
228
+ # Initialize VAE
229
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
230
+
231
+ # Download and initialize transformer
232
+ wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
233
+ transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16)
234
+
235
+ # Initialize pipeline
236
+ pipe = NAGWanPipeline.from_pretrained(
237
+ MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16
238
+ )
239
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
240
+ pipe.to("cuda")
241
+
242
+ # Set attn processors
243
+ pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
244
+ pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
245
+ pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
246
+
247
+ video_models_initialized = True
248
+ logger.info("Video models initialized successfully")
249
+
250
+ def initialize_audio_models():
251
+ """Initialize audio generation models - must be called inside GPU function"""
252
+ global audio_net, audio_feature_utils, audio_seq_cfg, audio_models_initialized
253
+
254
+ if audio_models_initialized:
255
+ return
256
+
257
+ logger.info("Initializing audio models inside GPU function...")
258
+
259
  seq_cfg = audio_model_config.seq_cfg
260
 
261
  net: MMAudio = get_my_mmaudio(audio_model_config.model_name).to(device, dtype).eval()
 
270
  need_vae_encoder=False)
271
  feature_utils = feature_utils.to(device, dtype).eval()
272
 
273
+ audio_net = net
274
+ audio_feature_utils = feature_utils
275
+ audio_seq_cfg = seq_cfg
276
+ audio_models_initialized = True
277
+ logger.info("Audio models initialized successfully")
278
 
279
  # ────────────────────────────────────────────────────────────────
280
  # 6. Story Seed Functions
 
387
  # ────────────────────────────────────────────────────────────────
388
  # 8. Video/Audio Generation Functions
389
  # ────────────────────────────────────────────────────────────────
390
+ @spaces.GPU()
391
  @torch.inference_mode()
392
  def add_audio_to_video(video_path, prompt, audio_negative_prompt, audio_steps, audio_cfg_strength, duration):
393
  """Generate and add audio to video using MMAudio"""
394
+ # Initialize audio models if needed
395
+ initialize_audio_models()
396
+
397
  rng = torch.Generator(device=device)
398
  rng.seed()
399
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps)
 
439
  enable_audio=True, audio_negative_prompt=DEFAULT_AUDIO_NEGATIVE_PROMPT,
440
  audio_steps=25, audio_cfg_strength=4.5,
441
  ):
442
+ # Initialize video models if needed
443
+ initialize_video_models()
444
+
445
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
446
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
447