Spaces:
Running
Running
Pravin Barapatre
Pin dependencies for Hugging Face Spaces compatibility and remove submodule issue
db8251f
import torch | |
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
from diffusers.utils import export_to_video | |
import numpy as np | |
import argparse | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def generate_video_from_text( | |
prompt, | |
model_id="damo-vilab/text-to-video-ms-1.7b", | |
num_frames=16, | |
fps=8, | |
num_inference_steps=25, | |
guidance_scale=7.5, | |
seed=None, | |
output_path="generated_video.mp4" | |
): | |
""" | |
Generate a video from text prompt using Hugging Face models | |
Args: | |
prompt (str): Text description of the video | |
model_id (str): Hugging Face model ID | |
num_frames (int): Number of frames to generate | |
fps (int): Frames per second | |
num_inference_steps (int): Number of denoising steps | |
guidance_scale (float): Guidance scale for generation | |
seed (int): Random seed for reproducibility | |
output_path (str): Output video file path | |
""" | |
# Check device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
try: | |
# Set seed for reproducibility | |
if seed is not None: | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
logger.info(f"Loading model: {model_id}") | |
# Load pipeline | |
pipeline = DiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
variant="fp16" if device == "cuda" else None | |
) | |
# Move to device | |
pipeline = pipeline.to(device) | |
# Optimize scheduler for faster inference | |
if hasattr(pipeline, 'scheduler'): | |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipeline.scheduler.config | |
) | |
# Enable memory efficient attention if available | |
if device == "cuda": | |
pipeline.enable_model_cpu_offload() | |
pipeline.enable_vae_slicing() | |
logger.info(f"Generating video with prompt: {prompt}") | |
logger.info(f"Parameters: frames={num_frames}, fps={fps}, steps={num_inference_steps}") | |
# Generate video | |
video_frames = pipeline( | |
prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_frames=num_frames | |
).frames | |
# Convert to numpy array | |
video_frames = np.array(video_frames) | |
# Save video | |
export_to_video(video_frames, output_path, fps=fps) | |
logger.info(f"Video saved to: {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Error generating video: {str(e)}") | |
raise | |
def main(): | |
parser = argparse.ArgumentParser(description="Generate video from text using Hugging Face models") | |
parser.add_argument("prompt", help="Text description of the video to generate") | |
parser.add_argument("--model", default="damo-vilab/text-to-video-ms-1.7b", | |
help="Hugging Face model ID to use") | |
parser.add_argument("--frames", type=int, default=16, | |
help="Number of frames to generate (default: 16)") | |
parser.add_argument("--fps", type=int, default=8, | |
help="Frames per second (default: 8)") | |
parser.add_argument("--steps", type=int, default=25, | |
help="Number of inference steps (default: 25)") | |
parser.add_argument("--guidance", type=float, default=7.5, | |
help="Guidance scale (default: 7.5)") | |
parser.add_argument("--seed", type=int, default=None, | |
help="Random seed for reproducibility") | |
parser.add_argument("--output", default="generated_video.mp4", | |
help="Output video file path (default: generated_video.mp4)") | |
args = parser.parse_args() | |
try: | |
output_path = generate_video_from_text( | |
prompt=args.prompt, | |
model_id=args.model, | |
num_frames=args.frames, | |
fps=args.fps, | |
num_inference_steps=args.steps, | |
guidance_scale=args.guidance, | |
seed=args.seed, | |
output_path=args.output | |
) | |
print(f"Video generated successfully: {output_path}") | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
return 1 | |
return 0 | |
if __name__ == "__main__": | |
exit(main()) |