Spaces:
Running
Running
import os | |
import torch | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionXLPipeline, | |
StableVideoDiffusionPipeline, | |
DDIMScheduler, | |
StableDiffusionImg2ImgPipeline, | |
StableDiffusionXLImg2ImgPipeline | |
) | |
from PIL import Image | |
import numpy as np | |
import time | |
import spaces | |
# Global pipelines cache | |
_model_cache = {} | |
def list_available_image_models(): | |
"""Return list of available image generation models""" | |
return [ | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
"stabilityai/sdxl-turbo", | |
"runwayml/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-1" | |
] | |
def list_available_video_models(): | |
"""Return list of available video generation models""" | |
return [ | |
"stabilityai/stable-video-diffusion-img2vid-xt", | |
"stabilityai/stable-video-diffusion-img2vid" | |
] | |
def _get_model_key(model_name, is_img2img=False): | |
"""Generate a unique key for the model cache""" | |
return f"{model_name}_{'img2img' if is_img2img else 'txt2img'}" | |
def _load_image_pipeline(model_name, is_img2img=False): | |
"""Load image generation pipeline with caching""" | |
model_key = _get_model_key(model_name, is_img2img) | |
if model_key not in _model_cache: | |
print(f"Loading image model: {model_name} ({is_img2img})") | |
if "xl" in model_name.lower(): | |
# SDXL model | |
if is_img2img: | |
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True | |
) | |
else: | |
pipeline = StableDiffusionXLPipeline.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True | |
) | |
else: | |
# SD 1.5/2.x model | |
if is_img2img: | |
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 | |
) | |
else: | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 | |
) | |
pipeline.enable_model_cpu_offload() | |
pipeline.safety_checker = None # disable safety checker for performance | |
_model_cache[model_key] = pipeline | |
return _model_cache[model_key] | |
def _load_video_pipeline(model_name): | |
"""Load video generation pipeline with caching""" | |
if model_name not in _model_cache: | |
print(f"Loading video model: {model_name}") | |
pipeline = StableVideoDiffusionPipeline.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
variant="fp16" | |
) | |
pipeline.enable_model_cpu_offload() | |
# Enable forward chunking for lower VRAM use | |
pipeline.unet.enable_forward_chunking(chunk_size=1) | |
_model_cache[model_name] = pipeline | |
return _model_cache[model_name] | |
def preview_image_generation(prompt, image_model="stabilityai/stable-diffusion-xl-base-1.0", width=1024, height=576, seed=None): | |
""" | |
Generate a preview image from a prompt | |
Args: | |
prompt: Text prompt for image generation | |
image_model: Model to use | |
width/height: Image dimensions | |
seed: Random seed (None for random) | |
Returns: | |
PIL Image object | |
""" | |
pipeline = _load_image_pipeline(image_model) | |
generator = None | |
if seed is not None: | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
with torch.autocast("cuda"): | |
image = pipeline( | |
prompt, | |
width=width, | |
height=height, | |
generator=generator, | |
num_inference_steps=30 | |
).images[0] | |
return image | |
def create_video_segments( | |
segments, | |
scene_prompts, | |
image_model="stabilityai/stable-diffusion-xl-base-1.0", | |
video_model="stabilityai/stable-video-diffusion-img2vid-xt", | |
width=1024, | |
height=576, | |
dynamic_fps=True, | |
base_fps=None, | |
seed=None, | |
work_dir=".", | |
image_mode="Independent", | |
strength=0.5, | |
progress_callback=None | |
): | |
""" | |
Generate an image and a short video clip for each segment. | |
Args: | |
segments: List of segment dictionaries with timing info | |
scene_prompts: List of text prompts for each segment | |
image_model: Model to use for image generation | |
video_model: Model to use for video generation | |
width/height: Video dimensions | |
dynamic_fps: If True, adjust FPS to match segment duration | |
base_fps: Base FPS when dynamic_fps is False | |
seed: Random seed (None or 0 for random) | |
work_dir: Directory to save intermediate files | |
image_mode: "Independent" or "Consistent (Img2Img)" for style continuity | |
strength: Strength parameter for img2img (0-1, lower preserves more reference) | |
progress_callback: Function to call with progress updates | |
Returns: | |
List of file paths to the segment video clips | |
""" | |
# Initialize image and video pipelines | |
txt2img_pipe = _load_image_pipeline(image_model) | |
video_pipe = _load_video_pipeline(video_model) | |
# Set manual seed if provided | |
generator = None | |
if seed is not None and int(seed) != 0: | |
generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
segment_files = [] | |
reference_image = None | |
for idx, (seg, prompt) in enumerate(zip(segments, scene_prompts)): | |
if progress_callback: | |
progress_percent = (idx / len(segments)) * 100 | |
progress_callback(progress_percent, f"Generating scene {idx+1}/{len(segments)}") | |
seg_start = seg["start"] | |
seg_end = seg["end"] | |
seg_dur = max(seg_end - seg_start, 0.001) | |
# Determine FPS for this segment | |
if dynamic_fps: | |
# Use 25 frames spanning the segment duration | |
fps = 25.0 / seg_dur | |
# Cap FPS to 30 to avoid too high frame rate for very short segments | |
if fps > 30.0: | |
fps = 30.0 | |
else: | |
fps = base_fps or 10.0 # use given fixed fps, default 10 if not set | |
# 1. Generate initial frame image with Stable Diffusion | |
img_filename = os.path.join(work_dir, f"segment{idx:02d}_img.png") | |
with torch.autocast("cuda"): | |
if image_mode == "Consistent (Img2Img)" and reference_image is not None: | |
# Use img2img with reference image for style consistency | |
img2img_pipe = _load_image_pipeline(image_model, is_img2img=True) | |
image = img2img_pipe( | |
prompt=prompt, | |
image=reference_image, | |
strength=strength, | |
generator=generator, | |
num_inference_steps=30 | |
).images[0] | |
else: | |
# Regular text-to-image generation | |
image = txt2img_pipe( | |
prompt=prompt, | |
width=width, | |
height=height, | |
generator=generator, | |
num_inference_steps=30 | |
).images[0] | |
# Save the image for inspection | |
image.save(img_filename) | |
# Update reference image for next segment if using consistent mode | |
if image_mode == "Consistent (Img2Img)": | |
reference_image = image | |
# 2. Generate video frames from the image using stable video diffusion | |
with torch.autocast("cuda"): | |
video_frames = video_pipe( | |
image, | |
num_frames=25, | |
fps=fps, | |
decode_chunk_size=1, | |
generator=generator | |
).frames[0] | |
# Save video frames to a file (mp4) | |
seg_filename = os.path.join(work_dir, f"segment_{idx:03d}.mp4") | |
from diffusers.utils import export_to_video | |
export_to_video(video_frames, seg_filename, fps=fps) | |
segment_files.append(seg_filename) | |
# Free memory from frames | |
del video_frames | |
torch.cuda.empty_cache() | |
# Return list of video segment files | |
return segment_files |