Spaces:
Runtime error
Runtime error
File size: 4,251 Bytes
cef9e84 9a900bb cef9e84 1c707cd 9a900bb 1c707cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import sys
sys.path.append("GenAU")
import importlib
import subprocess
try:
importlib.import_module("apex")
subprocess.run(["pip", "uninstall", "-y", "apex"], check=True)
except ImportError:
pass
from src.tools.training_utils import get_restore_step
from src.utilities.model.model_util import instantiate_from_config
from src.tools.configuration import Configuration
from src.tools.download_manager import get_checkpoint_path
import os
import torch
import gradio as gr
import soundfile as sf
from pytorch_lightning import seed_everything
import importlib
import subprocess
def load_model(
model_name: str = "genau-l-full-hq-data",
config_yaml_path: str = None,
checkpoint_path: str = None,
):
assert torch.cuda.is_available(), "CUDA is not available."
if config_yaml_path is None:
config_yaml_path = get_checkpoint_path(f"{model_name}_config")
if checkpoint_path is None:
checkpoint_path = get_checkpoint_path(model_name)
print("checkpoint_path", checkpoint_path)
configuration = Configuration(config_yaml_path)
config_dict = configuration.get_config()
if checkpoint_path is not None:
config_dict["reload_from_ckpt"] = checkpoint_path
exp_name = os.path.basename(config_yaml_path.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml_path))
log_path = config_dict['logging']["log_directory"]
if "reload_from_ckpt" in config_dict and config_dict["reload_from_ckpt"]:
resume_from_checkpoint = config_dict["reload_from_ckpt"]
else:
# Otherwise try to load the latest
ckpt_folder = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
if not os.path.exists(ckpt_folder):
raise RuntimeError(f"No checkpoint directory found at {ckpt_folder}")
restore_step, _ = get_restore_step(ckpt_folder)
resume_from_checkpoint = os.path.join(ckpt_folder, restore_step)
config_dict["model"]["params"]["ckpt_path"] = resume_from_checkpoint
latent_diffusion = instantiate_from_config(config_dict["model"])
latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()
return latent_diffusion, config_dict
def infer_gradio(
prompt: str,
seed: int,
cfg_weight: float,
n_cand: int,
ddim_steps: int
):
"""
Inference function called by Gradio's interface.
Returns a WAV audio object (data, sr) to play in the Gradio UI.
"""
seed_everything(seed)
saved_wav_path = latent_diffusion.text_to_audio(
prompt=prompt,
ddim_steps=ddim_steps,
unconditional_guidance_scale=cfg_weight,
n_gen=n_cand,
use_ema=True
)
data, sr = sf.read(saved_wav_path)
return (sr, data)
latent_diffusion, config_yaml = load_model(
model_name="genau-l-full-hq-data", # or whichever default model you want
config_yaml_path=None, # or path to your .yaml if you have it
checkpoint_path=None # or a direct path to a .ckpt file
)
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Audio Demo")
with gr.Row():
prompt_input = gr.Textbox(
lines=2,
label="Prompt",
placeholder="Type your text prompt here...",
value="A calm piano melody."
)
with gr.Accordion("Advanced Parameters", open=False):
seed_slider = gr.Number(
value=0,
label="Random Seed"
)
cfg_slider = gr.Slider(
minimum=0.1, maximum=10.0, value=4.0, step=0.1,
label="Classifier-Free Guidance Weight"
)
n_cand_slider = gr.Slider(
minimum=1, maximum=4, value=1, step=1,
label="Number of Candidates"
)
ddim_steps_slider = gr.Slider(
minimum=10, maximum=500, value=100, step=10,
label="DDIM Steps"
)
generate_button = gr.Button("Generate Audio")
audio_output = gr.Audio(type="numpy", label="Generated Audio")
generate_button.click(
fn=infer_gradio,
inputs=[prompt_input, seed_slider, cfg_slider, n_cand_slider, ddim_steps_slider],
outputs=[audio_output]
)
demo.launch(server_name="0.0.0.0", server_port=7860)
|