genau-demo / app.py
Moayed's picture
update requirements
9a900bb
import sys
sys.path.append("GenAU")
import importlib
import subprocess
try:
importlib.import_module("apex")
subprocess.run(["pip", "uninstall", "-y", "apex"], check=True)
except ImportError:
pass
from src.tools.training_utils import get_restore_step
from src.utilities.model.model_util import instantiate_from_config
from src.tools.configuration import Configuration
from src.tools.download_manager import get_checkpoint_path
import os
import torch
import gradio as gr
import soundfile as sf
from pytorch_lightning import seed_everything
import importlib
import subprocess
def load_model(
model_name: str = "genau-l-full-hq-data",
config_yaml_path: str = None,
checkpoint_path: str = None,
):
assert torch.cuda.is_available(), "CUDA is not available."
if config_yaml_path is None:
config_yaml_path = get_checkpoint_path(f"{model_name}_config")
if checkpoint_path is None:
checkpoint_path = get_checkpoint_path(model_name)
print("checkpoint_path", checkpoint_path)
configuration = Configuration(config_yaml_path)
config_dict = configuration.get_config()
if checkpoint_path is not None:
config_dict["reload_from_ckpt"] = checkpoint_path
exp_name = os.path.basename(config_yaml_path.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml_path))
log_path = config_dict['logging']["log_directory"]
if "reload_from_ckpt" in config_dict and config_dict["reload_from_ckpt"]:
resume_from_checkpoint = config_dict["reload_from_ckpt"]
else:
# Otherwise try to load the latest
ckpt_folder = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
if not os.path.exists(ckpt_folder):
raise RuntimeError(f"No checkpoint directory found at {ckpt_folder}")
restore_step, _ = get_restore_step(ckpt_folder)
resume_from_checkpoint = os.path.join(ckpt_folder, restore_step)
config_dict["model"]["params"]["ckpt_path"] = resume_from_checkpoint
latent_diffusion = instantiate_from_config(config_dict["model"])
latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()
return latent_diffusion, config_dict
def infer_gradio(
prompt: str,
seed: int,
cfg_weight: float,
n_cand: int,
ddim_steps: int
):
"""
Inference function called by Gradio's interface.
Returns a WAV audio object (data, sr) to play in the Gradio UI.
"""
seed_everything(seed)
saved_wav_path = latent_diffusion.text_to_audio(
prompt=prompt,
ddim_steps=ddim_steps,
unconditional_guidance_scale=cfg_weight,
n_gen=n_cand,
use_ema=True
)
data, sr = sf.read(saved_wav_path)
return (sr, data)
latent_diffusion, config_yaml = load_model(
model_name="genau-l-full-hq-data", # or whichever default model you want
config_yaml_path=None, # or path to your .yaml if you have it
checkpoint_path=None # or a direct path to a .ckpt file
)
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Audio Demo")
with gr.Row():
prompt_input = gr.Textbox(
lines=2,
label="Prompt",
placeholder="Type your text prompt here...",
value="A calm piano melody."
)
with gr.Accordion("Advanced Parameters", open=False):
seed_slider = gr.Number(
value=0,
label="Random Seed"
)
cfg_slider = gr.Slider(
minimum=0.1, maximum=10.0, value=4.0, step=0.1,
label="Classifier-Free Guidance Weight"
)
n_cand_slider = gr.Slider(
minimum=1, maximum=4, value=1, step=1,
label="Number of Candidates"
)
ddim_steps_slider = gr.Slider(
minimum=10, maximum=500, value=100, step=10,
label="DDIM Steps"
)
generate_button = gr.Button("Generate Audio")
audio_output = gr.Audio(type="numpy", label="Generated Audio")
generate_button.click(
fn=infer_gradio,
inputs=[prompt_input, seed_slider, cfg_slider, n_cand_slider, ddim_steps_slider],
outputs=[audio_output]
)
demo.launch(server_name="0.0.0.0", server_port=7860)