File size: 8,536 Bytes
9fa4d05
 
 
 
 
 
 
 
 
 
 
 
 
afd038c
9fa4d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afd038c
9fa4d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afd038c
9fa4d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import torch
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
    StableVideoDiffusionPipeline,
    DDIMScheduler,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionXLImg2ImgPipeline
)
from PIL import Image
import numpy as np
import time
import spaces

# Global pipelines cache
_model_cache = {}

def list_available_image_models():
    """Return list of available image generation models"""
    return [
        "stabilityai/stable-diffusion-xl-base-1.0",
        "stabilityai/sdxl-turbo",
        "runwayml/stable-diffusion-v1-5",
        "stabilityai/stable-diffusion-2-1"
    ]

def list_available_video_models():
    """Return list of available video generation models"""
    return [
        "stabilityai/stable-video-diffusion-img2vid-xt",
        "stabilityai/stable-video-diffusion-img2vid"
    ]

def _get_model_key(model_name, is_img2img=False):
    """Generate a unique key for the model cache"""
    return f"{model_name}_{'img2img' if is_img2img else 'txt2img'}"

def _load_image_pipeline(model_name, is_img2img=False):
    """Load image generation pipeline with caching"""
    model_key = _get_model_key(model_name, is_img2img)
    
    if model_key not in _model_cache:
        print(f"Loading image model: {model_name} ({is_img2img})")
        
        if "xl" in model_name.lower():
            # SDXL model
            if is_img2img:
                pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16,
                    variant="fp16",
                    use_safetensors=True
                )
            else:
                pipeline = StableDiffusionXLPipeline.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16,
                    variant="fp16",
                    use_safetensors=True
                )
        else:
            # SD 1.5/2.x model
            if is_img2img:
                pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16
                )
            else:
                pipeline = StableDiffusionPipeline.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16
                )
                
        pipeline.enable_model_cpu_offload()
        pipeline.safety_checker = None  # disable safety checker for performance
        _model_cache[model_key] = pipeline
        
    return _model_cache[model_key]

def _load_video_pipeline(model_name):
    """Load video generation pipeline with caching"""
    if model_name not in _model_cache:
        print(f"Loading video model: {model_name}")
        
        pipeline = StableVideoDiffusionPipeline.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            variant="fp16"
        )
        pipeline.enable_model_cpu_offload()
        
        # Enable forward chunking for lower VRAM use
        pipeline.unet.enable_forward_chunking(chunk_size=1)
        
        _model_cache[model_name] = pipeline
        
    return _model_cache[model_name]

@spaces.GPU
def preview_image_generation(prompt, image_model="stabilityai/stable-diffusion-xl-base-1.0", width=1024, height=576, seed=None):
    """
    Generate a preview image from a prompt
    
    Args:
        prompt: Text prompt for image generation
        image_model: Model to use
        width/height: Image dimensions
        seed: Random seed (None for random)
    
    Returns:
        PIL Image object
    """
    pipeline = _load_image_pipeline(image_model)
    generator = None
    if seed is not None:
        generator = torch.Generator(device="cuda").manual_seed(seed)
    
    with torch.autocast("cuda"):
        image = pipeline(
            prompt,
            width=width,
            height=height,
            generator=generator,
            num_inference_steps=30
        ).images[0]
    
    return image

@spaces.GPU
def create_video_segments(
    segments,
    scene_prompts,
    image_model="stabilityai/stable-diffusion-xl-base-1.0",
    video_model="stabilityai/stable-video-diffusion-img2vid-xt",
    width=1024,
    height=576,
    dynamic_fps=True,
    base_fps=None,
    seed=None,
    work_dir=".",
    image_mode="Independent",
    strength=0.5,
    progress_callback=None
):
    """
    Generate an image and a short video clip for each segment.
    
    Args:
        segments: List of segment dictionaries with timing info
        scene_prompts: List of text prompts for each segment
        image_model: Model to use for image generation
        video_model: Model to use for video generation
        width/height: Video dimensions
        dynamic_fps: If True, adjust FPS to match segment duration
        base_fps: Base FPS when dynamic_fps is False
        seed: Random seed (None or 0 for random)
        work_dir: Directory to save intermediate files
        image_mode: "Independent" or "Consistent (Img2Img)" for style continuity
        strength: Strength parameter for img2img (0-1, lower preserves more reference)
        progress_callback: Function to call with progress updates
        
    Returns:
        List of file paths to the segment video clips
    """
    # Initialize image and video pipelines
    txt2img_pipe = _load_image_pipeline(image_model)
    video_pipe = _load_video_pipeline(video_model)
    
    # Set manual seed if provided
    generator = None
    if seed is not None and int(seed) != 0:
        generator = torch.Generator(device="cuda").manual_seed(int(seed))
    
    segment_files = []
    reference_image = None
    
    for idx, (seg, prompt) in enumerate(zip(segments, scene_prompts)):
        if progress_callback:
            progress_percent = (idx / len(segments)) * 100
            progress_callback(progress_percent, f"Generating scene {idx+1}/{len(segments)}")
        
        seg_start = seg["start"]
        seg_end = seg["end"]
        seg_dur = max(seg_end - seg_start, 0.001)
        
        # Determine FPS for this segment
        if dynamic_fps:
            # Use 25 frames spanning the segment duration
            fps = 25.0 / seg_dur
            # Cap FPS to 30 to avoid too high frame rate for very short segments
            if fps > 30.0: 
                fps = 30.0
        else:
            fps = base_fps or 10.0  # use given fixed fps, default 10 if not set
        
        # 1. Generate initial frame image with Stable Diffusion
        img_filename = os.path.join(work_dir, f"segment{idx:02d}_img.png")
        
        with torch.autocast("cuda"):
            if image_mode == "Consistent (Img2Img)" and reference_image is not None:
                # Use img2img with reference image for style consistency
                img2img_pipe = _load_image_pipeline(image_model, is_img2img=True)
                image = img2img_pipe(
                    prompt=prompt,
                    image=reference_image,
                    strength=strength,
                    generator=generator,
                    num_inference_steps=30
                ).images[0]
            else:
                # Regular text-to-image generation
                image = txt2img_pipe(
                    prompt=prompt,
                    width=width,
                    height=height,
                    generator=generator,
                    num_inference_steps=30
                ).images[0]
        
        # Save the image for inspection
        image.save(img_filename)
        
        # Update reference image for next segment if using consistent mode
        if image_mode == "Consistent (Img2Img)":
            reference_image = image
        
        # 2. Generate video frames from the image using stable video diffusion
        with torch.autocast("cuda"):
            video_frames = video_pipe(
                image, 
                num_frames=25,
                fps=fps,
                decode_chunk_size=1,
                generator=generator
            ).frames[0]
        
        # Save video frames to a file (mp4)
        seg_filename = os.path.join(work_dir, f"segment_{idx:03d}.mp4")
        from diffusers.utils import export_to_video
        export_to_video(video_frames, seg_filename, fps=fps)
        segment_files.append(seg_filename)
        
        # Free memory from frames
        del video_frames
        torch.cuda.empty_cache()
        
    # Return list of video segment files
    return segment_files