jbilcke-hf HF Staff commited on
Commit
0ce1e5e
·
verified ·
1 Parent(s): 54eccd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -222
app.py CHANGED
@@ -53,20 +53,6 @@ import numpy as np
53
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
56
- # LoRA Storage Configuration
57
- STORAGE_PATH = Path(DATA_ROOT) / "storage"
58
- LORA_PATH = STORAGE_PATH / "loras"
59
- OUTPUT_PATH = STORAGE_PATH / "output"
60
-
61
- # Create necessary directories
62
- STORAGE_PATH.mkdir(parents=True, exist_ok=True)
63
- LORA_PATH.mkdir(parents=True, exist_ok=True)
64
- OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
65
-
66
- # Global variables for LoRA management
67
- current_lora_id = None
68
- current_lora_path = None
69
-
70
  # --- Argument Parsing ---
71
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
72
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
@@ -129,89 +115,6 @@ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
129
  APP_STATE["torch_compile_applied"] = True
130
  print("✅ torch.compile applied to transformer")
131
 
132
- def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
133
- """Upload a LoRA file and return a hash-based ID for future reference"""
134
- if file is None:
135
- return "", ""
136
-
137
- try:
138
- # Calculate SHA256 hash of the file
139
- sha256_hash = hashlib.sha256()
140
- with open(file.name, "rb") as f:
141
- for chunk in iter(lambda: f.read(4096), b""):
142
- sha256_hash.update(chunk)
143
- file_hash = sha256_hash.hexdigest()
144
-
145
- # Create destination path using hash
146
- dest_path = LORA_PATH / f"{file_hash}.safetensors"
147
-
148
- # Check if file already exists
149
- if dest_path.exists():
150
- print(f"LoRA file already exists!")
151
- return file_hash, file_hash
152
-
153
- # Copy the file to the destination
154
- shutil.copy(file.name, dest_path)
155
-
156
- print(f"LoRA file uploaded!")
157
- return file_hash, file_hash
158
- except Exception as e:
159
- print(f"Error uploading LoRA file: {e}")
160
- raise gr.Error(f"Failed to upload LoRA file: {str(e)}")
161
-
162
- def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
163
- """Get the path to a LoRA file from its hash-based ID"""
164
- if not lora_id:
165
- return None
166
-
167
- # Check if file exists
168
- lora_path = LORA_PATH / f"{lora_id}.safetensors"
169
- if lora_path.exists():
170
- return lora_path
171
-
172
- return None
173
-
174
- def manage_lora_weights(lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
175
- """Manage LoRA weights for the transformer model"""
176
- global current_lora_id, current_lora_path
177
-
178
- # Determine if we should use LoRA
179
- using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
180
-
181
- # If not using LoRA but we have one loaded, clear it
182
- if not using_lora and current_lora_id is not None:
183
- print(f"Clearing current LoRA")
184
- current_lora_id = None
185
- current_lora_path = None
186
- return False, None
187
-
188
- # If using LoRA, check if we need to change weights
189
- if using_lora:
190
- lora_path = get_lora_file_path(lora_id)
191
-
192
- if not lora_path:
193
- print(f"A LoRA file with this ID was found. Using base model instead.")
194
-
195
- # If we had a LoRA loaded, clear it
196
- if current_lora_id is not None:
197
- print(f"Clearing current LoRA")
198
- current_lora_id = None
199
- current_lora_path = None
200
-
201
- return False, None
202
-
203
- # If LoRA ID changed, update
204
- if lora_id != current_lora_id:
205
- print(f"Loading LoRA..")
206
- current_lora_id = lora_id
207
- current_lora_path = lora_path
208
- else:
209
- print(f"Using a LoRA!")
210
-
211
- return True, lora_path
212
-
213
- return False, None
214
-
215
  def frames_to_ts_file(frames, filepath, fps = 15):
216
  """
217
  Convert frames directly to .ts file using PyAV.
@@ -327,7 +230,7 @@ pipeline = CausalInferencePipeline(
327
  pipeline.to(dtype=torch.float16).to(gpu)
328
 
329
  @torch.no_grad()
330
- def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, lora_id=None, lora_weight=0.0):
331
  """
332
  Generator function that yields .ts video chunks using PyAV for streaming.
333
  """
@@ -336,13 +239,6 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
336
 
337
  # print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
338
 
339
- # Handle LoRA weights
340
- using_lora, lora_path = manage_lora_weights(lora_id, lora_weight)
341
- if using_lora:
342
- print(f"🎨 Using LoRA with weight factor {lora_weight}")
343
- else:
344
- print("🎨 Using base model (no LoRA)")
345
-
346
  # Setup
347
  conditional_dict = text_encoder(text_prompts=[prompt])
348
  for key, value in conditional_dict.items():
@@ -504,133 +400,93 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
504
  print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
505
 
506
  # --- Gradio UI Layout ---
507
- with gr.Blocks(title="Wan2.1 1.3B LoRA Self-Forcing streaming demo") as demo:
508
- gr.Markdown("# 🚀 Run Any LoRA in near real-time!")
509
- gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B and LoRA [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
510
 
511
- with gr.Tabs():
512
- # LoRA Upload Tab
513
- with gr.TabItem("1️⃣ Upload LoRA"):
514
- gr.Markdown("## Upload LoRA Weights")
515
- gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.")
516
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  with gr.Row():
518
- lora_file = gr.File(label="LoRA File (safetensors format)")
519
-
 
 
 
 
 
 
 
520
  with gr.Row():
521
- lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- # Video Generation Tab
524
- with gr.TabItem("2️⃣ Generate Video"):
525
- with gr.Row():
526
- with gr.Column(scale=2):
527
- with gr.Group():
528
- prompt = gr.Textbox(
529
- label="Prompt",
530
- placeholder="A stylish woman walks down a Tokyo street...",
531
- lines=4,
532
- value=""
533
- )
534
-
535
- start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
536
-
537
- gr.Markdown("### ⚙️ Settings")
538
- with gr.Row():
539
- seed = gr.Number(
540
- label="Seed",
541
- value=-1,
542
- info="Use -1 for random seed",
543
- precision=0
544
- )
545
- fps = gr.Slider(
546
- label="Playback FPS",
547
- minimum=1,
548
- maximum=30,
549
- value=args.fps,
550
- step=1,
551
- visible=False,
552
- info="Frames per second for playback"
553
- )
554
-
555
- with gr.Row():
556
- duration = gr.Slider(
557
- label="Duration (seconds)",
558
- minimum=1,
559
- maximum=5,
560
- value=3,
561
- step=1,
562
- info="Video duration in seconds"
563
- )
564
-
565
- with gr.Row():
566
- width = gr.Slider(
567
- label="Width",
568
- minimum=224,
569
- maximum=720,
570
- value=400,
571
- step=8,
572
- info="Video width in pixels (8px steps)"
573
- )
574
- height = gr.Slider(
575
- label="Height",
576
- minimum=224,
577
- maximum=720,
578
- value=224,
579
- step=8,
580
- info="Video height in pixels (8px steps)"
581
- )
582
-
583
- gr.Markdown("### 🎨 LoRA Settings")
584
- lora_id = gr.Textbox(
585
- label="LoRA ID (from upload tab)",
586
- placeholder="Enter your LoRA ID here...",
587
- )
588
-
589
- lora_weight = gr.Slider(
590
- label="LoRA Weight",
591
- minimum=0.0,
592
- maximum=1.0,
593
- step=0.01,
594
- value=1.0,
595
- info="Strength of LoRA influence"
596
- )
597
-
598
- with gr.Column(scale=3):
599
- gr.Markdown("### 📺 Video Stream")
600
-
601
- streaming_video = gr.Video(
602
- label="Live Stream",
603
- streaming=True,
604
- loop=True,
605
- height=400,
606
- autoplay=True,
607
- show_label=False
608
- )
609
-
610
- status_display = gr.HTML(
611
- value=(
612
- "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
613
- "🎬 Ready to start streaming...<br>"
614
- "<small>Configure your prompt and click 'Start Streaming'</small>"
615
- "</div>"
616
- ),
617
- label="Generation Status"
618
- )
619
 
620
  # Connect the generator to the streaming video
621
  start_btn.click(
622
  fn=video_generation_handler_streaming,
623
- inputs=[prompt, seed, fps, width, height, duration, lora_id, lora_weight],
624
  outputs=[streaming_video, status_display]
625
  )
626
-
627
- # Connect LoRA upload to both display fields
628
- lora_file.change(
629
- fn=upload_lora_file,
630
- inputs=[lora_file],
631
- outputs=[lora_id_output, lora_id]
632
- )
633
-
634
 
635
  # --- Launch App ---
636
  if __name__ == "__main__":
 
53
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # --- Argument Parsing ---
57
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
58
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
 
115
  APP_STATE["torch_compile_applied"] = True
116
  print("✅ torch.compile applied to transformer")
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def frames_to_ts_file(frames, filepath, fps = 15):
119
  """
120
  Convert frames directly to .ts file using PyAV.
 
230
  pipeline.to(dtype=torch.float16).to(gpu)
231
 
232
  @torch.no_grad()
233
+ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5):
234
  """
235
  Generator function that yields .ts video chunks using PyAV for streaming.
236
  """
 
239
 
240
  # print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
241
 
 
 
 
 
 
 
 
242
  # Setup
243
  conditional_dict = text_encoder(text_prompts=[prompt])
244
  for key, value in conditional_dict.items():
 
400
  print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
401
 
402
  # --- Gradio UI Layout ---
403
+ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
404
+ gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
 
405
 
406
+ with gr.Row():
407
+ with gr.Column(scale=2):
408
+ with gr.Group():
409
+ prompt = gr.Textbox(
410
+ label="Prompt",
411
+ placeholder="A stylish woman walks down a Tokyo street...",
412
+ lines=4,
413
+ value=""
414
+ )
415
+ start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
416
+
417
+ gr.Markdown("### ⚙️ Settings")
418
+ with gr.Row():
419
+ seed = gr.Number(
420
+ label="Seed",
421
+ value=-1,
422
+ info="Use -1 for random seed",
423
+ precision=0
424
+ )
425
+ fps = gr.Slider(
426
+ label="Playback FPS",
427
+ minimum=1,
428
+ maximum=30,
429
+ value=args.fps,
430
+ step=1,
431
+ visible=False,
432
+ info="Frames per second for playback"
433
+ )
434
+
435
  with gr.Row():
436
+ duration = gr.Slider(
437
+ label="Duration (seconds)",
438
+ minimum=1,
439
+ maximum=5,
440
+ value=3,
441
+ step=1,
442
+ info="Video duration in seconds"
443
+ )
444
+
445
  with gr.Row():
446
+ width = gr.Slider(
447
+ label="Width",
448
+ minimum=224,
449
+ maximum=720,
450
+ value=400,
451
+ step=8,
452
+ info="Video width in pixels (8px steps)"
453
+ )
454
+ height = gr.Slider(
455
+ label="Height",
456
+ minimum=224,
457
+ maximum=720,
458
+ value=224,
459
+ step=8,
460
+ info="Video height in pixels (8px steps)"
461
+ )
462
 
463
+ with gr.Column(scale=3):
464
+ gr.Markdown("### 📺 Video Stream")
465
+ streaming_video = gr.Video(
466
+ label="Live Stream",
467
+ streaming=True,
468
+ loop=True,
469
+ height=400,
470
+ autoplay=True,
471
+ show_label=False
472
+ )
473
+
474
+ status_display = gr.HTML(
475
+ value=(
476
+ "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
477
+ "🎬 Ready to start streaming...<br>"
478
+ "<small>Configure your prompt and click 'Start Streaming'</small>"
479
+ "</div>"
480
+ ),
481
+ label="Generation Status"
482
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
  # Connect the generator to the streaming video
485
  start_btn.click(
486
  fn=video_generation_handler_streaming,
487
+ inputs=[prompt, seed, fps, width, height, duration],
488
  outputs=[streaming_video, status_display]
489
  )
 
 
 
 
 
 
 
 
490
 
491
  # --- Launch App ---
492
  if __name__ == "__main__":