Spaces:
Running
Running
Pravin Barapatre
Pin dependencies for Hugging Face Spaces compatibility and remove submodule issue
db8251f
import torch | |
import gradio as gr | |
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
from diffusers.utils import export_to_video | |
import numpy as np | |
from PIL import Image | |
import os | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TextToVideoGenerator: | |
def __init__(self): | |
self.pipeline = None | |
self.current_model = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Available models | |
self.models = { | |
"damo-vilab/text-to-video-ms-1.7b": { | |
"name": "DAMO Text-to-Video MS-1.7B", | |
"description": "Fast and efficient text-to-video model", | |
"max_frames": 16, | |
"fps": 8 | |
}, | |
"cerspense/zeroscope_v2_XL": { | |
"name": "Zeroscope v2 XL", | |
"description": "High-quality text-to-video model", | |
"max_frames": 24, | |
"fps": 6 | |
}, | |
"stabilityai/stable-video-diffusion-img2vid-xt": { | |
"name": "Stable Video Diffusion XT", | |
"description": "Image-to-video model (requires initial image)", | |
"max_frames": 25, | |
"fps": 6 | |
} | |
} | |
def load_model(self, model_id): | |
"""Load the specified model""" | |
if self.current_model == model_id and self.pipeline is not None: | |
return f"Model {self.models[model_id]['name']} is already loaded" | |
try: | |
logger.info(f"Loading model: {model_id}") | |
# Clear GPU memory if needed | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Load pipeline | |
self.pipeline = DiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
variant="fp16" if self.device == "cuda" else None | |
) | |
# Move to device | |
self.pipeline = self.pipeline.to(self.device) | |
# Optimize scheduler for faster inference | |
if hasattr(self.pipeline, 'scheduler'): | |
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
self.pipeline.scheduler.config | |
) | |
# Enable memory efficient attention if available | |
if self.device == "cuda": | |
self.pipeline.enable_model_cpu_offload() | |
self.pipeline.enable_vae_slicing() | |
self.current_model = model_id | |
logger.info(f"Successfully loaded model: {model_id}") | |
return f"Successfully loaded {self.models[model_id]['name']}" | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return f"Error loading model: {str(e)}" | |
def generate_video(self, prompt, model_id, num_frames=16, fps=8, num_inference_steps=25, guidance_scale=7.5, seed=None): | |
"""Generate video from text prompt""" | |
try: | |
# Load model if not already loaded | |
if self.current_model != model_id: | |
load_result = self.load_model(model_id) | |
if "Error" in load_result: | |
return None, load_result | |
# Set seed for reproducibility | |
if seed is not None: | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
# Get model config | |
model_config = self.models[model_id] | |
num_frames = min(num_frames, model_config["max_frames"]) | |
fps = model_config["fps"] | |
logger.info(f"Generating video with prompt: {prompt}") | |
logger.info(f"Parameters: frames={num_frames}, fps={fps}, steps={num_inference_steps}") | |
# Generate video | |
video_frames = self.pipeline( | |
prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_frames=num_frames | |
).frames | |
# Convert to numpy array | |
video_frames = np.array(video_frames) | |
# Save video | |
output_path = f"generated_video_{seed if seed else 'random'}.mp4" | |
export_to_video(video_frames, output_path, fps=fps) | |
logger.info(f"Video saved to: {output_path}") | |
return output_path, f"Video generated successfully! Saved as {output_path}" | |
except Exception as e: | |
logger.error(f"Error generating video: {str(e)}") | |
return None, f"Error generating video: {str(e)}" | |
def get_available_models(self): | |
"""Get list of available models""" | |
return list(self.models.keys()) | |
def get_model_info(self, model_id): | |
"""Get information about a specific model""" | |
if model_id in self.models: | |
return self.models[model_id] | |
return None | |
# Initialize the generator | |
generator = TextToVideoGenerator() | |
def create_interface(): | |
"""Create Gradio interface""" | |
def generate_video_interface(prompt, model_id, num_frames, fps, num_inference_steps, guidance_scale, seed): | |
if not prompt.strip(): | |
return None, "Please enter a prompt" | |
return generator.generate_video( | |
prompt=prompt, | |
model_id=model_id, | |
num_frames=num_frames, | |
fps=fps, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
seed=seed | |
) | |
# Create interface | |
with gr.Blocks(title="Text-to-Video Generator", theme=gr.themes.Soft()) as interface: | |
gr.Markdown("# Text-to-Video Generation with Hugging Face Models") | |
gr.Markdown("Generate videos from text descriptions using state-of-the-art AI models") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Input section | |
with gr.Group(): | |
gr.Markdown("## Input Parameters") | |
prompt = gr.Textbox( | |
label="Text Prompt", | |
placeholder="Enter your video description here...", | |
lines=3, | |
max_lines=5 | |
) | |
model_id = gr.Dropdown( | |
choices=generator.get_available_models(), | |
value=generator.get_available_models()[0], | |
label="Model", | |
info="Select the model to use for generation" | |
) | |
with gr.Row(): | |
num_frames = gr.Slider( | |
minimum=8, | |
maximum=24, | |
value=16, | |
step=1, | |
label="Number of Frames", | |
info="More frames = longer video" | |
) | |
fps = gr.Slider( | |
minimum=4, | |
maximum=12, | |
value=8, | |
step=1, | |
label="FPS", | |
info="Frames per second" | |
) | |
with gr.Row(): | |
num_inference_steps = gr.Slider( | |
minimum=10, | |
maximum=50, | |
value=25, | |
step=1, | |
label="Inference Steps", | |
info="More steps = better quality but slower" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, | |
maximum=20.0, | |
value=7.5, | |
step=0.5, | |
label="Guidance Scale", | |
info="Higher values = more prompt adherence" | |
) | |
seed = gr.Number( | |
label="Seed (Optional)", | |
value=None, | |
info="Set for reproducible results" | |
) | |
generate_btn = gr.Button("Generate Video", variant="primary", size="lg") | |
# Output section | |
with gr.Group(): | |
gr.Markdown("## Output") | |
status_text = gr.Textbox(label="Status", interactive=False) | |
video_output = gr.Video(label="Generated Video") | |
with gr.Column(scale=1): | |
# Model information | |
with gr.Group(): | |
gr.Markdown("## Model Information") | |
model_info = gr.JSON(label="Current Model Details") | |
# Examples | |
with gr.Group(): | |
gr.Markdown("## Example Prompts") | |
examples = [ | |
["A beautiful sunset over the ocean with waves crashing on the shore"], | |
["A cat playing with a ball of yarn in a cozy living room"], | |
["A futuristic city with flying cars and neon lights"], | |
["A butterfly emerging from a cocoon in a garden"], | |
["A rocket launching into space with fire and smoke"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
label="Try these examples" | |
) | |
# Event handlers | |
generate_btn.click( | |
fn=generate_video_interface, | |
inputs=[prompt, model_id, num_frames, fps, num_inference_steps, guidance_scale, seed], | |
outputs=[video_output, status_text] | |
) | |
# Update model info when model changes | |
def update_model_info(model_id): | |
info = generator.get_model_info(model_id) | |
return info | |
model_id.change( | |
fn=update_model_info, | |
inputs=model_id, | |
outputs=model_info | |
) | |
# Load initial model info | |
interface.load(lambda: generator.get_model_info(generator.get_available_models()[0]), outputs=model_info) | |
return interface | |
if __name__ == "__main__": | |
# Create and launch the interface | |
interface = create_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) |