File size: 20,291 Bytes
37d5d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
import gradio as gr
import tempfile
import random
import json
import os
import shutil
import hashlib
import uuid
from pathlib import Path
import time
import logging
import torch
import numpy as np
from typing import Dict, Any, List, Optional, Tuple, Union
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Constants
STORAGE_PATH = Path(os.getenv('STORAGE_PATH', './data'))
LORA_PATH = STORAGE_PATH / "loras"
OUTPUT_PATH = STORAGE_PATH / "output"

MODEL_VERSION = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
DEFAULT_PROMPT_PREFIX = ""

# Create necessary directories
STORAGE_PATH.mkdir(parents=True, exist_ok=True)
LORA_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

# Global variables to track model state
pipe = None
current_lora_id = None

def format_time(seconds: float) -> str:
    """Format time duration in seconds to human readable string"""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    
    parts = []
    if hours > 0:
        parts.append(f"{hours}h")
    if minutes > 0:
        parts.append(f"{minutes}m")
    if secs > 0 or not parts:
        parts.append(f"{secs}s")
        
    return " ".join(parts)

def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
    """Upload a LoRA file and return a hash-based ID for future reference
    
    Args:
        file: Uploaded file object from Gradio
        
    Returns:
        Tuple[str, str]: Hash-based ID for the stored file (returned twice for both outputs)
    """
    if file is None:
        return "", ""
        
    try:
        # Calculate SHA256 hash of the file
        sha256_hash = hashlib.sha256()
        with open(file.name, "rb") as f:
            for chunk in iter(lambda: f.read(4096), b""):
                sha256_hash.update(chunk)
        file_hash = sha256_hash.hexdigest()
        
        # Create destination path using hash
        dest_path = LORA_PATH / f"{file_hash}.safetensors"
        
        # Check if file already exists
        if dest_path.exists():
            logger.info("LoRA file already exists")
            return file_hash, file_hash
        
        # Copy the file to the destination
        shutil.copy(file.name, dest_path)
        
        logger.info(f"a new LoRA file has been uploaded")
        return file_hash, file_hash
    except Exception as e:
        logger.error(f"Error uploading LoRA file: {e}")
        raise gr.Error(f"Failed to upload LoRA file: {str(e)}")

def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
    """Get the path to a LoRA file from its hash-based ID
    
    Args:
        lora_id: Hash-based ID of the stored LoRA file
        
    Returns:
        Path: Path to the LoRA file if found, None otherwise
    """
    if not lora_id:
        return None
        
    # Check if file exists
    lora_path = LORA_PATH / f"{lora_id}.safetensors"
    if lora_path.exists():
        return lora_path
    
    return None

def get_or_create_pipeline(
    enable_cpu_offload: bool = True, 
    flow_shift: float = 3.0
) -> WanPipeline:
    """Get existing pipeline or create a new one if necessary
    
    Args:
        enable_cpu_offload: Whether to enable CPU offload
        flow_shift: Flow shift parameter for scheduler
        
    Returns:
        WanPipeline: The pipeline for generation
    """
    global pipe
    
    if pipe is None:
        # Create a new pipeline
        logger.info("Creating new pipeline")
        
        # Load VAE
        vae = AutoencoderKLWan.from_pretrained(MODEL_VERSION, subfolder="vae", torch_dtype=torch.float32)
        
        # Load transformer
        pipe = WanPipeline.from_pretrained(MODEL_VERSION, vae=vae, torch_dtype=torch.bfloat16)
        
        # Configure scheduler
        pipe.scheduler = UniPCMultistepScheduler.from_config(
            pipe.scheduler.config, 
            flow_shift=flow_shift
        )
        
        # Move to GPU
        pipe.to("cuda")
        
        # Enable CPU offload if requested
        if enable_cpu_offload:
            logger.info("Enabling CPU offload")
            pipe.enable_model_cpu_offload()
    else:
        # Update existing pipeline's scheduler if needed
        if pipe.scheduler.config.flow_shift != flow_shift:
            logger.info(f"Updating scheduler flow_shift from {pipe.scheduler.config.flow_shift} to {flow_shift}")
            pipe.scheduler = UniPCMultistepScheduler.from_config(
                pipe.scheduler.config, 
                flow_shift=flow_shift
            )
    
    return pipe

def manage_lora_weights(pipe: WanPipeline, lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
    """Manage LoRA weights, loading/unloading only when necessary
    
    Args:
        pipe: The pipeline to manage LoRA weights for
        lora_id: UUID of LoRA file to use
        lora_weight: Weight of LoRA contribution
        
    Returns:
        Tuple[bool, Optional[Path]]: (Is using LoRA, Path to LoRA file)
    """
    global current_lora_id
    
    # Determine if we should use LoRA
    using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
    
    # If not using LoRA but we have one loaded, unload it
    if not using_lora and current_lora_id is not None:
        logger.info(f"Unloading current LoRA with ID")
        try:
            # Unload current LoRA weights
            pipe.unload_lora_weights()
            current_lora_id = None
        except Exception as e:
            logger.error(f"Error unloading LoRA weights: {e}")
        return False, None
    
    # If using LoRA, check if we need to change weights
    if using_lora:
        lora_path = get_lora_file_path(lora_id)
        
        if not lora_path:
            # Log the event but continue with base model
            logger.warning(f"LoRA file with ID {lora_id} not found. Using base model instead.")
            
            # If we had a LoRA loaded, unload it
            if current_lora_id is not None:
                logger.info(f"Unloading current LoRA")
                try:
                    pipe.unload_lora_weights()
                except Exception as e:
                    logger.error(f"Error unloading LoRA weights: {e}")
                current_lora_id = None
                
            return False, None
        
        # If LoRA ID changed, update weights
        if lora_id != current_lora_id:
            # If we had a LoRA loaded, unload it first
            if current_lora_id is not None:
                logger.info(f"Unloading current LoRA")
                try:
                    pipe.unload_lora_weights()
                except Exception as e:
                    logger.error(f"Error unloading LoRA weights: {e}")
            
            # Load new LoRA weights
            logger.info("Using a LoRA")
            try:
                pipe.load_lora_weights(lora_path, weight_name=str(lora_path), adapter_name="default")
                current_lora_id = lora_id
            except Exception as e:
                logger.error(f"Error loading LoRA weights: {e}")
                return False, None
        else:
            logger.info(f"Using currently loaded LoRA with ID")
        
        return True, lora_path
    
    return False, None
    
def generate_video(
    prompt: str,
    negative_prompt: str,
    prompt_prefix: str,
    width: int,
    height: int,
    num_frames: int,
    guidance_scale: float,
    flow_shift: float,
    lora_id: Optional[str],
    lora_weight: float,
    inference_steps: int,
    fps: int = 16,
    seed: int = -1,
    enable_cpu_offload: bool = True,
    conditioning_image: Optional[str] = None,
    progress=gr.Progress()
) -> str:
    """Generate a video using the Wan model with optional LoRA weights
    
    Args:
        prompt: Text prompt for generation
        negative_prompt: Negative text prompt
        prompt_prefix: Prefix to add to all prompts
        width: Output video width
        height: Output video height
        num_frames: Number of frames to generate
        guidance_scale: Classifier-free guidance scale
        flow_shift: Flow shift parameter for scheduler
        lora_id: UUID of LoRA file to use
        lora_weight: Weight of LoRA contribution
        inference_steps: Number of inference steps
        fps: Frames per second for output video
        seed: Random seed (-1 for random)
        enable_cpu_offload: Whether to enable CPU offload for VRAM optimization
        conditioning_image: Path to conditioning image for image-to-video (not used in this app)
        progress: Gradio progress callback
        
    Returns:
        str: Video path
    """
    global pipe, current_lora_id  # Move the global declaration to the top of the function
    
    try:
        # Progress 0-5%: Initialize and check inputs
        progress(0.00, desc="Initializing generation")
        
        # Add prefix to prompt
        progress(0.02, desc="Processing prompt")
        if prompt_prefix and not prompt.startswith(prompt_prefix):
            full_prompt = f"{prompt_prefix}{prompt}"
        else:
            full_prompt = prompt
        
        # Create correct num_frames (should be 8*k + 1)
        adjusted_num_frames = ((num_frames - 1) // 8) * 8 + 1
        if adjusted_num_frames != num_frames:
            logger.info(f"Adjusted number of frames from {num_frames} to {adjusted_num_frames} to match model requirements")
            num_frames = adjusted_num_frames
        
        # Set up random seed
        progress(0.03, desc="Setting up random seed")
        if seed == -1:
            seed = random.randint(0, 2**32 - 1)
            logger.info(f"Using randomly generated seed: {seed}")
        
        # Set random seeds for reproducibility
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        generator = torch.Generator(device="cuda")
        generator = generator.manual_seed(seed)
        
        # Progress 5-25%: Get or create pipeline
        progress(0.05, desc="Preparing model")
        pipe = get_or_create_pipeline(enable_cpu_offload, flow_shift)
        
        # Progress 25-40%: Manage LoRA weights
        progress(0.25, desc="Managing LoRA weights")
        using_lora, lora_path = manage_lora_weights(pipe, lora_id, lora_weight)
        
        # Create temporary file for the output
        with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
            output_path = temp_file.name
        
        # Progress 40-90%: Generate the video
        progress(0.40, desc="Starting video generation")
        
        # Set up timing for generation
        start_time = torch.cuda.Event(enable_timing=True)
        end_time = torch.cuda.Event(enable_timing=True)
        
        start_time.record()
        # Update progress once before generation starts
        progress(0.45, desc="Running diffusion process")
        
        # Generate the video without callback
        output = pipe(
            prompt=full_prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            guidance_scale=guidance_scale,
            num_inference_steps=inference_steps,
            generator=generator,
            # noo! don't do this!
            # we will implement the lora weight / scale later
            #cross_attention_kwargs={"scale": lora_weight} if using_lora else None
        ).frames[0]
        
        # Update progress after generation completes
        progress(0.90, desc="Generation complete")
        
        end_time.record()
        torch.cuda.synchronize()
        generation_time = start_time.elapsed_time(end_time) / 1000  # Convert to seconds
        
        logger.info(f"Video generation completed in {format_time(generation_time)}")
        
        # Progress 90-95%: Export video
        progress(0.90, desc="Exporting video")
        export_to_video(output, output_path, fps=fps)
        
        # Progress 95-100%: Save output and clean up
        progress(0.95, desc="Saving video")
        
        # Save a copy to our output directory with UUID for potential future reference
        output_id = str(uuid.uuid4())
        saved_output_path = OUTPUT_PATH / f"{output_id}.mp4"
        shutil.copy(output_path, saved_output_path)
        logger.info(f"Saved video with ID: {output_id}")
        
        # No longer clear the pipeline since we're reusing it
        # Just clean up local variables
        progress(0.98, desc="Cleaning up resources")
        
        progress(1.0, desc="Generation complete")
        
        return output_path
    
    except Exception as e:
        import traceback
        error_msg = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        
        # Clean up CUDA memory on error
        if pipe is not None:
            # Try to unload any LoRA weights on error
            if current_lora_id is not None:
                try:
                    pipe.unload_lora_weights()
                    current_lora_id = None
                except:
                    pass
            
            # Release the pipeline on critical errors
            try:
                pipe = None
                torch.cuda.empty_cache()
            except:
                pass
        
        # Re-raise as Gradio error for UI display
        raise gr.Error(f"Error generating video: {str(e)}")

# Create the Gradio app
with gr.Blocks(title="Video Generation API") as app:
    
    with gr.Tabs():
        # LoRA Upload Tab
        with gr.TabItem("1️⃣ Upload LoRA"):
            gr.Markdown("## Upload LoRA Weights")
            gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.")
            
            with gr.Row():
                lora_file = gr.File(label="LoRA File (safetensors format)")
                
            with gr.Row():
                lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False)
            
            # This will be connected after all components are defined
        
        # Video Generation Tab
        with gr.TabItem("2️⃣ Generate Video"):
            
            with gr.Row():
                with gr.Column(scale=1):
                    # Input parameters
                    prompt = gr.Textbox(
                        label="Prompt",
                        placeholder="Enter your prompt here...",
                        lines=3
                    )
                    
                    negative_prompt = gr.Textbox(
                        label="Negative Prompt",
                        placeholder="Enter negative prompt here...",
                        lines=3,
                        value="worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background"
                    )
                    
                    prompt_prefix = gr.Textbox(
                        label="Prompt Prefix",
                        placeholder="Prefix to add to all prompts",
                        value=DEFAULT_PROMPT_PREFIX
                    )
                    
                    with gr.Row():
                        width = gr.Slider(
                            label="Width",
                            minimum=256,
                            maximum=1280,
                            step=8,
                            value=1280
                        )
                        
                        height = gr.Slider(
                            label="Height",
                            minimum=256,
                            maximum=720,
                            step=8,
                            value=720
                        )
                    
                    with gr.Row():
                        num_frames = gr.Slider(
                            label="Number of Frames",
                            minimum=9,
                            maximum=257,
                            step=8,
                            value=49
                        )
                        
                        fps = gr.Slider(
                            label="FPS",
                            minimum=1,
                            maximum=60,
                            step=1,
                            value=16
                        )
                    
                    with gr.Row():
                        guidance_scale = gr.Slider(
                            label="Guidance Scale",
                            minimum=1.0,
                            maximum=10.0,
                            step=0.1,
                            value=5.0
                        )
                        
                        flow_shift = gr.Slider(
                            label="Flow Shift",
                            minimum=0.0,
                            maximum=10.0,
                            step=0.1,
                            value=3.0
                        )
                    
                    lora_id = gr.Textbox(
                        label="LoRA ID (from upload tab)",
                        placeholder="Enter your LoRA ID here...",
                    )
                    
                    with gr.Row():
                        lora_weight = gr.Slider(
                            label="LoRA Weight",
                            minimum=0.0,
                            maximum=1.0,
                            step=0.01,
                            value=0.7
                        )
                        
                        inference_steps = gr.Slider(
                            label="Inference Steps",
                            minimum=1,
                            maximum=100,
                            step=1,
                            value=30
                        )
                    
                    seed = gr.Slider(
                        label="Generation Seed (-1 for random)",
                        minimum=-1,
                        maximum=2147483647,  # 2^31 - 1
                        step=1,
                        value=-1
                    )
                    
                    enable_cpu_offload = gr.Checkbox(
                        label="Enable Model CPU Offload (for low-VRAM GPUs)",
                        value=False
                    )
                    
                    generate_btn = gr.Button(
                        "Generate Video",
                        variant="primary"
                    )
                
                with gr.Column(scale=1):
                    # Output component - just the video preview
                    preview_video = gr.Video(
                        label="Generated Video",
                        interactive=False
                    )
            
            # Connect the generate button
            generate_btn.click(
                fn=generate_video,
                inputs=[
                    prompt,
                    negative_prompt,
                    prompt_prefix,
                    width,
                    height,
                    num_frames,
                    guidance_scale,
                    flow_shift,
                    lora_id,
                    lora_weight,
                    inference_steps,
                    fps,
                    seed,
                    enable_cpu_offload
                ],
                outputs=[
                    preview_video
                ]
            )
    
    # Connect LoRA upload to both display fields
    lora_file.change(
        fn=upload_lora_file,
        inputs=[lora_file],
        outputs=[lora_id_output, lora_id]
    )

# Launch the app
if __name__ == "__main__":
    app.launch()