fdaudens's picture
fdaudens HF Staff
Create app.py
64001ff verified
raw
history blame
7.09 kB
import gradio as gr
import spaces
import torch
from pydub import AudioSegment
import numpy as np
import io
from scipy.io import wavfile
from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
import base64
from scipy.io.wavfile import write
import os
# Global model variables
model = None
processor = None
def load_model():
"""Load model and processor once"""
global model, processor
if model is None:
model = ColQwen2_5Omni.from_pretrained(
"vidore/colqwen-omni-v0.1",
torch_dtype=torch.bfloat16,
device_map="cpu", # Start on CPU for ZeroGPU
attn_implementation="eager" # ZeroGPU compatible
).eval()
processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1")
return model, processor
def chunk_audio(audio_file, chunk_length=30):
"""Split audio into chunks"""
audio = AudioSegment.from_file(audio_file.name)
audios = []
target_rate = 16000
chunk_length_ms = chunk_length * 1000
for i in range(0, len(audio), chunk_length_ms):
chunk = audio[i:i + chunk_length_ms]
chunk = chunk.set_channels(1).set_frame_rate(target_rate)
buf = io.BytesIO()
chunk.export(buf, format="wav")
buf.seek(0)
rate, data = wavfile.read(buf)
audios.append(data)
return audios
@spaces.GPU(duration=120)
def embed_audio_chunks(audios):
"""Embed audio chunks using GPU"""
model, processor = load_model()
model = model.to('cuda')
# Process in batches
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset=audios,
batch_size=4,
shuffle=False,
collate_fn=lambda x: processor.process_audios(x)
)
embeddings = []
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
# Move model back to CPU to free GPU memory
model = model.to('cpu')
torch.cuda.empty_cache()
return embeddings
@spaces.GPU(duration=60)
def search_audio(query, embeddings, audios, top_k=5):
"""Search for relevant audio chunks"""
model, processor = load_model()
model = model.to('cuda')
# Process query
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
query_embeddings = model(**batch_queries)
# Score against all embeddings
scores = processor.score_multi_vector(query_embeddings, embeddings)
top_indices = scores[0].topk(top_k).indices.tolist()
# Move model back to CPU
model = model.to('cpu')
torch.cuda.empty_cache()
return top_indices
def audio_to_base64(data, rate=16000):
"""Convert audio data to base64"""
buf = io.BytesIO()
write(buf, rate, data)
buf.seek(0)
encoded_string = base64.b64encode(buf.read()).decode("utf-8")
return encoded_string
def process_audio_rag(audio_file, query, chunk_length=30, use_openai=False, openai_key=None):
"""Main processing function"""
if not audio_file:
return "Please upload an audio file", None, None
# Chunk audio
audios = chunk_audio(audio_file, chunk_length)
# Embed chunks
embeddings = embed_audio_chunks(audios)
# Search for relevant chunks
top_indices = search_audio(query, embeddings, audios)
# Prepare results
result_text = f"Found {len(top_indices)} relevant audio chunks:\n"
result_text += f"Chunk indices: {top_indices}\n\n"
# Save first result as audio file
first_chunk_path = "result_chunk.wav"
wavfile.write(first_chunk_path, 16000, audios[top_indices[0]])
# Optional: Use OpenAI for answer generation
if use_openai and openai_key:
from openai import OpenAI
client = OpenAI(api_key=openai_key)
content = [{"type": "text", "text": f"Answer the query using the audio files. Query: {query}"}]
for idx in top_indices[:3]: # Use top 3 chunks
content.extend([
{"type": "text", "text": f"Audio chunk #{idx}:"},
{
"type": "input_audio",
"input_audio": {
"data": audio_to_base64(audios[idx]),
"format": "wav"
}
}
])
try:
completion = client.chat.completions.create(
model="gpt-4o-audio-preview",
messages=[{"role": "user", "content": content}]
)
result_text += f"\nOpenAI Answer: {completion.choices[0].message.content}"
except Exception as e:
result_text += f"\nOpenAI Error: {str(e)}"
# Create audio visualization
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(audios[top_indices[0]])
ax.set_title(f"Waveform of top matching chunk (#{top_indices[0]})")
ax.set_xlabel("Samples")
ax.set_ylabel("Amplitude")
plt.tight_layout()
return result_text, first_chunk_path, fig
# Create Gradio interface
with gr.Blocks(title="AudioRAG Demo") as demo:
gr.Markdown("# AudioRAG Demo - Semantic Audio Search")
gr.Markdown("Upload an audio file and search through it using natural language queries!")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
query_input = gr.Textbox(label="Search Query", placeholder="What are you looking for in the audio?")
chunk_length = gr.Slider(minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)")
with gr.Accordion("OpenAI Integration (Optional)", open=False):
use_openai = gr.Checkbox(label="Use OpenAI for answer generation")
openai_key = gr.Textbox(label="OpenAI API Key", type="password")
search_btn = gr.Button("Search Audio", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Results", lines=10)
output_audio = gr.Audio(label="Top Matching Audio Chunk", type="filepath")
output_plot = gr.Plot(label="Audio Waveform")
search_btn.click(
fn=process_audio_rag,
inputs=[audio_input, query_input, chunk_length, use_openai, openai_key],
outputs=[output_text, output_audio, output_plot]
)
gr.Examples(
examples=[
["example_audio.wav", "Was Hannibal well liked by his men?", 30],
["podcast.mp3", "What did they say about climate change?", 20],
],
inputs=[audio_input, query_input, chunk_length]
)
if __name__ == "__main__":
# Load model on startup
load_model()
demo.launch()