Spaces:
Runtime error
Runtime error
| from transformers import ClapModel, ClapProcessor, AutoFeatureExtractor | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import os | |
| import numpy as np | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from qdrant_client.http import models | |
| import dotenv | |
| dotenv.load_dotenv() | |
| class ClapSSGradio(): | |
| def __init__( | |
| self, | |
| name, | |
| model = "clap-2", | |
| k=10, | |
| ): | |
| self.name = name | |
| self.k = k | |
| self.model = ClapModel.from_pretrained( | |
| f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN')) | |
| self.processor = ClapProcessor.from_pretrained( | |
| f"Audiogen/{model}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN')) | |
| self.sas_token = os.environ['AZURE_SAS_TOKEN'] | |
| self.account_name = 'Audiogen' | |
| self.storage_name = 'audiogentrainingdataeun' | |
| self._start_qdrant() | |
| def _start_qdrant(self): | |
| self.client = QdrantClient(url=os.getenv( | |
| "QDRANT_URL"), api_key=os.getenv('QDRANT_API_KEY')) | |
| # print(self.client.get_collection(collection_name=self.name)) | |
| def _embed_query(self, query, audio_file): | |
| if audio_file is not None: | |
| waveform, sample_rate = torchaudio.load(audio_file.name) | |
| print("Waveform shape:", waveform.shape) | |
| waveform = torchaudio.functional.resample( | |
| waveform, sample_rate, 48000) | |
| print("Resampled waveform shape:", waveform.shape) | |
| if waveform.shape[-1] < 480000: | |
| waveform = torch.nn.functional.pad( | |
| waveform, (0, 48000 - waveform.shape[-1])) | |
| elif waveform.shape[-1] > 480000: | |
| waveform = waveform[..., :480000] | |
| audio_prompt_features = self.processor( | |
| audios=waveform.mean(0), return_tensors='pt', sampling_rate=48000 | |
| )['input_features'] | |
| print("Audio prompt features shape:", audio_prompt_features.shape) | |
| e = self.model.get_audio_features( | |
| input_features=audio_prompt_features)[0] | |
| if any(torch.isnan(e)): | |
| raise ValueError("Audio features are NaN") | |
| print("Embeddings: ", e.shape) | |
| return e | |
| else: | |
| inputs = self.processor( | |
| query, return_tensors="pt", padding='max_length', max_length=77, truncation=True) | |
| return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0] | |
| def _similarity_search(self, query, threshold, audio_file): | |
| results = self.client.search( | |
| collection_name=self.name, | |
| query_vector=self._embed_query(query, audio_file), | |
| limit=self.k, | |
| score_threshold=threshold, | |
| ) | |
| containers = [result.payload['container'] for result in results] | |
| filenames = [result.id for result in results] | |
| captions = [result.payload['caption'] for result in results] | |
| scores = [result.score for result in results] | |
| # print to stdout | |
| print(f"\nQuery: {query}\n") | |
| for i, (container, filename, caption, score) in enumerate(zip(containers, filenames, captions, scores)): | |
| print(f"{i}: {container} - {caption}. Score: {score}") | |
| waveforms = self._download_results(containers, filenames) | |
| if len(waveforms) == 0: | |
| print("\nNo results found") | |
| if len(waveforms) < self.k: | |
| waveforms.extend([(int(48000), np.zeros((480000, 2))) | |
| for _ in range(self.k - len(waveforms))]) | |
| return waveforms | |
| def _download_results(self, containers: list, filenames: list): | |
| # construct url | |
| urls = [f"https://{self.storage_name}.blob.core.windows.net/snake/{file_name}.flac?{self.sas_token}" for file_name in filenames] | |
| # make requests | |
| waveforms = [] | |
| for url in urls: | |
| waveform, sample_rate = torchaudio.load(url) | |
| waveforms.append(tuple([sample_rate, waveform.numpy().T])) | |
| return waveforms | |
| def launch(self, share=False): | |
| # gradio app structure | |
| with gr.Blocks(title='Clap Semantic Search') as ui: | |
| with gr.Row(): | |
| with gr.Column(variant='panel'): | |
| search = gr.Textbox(placeholder='Search Samples') | |
| float_input = gr.Number( | |
| label='Similarity threshold [min: 0.1 max: 1]', value=0.5, minimum=0.1, maximum=1) | |
| audio_file = gr.File( | |
| label='Upload an Audio File', type="file") | |
| search_button = gr.Button("Search", label='Search') | |
| with gr.Column(): | |
| audioboxes = [] | |
| gr.Markdown("Output") | |
| for i in range(self.k): | |
| t = gr.components.Audio(label=f"{i}", visible=True) | |
| audioboxes.append(t) | |
| search_button.click(fn=self._similarity_search, inputs=[ | |
| search, float_input, audio_file], outputs=audioboxes) | |
| ui.launch(share=share) | |
| if __name__ == "__main__": | |
| app = ClapSSGradio("demo") | |
| app.launch(share=False) | |