Spaces:
Runtime error
Runtime error
Add threshold field
Browse files
app.py
CHANGED
|
@@ -23,9 +23,6 @@ class ClapSSGradio():
|
|
| 23 |
self.name = name
|
| 24 |
self.k = k
|
| 25 |
|
| 26 |
-
print("Env?!")
|
| 27 |
-
print(os.getenv('HUGGINGFACE_API_TOKEN')[:2])
|
| 28 |
-
|
| 29 |
self.model = ClapModel.from_pretrained(
|
| 30 |
f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
|
| 31 |
self.tokenizer = ClapProcessor.from_pretrained(
|
|
@@ -48,12 +45,12 @@ class ClapSSGradio():
|
|
| 48 |
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
|
| 49 |
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
|
| 50 |
|
| 51 |
-
def _similarity_search(self, query):
|
| 52 |
results = self.client.search(
|
| 53 |
collection_name=self.name,
|
| 54 |
query_vector=self._embed_query(query),
|
| 55 |
limit=self.k,
|
| 56 |
-
score_threshold=
|
| 57 |
)
|
| 58 |
|
| 59 |
containers = [result.payload['container'] for result in results]
|
|
@@ -94,21 +91,17 @@ class ClapSSGradio():
|
|
| 94 |
def launch(self, share=False):
|
| 95 |
# gradio app structure
|
| 96 |
with gr.Blocks(title='Clap Semantic Search') as ui:
|
| 97 |
-
|
| 98 |
with gr.Row():
|
| 99 |
with gr.Column(variant='panel'):
|
| 100 |
search = gr.Textbox(placeholder='Search Samples')
|
| 101 |
-
|
| 102 |
with gr.Column():
|
| 103 |
audioboxes = []
|
| 104 |
gr.Markdown("Output")
|
| 105 |
for i in range(self.k):
|
| 106 |
t = gr.components.Audio(label=f"{i}", visible=True)
|
| 107 |
audioboxes.append(t)
|
| 108 |
-
|
| 109 |
-
search.submit(fn=self._similarity_search, inputs=[
|
| 110 |
-
search], outputs=audioboxes)
|
| 111 |
-
|
| 112 |
ui.launch(share=share)
|
| 113 |
|
| 114 |
|
|
|
|
| 23 |
self.name = name
|
| 24 |
self.k = k
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
self.model = ClapModel.from_pretrained(
|
| 27 |
f"Audiogen/{name}", use_auth_token=os.getenv('HUGGINGFACE_API_TOKEN'))
|
| 28 |
self.tokenizer = ClapProcessor.from_pretrained(
|
|
|
|
| 45 |
query, return_tensors="pt", padding='max_length', max_length=77, truncation=True)
|
| 46 |
return self.model.get_text_features(**inputs).cpu().numpy().tolist()[0]
|
| 47 |
|
| 48 |
+
def _similarity_search(self, query, threshold):
|
| 49 |
results = self.client.search(
|
| 50 |
collection_name=self.name,
|
| 51 |
query_vector=self._embed_query(query),
|
| 52 |
limit=self.k,
|
| 53 |
+
score_threshold=threshold,
|
| 54 |
)
|
| 55 |
|
| 56 |
containers = [result.payload['container'] for result in results]
|
|
|
|
| 91 |
def launch(self, share=False):
|
| 92 |
# gradio app structure
|
| 93 |
with gr.Blocks(title='Clap Semantic Search') as ui:
|
|
|
|
| 94 |
with gr.Row():
|
| 95 |
with gr.Column(variant='panel'):
|
| 96 |
search = gr.Textbox(placeholder='Search Samples')
|
| 97 |
+
float_input = gr.Number(label='Similarity threshold [min: 0.1 max: 1]', default=0.5, minimum=0.1, maximum=1)
|
| 98 |
with gr.Column():
|
| 99 |
audioboxes = []
|
| 100 |
gr.Markdown("Output")
|
| 101 |
for i in range(self.k):
|
| 102 |
t = gr.components.Audio(label=f"{i}", visible=True)
|
| 103 |
audioboxes.append(t)
|
| 104 |
+
search.submit(fn=self._similarity_search, inputs=[search, float_input], outputs=audioboxes)
|
|
|
|
|
|
|
|
|
|
| 105 |
ui.launch(share=share)
|
| 106 |
|
| 107 |
|