Update app.py
Browse files
app.py
CHANGED
@@ -53,20 +53,6 @@ import numpy as np
|
|
53 |
|
54 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
|
56 |
-
# LoRA Storage Configuration
|
57 |
-
STORAGE_PATH = Path(DATA_ROOT) / "storage"
|
58 |
-
LORA_PATH = STORAGE_PATH / "loras"
|
59 |
-
OUTPUT_PATH = STORAGE_PATH / "output"
|
60 |
-
|
61 |
-
# Create necessary directories
|
62 |
-
STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
63 |
-
LORA_PATH.mkdir(parents=True, exist_ok=True)
|
64 |
-
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
|
65 |
-
|
66 |
-
# Global variables for LoRA management
|
67 |
-
current_lora_id = None
|
68 |
-
current_lora_path = None
|
69 |
-
|
70 |
# --- Argument Parsing ---
|
71 |
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
72 |
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
@@ -129,89 +115,6 @@ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
|
|
129 |
APP_STATE["torch_compile_applied"] = True
|
130 |
print("✅ torch.compile applied to transformer")
|
131 |
|
132 |
-
def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
|
133 |
-
"""Upload a LoRA file and return a hash-based ID for future reference"""
|
134 |
-
if file is None:
|
135 |
-
return "", ""
|
136 |
-
|
137 |
-
try:
|
138 |
-
# Calculate SHA256 hash of the file
|
139 |
-
sha256_hash = hashlib.sha256()
|
140 |
-
with open(file.name, "rb") as f:
|
141 |
-
for chunk in iter(lambda: f.read(4096), b""):
|
142 |
-
sha256_hash.update(chunk)
|
143 |
-
file_hash = sha256_hash.hexdigest()
|
144 |
-
|
145 |
-
# Create destination path using hash
|
146 |
-
dest_path = LORA_PATH / f"{file_hash}.safetensors"
|
147 |
-
|
148 |
-
# Check if file already exists
|
149 |
-
if dest_path.exists():
|
150 |
-
print(f"LoRA file already exists!")
|
151 |
-
return file_hash, file_hash
|
152 |
-
|
153 |
-
# Copy the file to the destination
|
154 |
-
shutil.copy(file.name, dest_path)
|
155 |
-
|
156 |
-
print(f"LoRA file uploaded!")
|
157 |
-
return file_hash, file_hash
|
158 |
-
except Exception as e:
|
159 |
-
print(f"Error uploading LoRA file: {e}")
|
160 |
-
raise gr.Error(f"Failed to upload LoRA file: {str(e)}")
|
161 |
-
|
162 |
-
def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
|
163 |
-
"""Get the path to a LoRA file from its hash-based ID"""
|
164 |
-
if not lora_id:
|
165 |
-
return None
|
166 |
-
|
167 |
-
# Check if file exists
|
168 |
-
lora_path = LORA_PATH / f"{lora_id}.safetensors"
|
169 |
-
if lora_path.exists():
|
170 |
-
return lora_path
|
171 |
-
|
172 |
-
return None
|
173 |
-
|
174 |
-
def manage_lora_weights(lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
|
175 |
-
"""Manage LoRA weights for the transformer model"""
|
176 |
-
global current_lora_id, current_lora_path
|
177 |
-
|
178 |
-
# Determine if we should use LoRA
|
179 |
-
using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
|
180 |
-
|
181 |
-
# If not using LoRA but we have one loaded, clear it
|
182 |
-
if not using_lora and current_lora_id is not None:
|
183 |
-
print(f"Clearing current LoRA")
|
184 |
-
current_lora_id = None
|
185 |
-
current_lora_path = None
|
186 |
-
return False, None
|
187 |
-
|
188 |
-
# If using LoRA, check if we need to change weights
|
189 |
-
if using_lora:
|
190 |
-
lora_path = get_lora_file_path(lora_id)
|
191 |
-
|
192 |
-
if not lora_path:
|
193 |
-
print(f"A LoRA file with this ID was found. Using base model instead.")
|
194 |
-
|
195 |
-
# If we had a LoRA loaded, clear it
|
196 |
-
if current_lora_id is not None:
|
197 |
-
print(f"Clearing current LoRA")
|
198 |
-
current_lora_id = None
|
199 |
-
current_lora_path = None
|
200 |
-
|
201 |
-
return False, None
|
202 |
-
|
203 |
-
# If LoRA ID changed, update
|
204 |
-
if lora_id != current_lora_id:
|
205 |
-
print(f"Loading LoRA..")
|
206 |
-
current_lora_id = lora_id
|
207 |
-
current_lora_path = lora_path
|
208 |
-
else:
|
209 |
-
print(f"Using a LoRA!")
|
210 |
-
|
211 |
-
return True, lora_path
|
212 |
-
|
213 |
-
return False, None
|
214 |
-
|
215 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
216 |
"""
|
217 |
Convert frames directly to .ts file using PyAV.
|
@@ -327,7 +230,7 @@ pipeline = CausalInferencePipeline(
|
|
327 |
pipeline.to(dtype=torch.float16).to(gpu)
|
328 |
|
329 |
@torch.no_grad()
|
330 |
-
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5
|
331 |
"""
|
332 |
Generator function that yields .ts video chunks using PyAV for streaming.
|
333 |
"""
|
@@ -336,13 +239,6 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
336 |
|
337 |
# print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
|
338 |
|
339 |
-
# Handle LoRA weights
|
340 |
-
using_lora, lora_path = manage_lora_weights(lora_id, lora_weight)
|
341 |
-
if using_lora:
|
342 |
-
print(f"🎨 Using LoRA with weight factor {lora_weight}")
|
343 |
-
else:
|
344 |
-
print("🎨 Using base model (no LoRA)")
|
345 |
-
|
346 |
# Setup
|
347 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
348 |
for key, value in conditional_dict.items():
|
@@ -504,133 +400,93 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
504 |
print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
505 |
|
506 |
# --- Gradio UI Layout ---
|
507 |
-
with gr.Blocks(title="Wan2.1 1.3B
|
508 |
-
gr.Markdown("
|
509 |
-
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B and LoRA [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
510 |
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
with gr.Row():
|
518 |
-
|
519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
with gr.Row():
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
precision=0
|
544 |
-
)
|
545 |
-
fps = gr.Slider(
|
546 |
-
label="Playback FPS",
|
547 |
-
minimum=1,
|
548 |
-
maximum=30,
|
549 |
-
value=args.fps,
|
550 |
-
step=1,
|
551 |
-
visible=False,
|
552 |
-
info="Frames per second for playback"
|
553 |
-
)
|
554 |
-
|
555 |
-
with gr.Row():
|
556 |
-
duration = gr.Slider(
|
557 |
-
label="Duration (seconds)",
|
558 |
-
minimum=1,
|
559 |
-
maximum=5,
|
560 |
-
value=3,
|
561 |
-
step=1,
|
562 |
-
info="Video duration in seconds"
|
563 |
-
)
|
564 |
-
|
565 |
-
with gr.Row():
|
566 |
-
width = gr.Slider(
|
567 |
-
label="Width",
|
568 |
-
minimum=224,
|
569 |
-
maximum=720,
|
570 |
-
value=400,
|
571 |
-
step=8,
|
572 |
-
info="Video width in pixels (8px steps)"
|
573 |
-
)
|
574 |
-
height = gr.Slider(
|
575 |
-
label="Height",
|
576 |
-
minimum=224,
|
577 |
-
maximum=720,
|
578 |
-
value=224,
|
579 |
-
step=8,
|
580 |
-
info="Video height in pixels (8px steps)"
|
581 |
-
)
|
582 |
-
|
583 |
-
gr.Markdown("### 🎨 LoRA Settings")
|
584 |
-
lora_id = gr.Textbox(
|
585 |
-
label="LoRA ID (from upload tab)",
|
586 |
-
placeholder="Enter your LoRA ID here...",
|
587 |
-
)
|
588 |
-
|
589 |
-
lora_weight = gr.Slider(
|
590 |
-
label="LoRA Weight",
|
591 |
-
minimum=0.0,
|
592 |
-
maximum=1.0,
|
593 |
-
step=0.01,
|
594 |
-
value=1.0,
|
595 |
-
info="Strength of LoRA influence"
|
596 |
-
)
|
597 |
-
|
598 |
-
with gr.Column(scale=3):
|
599 |
-
gr.Markdown("### 📺 Video Stream")
|
600 |
-
|
601 |
-
streaming_video = gr.Video(
|
602 |
-
label="Live Stream",
|
603 |
-
streaming=True,
|
604 |
-
loop=True,
|
605 |
-
height=400,
|
606 |
-
autoplay=True,
|
607 |
-
show_label=False
|
608 |
-
)
|
609 |
-
|
610 |
-
status_display = gr.HTML(
|
611 |
-
value=(
|
612 |
-
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
613 |
-
"🎬 Ready to start streaming...<br>"
|
614 |
-
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
615 |
-
"</div>"
|
616 |
-
),
|
617 |
-
label="Generation Status"
|
618 |
-
)
|
619 |
|
620 |
# Connect the generator to the streaming video
|
621 |
start_btn.click(
|
622 |
fn=video_generation_handler_streaming,
|
623 |
-
inputs=[prompt, seed, fps, width, height, duration
|
624 |
outputs=[streaming_video, status_display]
|
625 |
)
|
626 |
-
|
627 |
-
# Connect LoRA upload to both display fields
|
628 |
-
lora_file.change(
|
629 |
-
fn=upload_lora_file,
|
630 |
-
inputs=[lora_file],
|
631 |
-
outputs=[lora_id_output, lora_id]
|
632 |
-
)
|
633 |
-
|
634 |
|
635 |
# --- Launch App ---
|
636 |
if __name__ == "__main__":
|
|
|
53 |
|
54 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# --- Argument Parsing ---
|
57 |
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
58 |
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
|
|
115 |
APP_STATE["torch_compile_applied"] = True
|
116 |
print("✅ torch.compile applied to transformer")
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
119 |
"""
|
120 |
Convert frames directly to .ts file using PyAV.
|
|
|
230 |
pipeline.to(dtype=torch.float16).to(gpu)
|
231 |
|
232 |
@torch.no_grad()
|
233 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5):
|
234 |
"""
|
235 |
Generator function that yields .ts video chunks using PyAV for streaming.
|
236 |
"""
|
|
|
239 |
|
240 |
# print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
# Setup
|
243 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
244 |
for key, value in conditional_dict.items():
|
|
|
400 |
print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
401 |
|
402 |
# --- Gradio UI Layout ---
|
403 |
+
with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
|
404 |
+
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
|
|
405 |
|
406 |
+
with gr.Row():
|
407 |
+
with gr.Column(scale=2):
|
408 |
+
with gr.Group():
|
409 |
+
prompt = gr.Textbox(
|
410 |
+
label="Prompt",
|
411 |
+
placeholder="A stylish woman walks down a Tokyo street...",
|
412 |
+
lines=4,
|
413 |
+
value=""
|
414 |
+
)
|
415 |
+
start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
|
416 |
+
|
417 |
+
gr.Markdown("### ⚙️ Settings")
|
418 |
+
with gr.Row():
|
419 |
+
seed = gr.Number(
|
420 |
+
label="Seed",
|
421 |
+
value=-1,
|
422 |
+
info="Use -1 for random seed",
|
423 |
+
precision=0
|
424 |
+
)
|
425 |
+
fps = gr.Slider(
|
426 |
+
label="Playback FPS",
|
427 |
+
minimum=1,
|
428 |
+
maximum=30,
|
429 |
+
value=args.fps,
|
430 |
+
step=1,
|
431 |
+
visible=False,
|
432 |
+
info="Frames per second for playback"
|
433 |
+
)
|
434 |
+
|
435 |
with gr.Row():
|
436 |
+
duration = gr.Slider(
|
437 |
+
label="Duration (seconds)",
|
438 |
+
minimum=1,
|
439 |
+
maximum=5,
|
440 |
+
value=3,
|
441 |
+
step=1,
|
442 |
+
info="Video duration in seconds"
|
443 |
+
)
|
444 |
+
|
445 |
with gr.Row():
|
446 |
+
width = gr.Slider(
|
447 |
+
label="Width",
|
448 |
+
minimum=224,
|
449 |
+
maximum=720,
|
450 |
+
value=400,
|
451 |
+
step=8,
|
452 |
+
info="Video width in pixels (8px steps)"
|
453 |
+
)
|
454 |
+
height = gr.Slider(
|
455 |
+
label="Height",
|
456 |
+
minimum=224,
|
457 |
+
maximum=720,
|
458 |
+
value=224,
|
459 |
+
step=8,
|
460 |
+
info="Video height in pixels (8px steps)"
|
461 |
+
)
|
462 |
|
463 |
+
with gr.Column(scale=3):
|
464 |
+
gr.Markdown("### 📺 Video Stream")
|
465 |
+
streaming_video = gr.Video(
|
466 |
+
label="Live Stream",
|
467 |
+
streaming=True,
|
468 |
+
loop=True,
|
469 |
+
height=400,
|
470 |
+
autoplay=True,
|
471 |
+
show_label=False
|
472 |
+
)
|
473 |
+
|
474 |
+
status_display = gr.HTML(
|
475 |
+
value=(
|
476 |
+
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
477 |
+
"🎬 Ready to start streaming...<br>"
|
478 |
+
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
479 |
+
"</div>"
|
480 |
+
),
|
481 |
+
label="Generation Status"
|
482 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
|
484 |
# Connect the generator to the streaming video
|
485 |
start_btn.click(
|
486 |
fn=video_generation_handler_streaming,
|
487 |
+
inputs=[prompt, seed, fps, width, height, duration],
|
488 |
outputs=[streaming_video, status_display]
|
489 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
|
491 |
# --- Launch App ---
|
492 |
if __name__ == "__main__":
|