Spaces:
Paused
Paused
| import torch | |
| import torchaudio | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| import torch | |
| from BigVGAN import bigvgan | |
| from BigVGAN.meldataset import get_mel_spectrogram | |
| from voice_restore import VoiceRestore | |
| import argparse | |
| from model import OptimizedAudioRestorationModel | |
| import librosa | |
| from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Configuration class for VoiceRestore | |
| class VoiceRestoreConfig(PretrainedConfig): | |
| model_type = "voice_restore" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.steps = kwargs.get("steps", 16) | |
| self.cfg_strength = kwargs.get("cfg_strength", 0.5) | |
| self.window_size_sec = kwargs.get("window_size_sec", 5.0) | |
| self.overlap = kwargs.get("overlap", 0.5) | |
| # Model class for VoiceRestore | |
| class VoiceRestore(PreTrainedModel): | |
| config_class = VoiceRestoreConfig | |
| def __init__(self, config: VoiceRestoreConfig): | |
| super().__init__(config) | |
| self.steps = config.steps | |
| self.cfg_strength = config.cfg_strength | |
| self.window_size_sec = config.window_size_sec | |
| self.overlap = config.overlap | |
| # Initialize BigVGAN model | |
| self.bigvgan_model = bigvgan.BigVGAN.from_pretrained( | |
| 'nvidia/bigvgan_v2_24khz_100band_256x', | |
| use_cuda_kernel=False, | |
| force_download=False | |
| ).to(device) | |
| self.bigvgan_model.remove_weight_norm() | |
| # Optimized restoration model | |
| self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model) | |
| save_path = "./pytorch_model.bin" | |
| state_dict = torch.load(save_path, map_location=torch.device(device)) | |
| if 'model_state_dict' in state_dict: | |
| state_dict = state_dict['model_state_dict'] | |
| self.optimized_model.voice_restore.load_state_dict(state_dict, strict=True) | |
| self.optimized_model.eval() | |
| def forward(self, input_path, output_path, short=True): | |
| # Restore the audio using the parameters from the config | |
| if short: | |
| self.restore_audio_short(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength) | |
| else: | |
| self.restore_audio_long(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength, self.window_size_sec, self.overlap) | |
| def restore_audio_short(self, model, input_path, output_path, steps, cfg_strength): | |
| """ | |
| Short inference for audio restoration. | |
| """ | |
| # Load the audio file | |
| device_type = device.type | |
| audio, sr = torchaudio.load(input_path) | |
| if sr != model.target_sample_rate: | |
| audio = torchaudio.functional.resample(audio, sr, model.target_sample_rate) | |
| audio = audio.mean(dim=0, keepdim=True) if audio.dim() > 1 else audio # Convert to mono if stereo | |
| with torch.inference_mode(): | |
| with torch.autocast(device_type): | |
| restored_wav = model(audio, steps=steps, cfg_strength=cfg_strength) | |
| restored_wav = restored_wav.squeeze(0).float().cpu() # Move to CPU after processing | |
| # Save the restored audio | |
| torchaudio.save(output_path, restored_wav, model.target_sample_rate) | |
| def restore_audio_long(self, model, input_path, output_path, steps, cfg_strength, window_size_sec, overlap): | |
| """ | |
| Long inference for audio restoration using overlapping windows. | |
| """ | |
| # Load the audio file | |
| wav, sr = librosa.load(input_path, sr=24000, mono=True) | |
| wav = torch.FloatTensor(wav).unsqueeze(0) # Shape: [1, num_samples] | |
| window_size_samples = int(window_size_sec * sr) | |
| wav_windows = apply_overlap_windowing_waveform(wav, window_size_samples, overlap) | |
| restored_wav_windows = [] | |
| for wav_window in wav_windows: | |
| wav_window = wav_window.to(device) | |
| processed_mel = get_mel_spectrogram(wav_window, self.bigvgan_model.h).to(device) | |
| # Restore audio | |
| with torch.no_grad(): | |
| with torch.autocast(device): | |
| restored_mel = model.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength) | |
| restored_mel = restored_mel.squeeze(0).transpose(0, 1) | |
| restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0)).squeeze(0).float().cpu() | |
| restored_wav_windows.append(restored_wav) | |
| torch.cuda.empty_cache() | |
| restored_wav_windows = torch.stack(restored_wav_windows) | |
| restored_wav = reconstruct_waveform_from_windows(restored_wav_windows, window_size_samples, overlap) | |
| # Save the restored audio | |
| torchaudio.save(output_path, restored_wav.unsqueeze(0), 24000) | |
| # # Function to load the model using AutoModel | |
| # from transformers import AutoModel | |
| # def load_voice_restore_model(checkpoint_path: str): | |
| # model = AutoModel.from_pretrained(checkpoint_path, config=VoiceRestoreConfig()) | |
| # return model | |
| # # Example Usage | |
| # model = load_voice_restore_model("./checkpoints/voice-restore-20d-16h-optim.pt") | |
| # model("test_input.wav", "test_output.wav") | |