Spaces:
Running
on
L40S
Running
on
L40S
| import gradio as gr | |
| import torch | |
| import gc | |
| import numpy as np | |
| import random | |
| import os | |
| import tempfile | |
| import soundfile as sf | |
| os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG' | |
| from transformers import AutoProcessor, pipeline | |
| from elastic_models.transformers import MusicgenForConditionalGeneration | |
| def set_seed(seed: int = 42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def cleanup_gpu(): | |
| """Clean up GPU memory to avoid TensorRT conflicts.""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| def cleanup_temp_files(): | |
| """Clean up old temporary audio files.""" | |
| import glob | |
| import time | |
| temp_dir = tempfile.gettempdir() | |
| cutoff_time = time.time() - 3600 | |
| for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.wav")): | |
| try: | |
| if os.path.getctime(temp_file) < cutoff_time: | |
| os.remove(temp_file) | |
| print(f"[CLEANUP] Removed old temp file: {temp_file}") | |
| except OSError: | |
| pass | |
| _generator = None | |
| _processor = None | |
| def load_model(): | |
| global _generator, _processor | |
| if _generator is None: | |
| print("[MODEL] Starting model initialization...") | |
| cleanup_gpu() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[MODEL] Using device: {device}") | |
| print("[MODEL] Loading processor...") | |
| _processor = AutoProcessor.from_pretrained( | |
| "facebook/musicgen-large" | |
| ) | |
| print("[MODEL] Loading model...") | |
| model = MusicgenForConditionalGeneration.from_pretrained( | |
| "facebook/musicgen-large", | |
| torch_dtype=torch.float16, | |
| device=device, | |
| mode="S", | |
| __paged=True, | |
| ) | |
| model.eval() | |
| print("[MODEL] Creating pipeline...") | |
| _generator = pipeline( | |
| task="text-to-audio", | |
| model=model, | |
| tokenizer=_processor.tokenizer, | |
| device=device, | |
| ) | |
| print("[MODEL] Model initialization completed successfully") | |
| return _generator, _processor | |
| def calculate_max_tokens(duration_seconds): | |
| token_rate = 50 | |
| max_new_tokens = int(duration_seconds * token_rate) | |
| print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})") | |
| return max_new_tokens | |
| def generate_music(text_prompt, duration=10, guidance_scale=3.0): | |
| try: | |
| generator, processor = load_model() | |
| print(f"[GENERATION] Starting generation...") | |
| print(f"[GENERATION] Prompt: '{text_prompt}'") | |
| print(f"[GENERATION] Duration: {duration}s") | |
| print(f"[GENERATION] Guidance scale: {guidance_scale}") | |
| cleanup_gpu() | |
| import time | |
| set_seed(42) | |
| print(f"[GENERATION] Using seed: {42}") | |
| max_new_tokens = calculate_max_tokens(duration) | |
| generation_params = { | |
| 'do_sample': True, | |
| 'guidance_scale': guidance_scale, | |
| 'max_new_tokens': max_new_tokens, | |
| 'min_new_tokens': max_new_tokens, | |
| 'cache_implementation': 'paged', | |
| } | |
| prompts = [text_prompt] | |
| outputs = generator( | |
| prompts, | |
| batch_size=1, | |
| generate_kwargs=generation_params | |
| ) | |
| print(f"[GENERATION] Generation completed successfully") | |
| output = outputs[0] | |
| audio_data = output['audio'] | |
| sample_rate = output['sampling_rate'] | |
| print(f"[GENERATION] Audio shape: {audio_data.shape}") | |
| print(f"[GENERATION] Sample rate: {sample_rate}") | |
| print(f"[GENERATION] Audio dtype: {audio_data.dtype}") | |
| print(f"[GENERATION] Audio is numpy: {type(audio_data)}") | |
| if hasattr(audio_data, 'cpu'): | |
| audio_data = audio_data.cpu().numpy() | |
| print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}") | |
| if len(audio_data.shape) == 3: | |
| audio_data = audio_data[0] | |
| if len(audio_data.shape) == 2: | |
| if audio_data.shape[0] < audio_data.shape[1]: | |
| audio_data = audio_data.T | |
| if audio_data.shape[1] > 1: | |
| audio_data = audio_data[:, 0] | |
| else: | |
| audio_data = audio_data.flatten() | |
| audio_data = audio_data.flatten() | |
| print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}") | |
| max_val = np.max(np.abs(audio_data)) | |
| if max_val > 0: | |
| audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping | |
| audio_data = audio_data.astype(np.float32) | |
| print(f"[GENERATION] Final audio shape: {audio_data.shape}") | |
| print(f"[GENERATION] Audio range: [{np.min(audio_data):.3f}, {np.max(audio_data):.3f}]") | |
| print(f"[GENERATION] Sample rate: {sample_rate}") | |
| return (sample_rate, audio_data) | |
| except Exception as e: | |
| print(f"[ERROR] Generation failed: {str(e)}") | |
| cleanup_gpu() | |
| return None | |
| with gr.Blocks(title="MusicGen Large - Music Generation") as demo: | |
| gr.Markdown("# 🎵 MusicGen Large Music Generator") | |
| gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Music Description", | |
| placeholder="Enter a description of the music you want to generate", | |
| lines=3, | |
| value="A groovy funk bassline with a tight drum beat" | |
| ) | |
| with gr.Row(): | |
| duration = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| value=10, | |
| step=1, | |
| label="Duration (seconds)" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.0, | |
| step=0.5, | |
| label="Guidance Scale", | |
| info="Higher values follow prompt more closely" | |
| ) | |
| generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg") | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| label="Generated Music", | |
| type="numpy", | |
| interactive=False, | |
| ) | |
| with gr.Accordion("Tips", open=False): | |
| gr.Markdown(""" | |
| - Be specific in your descriptions (e.g., "slow blues guitar with harmonica") | |
| - Higher guidance scale = follows prompt more closely | |
| - Lower guidance scale = more creative/varied results | |
| - Duration is limited to 30 seconds for faster generation | |
| """) | |
| generate_btn.click( | |
| fn=generate_music, | |
| inputs=[text_input, duration, guidance_scale], | |
| outputs=audio_output, | |
| show_progress=True | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "A groovy funk bassline with a tight drum beat", | |
| "Relaxing acoustic guitar melody", | |
| "Electronic dance music with heavy bass", | |
| "Classical violin concerto", | |
| "Reggae with steel drums and bass", | |
| "Rock ballad with electric guitar solo", | |
| "Jazz piano improvisation with brushed drums", | |
| "Ambient synthwave with retro vibes", | |
| ], | |
| inputs=text_input, | |
| label="Example Prompts" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| <div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;"> | |
| <strong>Limitations:</strong><br> | |
| • The model is not able to generate realistic vocals.<br> | |
| • The model has been trained with English descriptions and will not perform as well in other languages.<br> | |
| • The model does not perform equally well for all music styles and cultures.<br> | |
| • The model sometimes generates end of songs, collapsing to silence.<br> | |
| • It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| cleanup_temp_files() | |
| demo.launch() | |