Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append("GenAU") | |
import importlib | |
import subprocess | |
try: | |
importlib.import_module("apex") | |
subprocess.run(["pip", "uninstall", "-y", "apex"], check=True) | |
except ImportError: | |
pass | |
from src.tools.training_utils import get_restore_step | |
from src.utilities.model.model_util import instantiate_from_config | |
from src.tools.configuration import Configuration | |
from src.tools.download_manager import get_checkpoint_path | |
import os | |
import torch | |
import gradio as gr | |
import soundfile as sf | |
from pytorch_lightning import seed_everything | |
import importlib | |
import subprocess | |
def load_model( | |
model_name: str = "genau-l-full-hq-data", | |
config_yaml_path: str = None, | |
checkpoint_path: str = None, | |
): | |
assert torch.cuda.is_available(), "CUDA is not available." | |
if config_yaml_path is None: | |
config_yaml_path = get_checkpoint_path(f"{model_name}_config") | |
if checkpoint_path is None: | |
checkpoint_path = get_checkpoint_path(model_name) | |
print("checkpoint_path", checkpoint_path) | |
configuration = Configuration(config_yaml_path) | |
config_dict = configuration.get_config() | |
if checkpoint_path is not None: | |
config_dict["reload_from_ckpt"] = checkpoint_path | |
exp_name = os.path.basename(config_yaml_path.split(".")[0]) | |
exp_group_name = os.path.basename(os.path.dirname(config_yaml_path)) | |
log_path = config_dict['logging']["log_directory"] | |
if "reload_from_ckpt" in config_dict and config_dict["reload_from_ckpt"]: | |
resume_from_checkpoint = config_dict["reload_from_ckpt"] | |
else: | |
# Otherwise try to load the latest | |
ckpt_folder = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") | |
if not os.path.exists(ckpt_folder): | |
raise RuntimeError(f"No checkpoint directory found at {ckpt_folder}") | |
restore_step, _ = get_restore_step(ckpt_folder) | |
resume_from_checkpoint = os.path.join(ckpt_folder, restore_step) | |
config_dict["model"]["params"]["ckpt_path"] = resume_from_checkpoint | |
latent_diffusion = instantiate_from_config(config_dict["model"]) | |
latent_diffusion.eval() | |
latent_diffusion = latent_diffusion.cuda() | |
return latent_diffusion, config_dict | |
def infer_gradio( | |
prompt: str, | |
seed: int, | |
cfg_weight: float, | |
n_cand: int, | |
ddim_steps: int | |
): | |
""" | |
Inference function called by Gradio's interface. | |
Returns a WAV audio object (data, sr) to play in the Gradio UI. | |
""" | |
seed_everything(seed) | |
saved_wav_path = latent_diffusion.text_to_audio( | |
prompt=prompt, | |
ddim_steps=ddim_steps, | |
unconditional_guidance_scale=cfg_weight, | |
n_gen=n_cand, | |
use_ema=True | |
) | |
data, sr = sf.read(saved_wav_path) | |
return (sr, data) | |
latent_diffusion, config_yaml = load_model( | |
model_name="genau-l-full-hq-data", # or whichever default model you want | |
config_yaml_path=None, # or path to your .yaml if you have it | |
checkpoint_path=None # or a direct path to a .ckpt file | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Text-to-Audio Demo") | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
lines=2, | |
label="Prompt", | |
placeholder="Type your text prompt here...", | |
value="A calm piano melody." | |
) | |
with gr.Accordion("Advanced Parameters", open=False): | |
seed_slider = gr.Number( | |
value=0, | |
label="Random Seed" | |
) | |
cfg_slider = gr.Slider( | |
minimum=0.1, maximum=10.0, value=4.0, step=0.1, | |
label="Classifier-Free Guidance Weight" | |
) | |
n_cand_slider = gr.Slider( | |
minimum=1, maximum=4, value=1, step=1, | |
label="Number of Candidates" | |
) | |
ddim_steps_slider = gr.Slider( | |
minimum=10, maximum=500, value=100, step=10, | |
label="DDIM Steps" | |
) | |
generate_button = gr.Button("Generate Audio") | |
audio_output = gr.Audio(type="numpy", label="Generated Audio") | |
generate_button.click( | |
fn=infer_gradio, | |
inputs=[prompt_input, seed_slider, cfg_slider, n_cand_slider, ddim_steps_slider], | |
outputs=[audio_output] | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |