jbilcke-hf HF Staff commited on
Commit
6896d2c
·
1 Parent(s): b735958
Files changed (2) hide show
  1. app.py +7 -6
  2. app_with_streaming.py +11 -6
app.py CHANGED
@@ -394,7 +394,7 @@ def video_generation_handler(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, heigh
394
 
395
  # --- Gradio UI Layout ---
396
  with gr.Blocks(title="Wan2.1 1.3B Self-Forcing demo") as demo:
397
- gr.Markdown("Video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009))")
398
 
399
  with gr.Row():
400
  with gr.Column(scale=2):
@@ -409,11 +409,12 @@ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing demo") as demo:
409
 
410
  gr.Markdown("### ⚙️ Settings")
411
  with gr.Row():
412
- seed = gr.Number(
413
- label="Seed",
414
- value=-1,
415
- info="Use -1 for random seed",
416
- precision=0
 
417
  )
418
  fps = gr.Slider(
419
  label="Playback FPS",
 
394
 
395
  # --- Gradio UI Layout ---
396
  with gr.Blocks(title="Wan2.1 1.3B Self-Forcing demo") as demo:
397
+ gr.Markdown("Video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
398
 
399
  with gr.Row():
400
  with gr.Column(scale=2):
 
409
 
410
  gr.Markdown("### ⚙️ Settings")
411
  with gr.Row():
412
+ seed = gr.Slider(
413
+ label="Generation Seed (-1 for random)",
414
+ minimum=-1,
415
+ maximum=2147483647, # 2^31 - 1
416
+ step=1,
417
+ value=-1
418
  )
419
  fps = gr.Slider(
420
  label="Playback FPS",
app_with_streaming.py CHANGED
@@ -263,7 +263,11 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
263
  rnd = torch.Generator(gpu).manual_seed(int(seed))
264
  pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
265
  pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
266
- noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
 
 
 
 
267
 
268
  vae_cache, latents_cache = None, None
269
  if not APP_STATE["current_use_taehv"] and not args.trt:
@@ -432,11 +436,12 @@ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
432
 
433
  gr.Markdown("### ⚙️ Settings")
434
  with gr.Row():
435
- seed = gr.Number(
436
- label="Seed",
437
- value=-1,
438
- info="Use -1 for random seed",
439
- precision=0
 
440
  )
441
  fps = gr.Slider(
442
  label="Playback FPS",
 
263
  rnd = torch.Generator(gpu).manual_seed(int(seed))
264
  pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
265
  pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
266
+
267
+ # Calculate latent dimensions based on actual width/height (assuming 8x downsampling)
268
+ latent_height = height // 8
269
+ latent_width = width // 8
270
+ noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd)
271
 
272
  vae_cache, latents_cache = None, None
273
  if not APP_STATE["current_use_taehv"] and not args.trt:
 
436
 
437
  gr.Markdown("### ⚙️ Settings")
438
  with gr.Row():
439
+ seed = gr.Slider(
440
+ label="Generation Seed (-1 for random)",
441
+ minimum=-1,
442
+ maximum=2147483647, # 2^31 - 1
443
+ step=1,
444
+ value=-1
445
  )
446
  fps = gr.Slider(
447
  label="Playback FPS",