import sys sys.path.append("GenAU") import importlib import subprocess try: importlib.import_module("apex") subprocess.run(["pip", "uninstall", "-y", "apex"], check=True) except ImportError: pass from src.tools.training_utils import get_restore_step from src.utilities.model.model_util import instantiate_from_config from src.tools.configuration import Configuration from src.tools.download_manager import get_checkpoint_path import os import torch import gradio as gr import soundfile as sf from pytorch_lightning import seed_everything import importlib import subprocess def load_model( model_name: str = "genau-l-full-hq-data", config_yaml_path: str = None, checkpoint_path: str = None, ): assert torch.cuda.is_available(), "CUDA is not available." if config_yaml_path is None: config_yaml_path = get_checkpoint_path(f"{model_name}_config") if checkpoint_path is None: checkpoint_path = get_checkpoint_path(model_name) print("checkpoint_path", checkpoint_path) configuration = Configuration(config_yaml_path) config_dict = configuration.get_config() if checkpoint_path is not None: config_dict["reload_from_ckpt"] = checkpoint_path exp_name = os.path.basename(config_yaml_path.split(".")[0]) exp_group_name = os.path.basename(os.path.dirname(config_yaml_path)) log_path = config_dict['logging']["log_directory"] if "reload_from_ckpt" in config_dict and config_dict["reload_from_ckpt"]: resume_from_checkpoint = config_dict["reload_from_ckpt"] else: # Otherwise try to load the latest ckpt_folder = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") if not os.path.exists(ckpt_folder): raise RuntimeError(f"No checkpoint directory found at {ckpt_folder}") restore_step, _ = get_restore_step(ckpt_folder) resume_from_checkpoint = os.path.join(ckpt_folder, restore_step) config_dict["model"]["params"]["ckpt_path"] = resume_from_checkpoint latent_diffusion = instantiate_from_config(config_dict["model"]) latent_diffusion.eval() latent_diffusion = latent_diffusion.cuda() return latent_diffusion, config_dict def infer_gradio( prompt: str, seed: int, cfg_weight: float, n_cand: int, ddim_steps: int ): """ Inference function called by Gradio's interface. Returns a WAV audio object (data, sr) to play in the Gradio UI. """ seed_everything(seed) saved_wav_path = latent_diffusion.text_to_audio( prompt=prompt, ddim_steps=ddim_steps, unconditional_guidance_scale=cfg_weight, n_gen=n_cand, use_ema=True ) data, sr = sf.read(saved_wav_path) return (sr, data) latent_diffusion, config_yaml = load_model( model_name="genau-l-full-hq-data", # or whichever default model you want config_yaml_path=None, # or path to your .yaml if you have it checkpoint_path=None # or a direct path to a .ckpt file ) with gr.Blocks() as demo: gr.Markdown("# Text-to-Audio Demo") with gr.Row(): prompt_input = gr.Textbox( lines=2, label="Prompt", placeholder="Type your text prompt here...", value="A calm piano melody." ) with gr.Accordion("Advanced Parameters", open=False): seed_slider = gr.Number( value=0, label="Random Seed" ) cfg_slider = gr.Slider( minimum=0.1, maximum=10.0, value=4.0, step=0.1, label="Classifier-Free Guidance Weight" ) n_cand_slider = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Number of Candidates" ) ddim_steps_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="DDIM Steps" ) generate_button = gr.Button("Generate Audio") audio_output = gr.Audio(type="numpy", label="Generated Audio") generate_button.click( fn=infer_gradio, inputs=[prompt_input, seed_slider, cfg_slider, n_cand_slider, ddim_steps_slider], outputs=[audio_output] ) demo.launch(server_name="0.0.0.0", server_port=7860)