Spaces:
Runtime error
Runtime error
File size: 7,085 Bytes
64001ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
import spaces
import torch
from pydub import AudioSegment
import numpy as np
import io
from scipy.io import wavfile
from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
import base64
from scipy.io.wavfile import write
import os
# Global model variables
model = None
processor = None
def load_model():
"""Load model and processor once"""
global model, processor
if model is None:
model = ColQwen2_5Omni.from_pretrained(
"vidore/colqwen-omni-v0.1",
torch_dtype=torch.bfloat16,
device_map="cpu", # Start on CPU for ZeroGPU
attn_implementation="eager" # ZeroGPU compatible
).eval()
processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1")
return model, processor
def chunk_audio(audio_file, chunk_length=30):
"""Split audio into chunks"""
audio = AudioSegment.from_file(audio_file.name)
audios = []
target_rate = 16000
chunk_length_ms = chunk_length * 1000
for i in range(0, len(audio), chunk_length_ms):
chunk = audio[i:i + chunk_length_ms]
chunk = chunk.set_channels(1).set_frame_rate(target_rate)
buf = io.BytesIO()
chunk.export(buf, format="wav")
buf.seek(0)
rate, data = wavfile.read(buf)
audios.append(data)
return audios
@spaces.GPU(duration=120)
def embed_audio_chunks(audios):
"""Embed audio chunks using GPU"""
model, processor = load_model()
model = model.to('cuda')
# Process in batches
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset=audios,
batch_size=4,
shuffle=False,
collate_fn=lambda x: processor.process_audios(x)
)
embeddings = []
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
# Move model back to CPU to free GPU memory
model = model.to('cpu')
torch.cuda.empty_cache()
return embeddings
@spaces.GPU(duration=60)
def search_audio(query, embeddings, audios, top_k=5):
"""Search for relevant audio chunks"""
model, processor = load_model()
model = model.to('cuda')
# Process query
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
query_embeddings = model(**batch_queries)
# Score against all embeddings
scores = processor.score_multi_vector(query_embeddings, embeddings)
top_indices = scores[0].topk(top_k).indices.tolist()
# Move model back to CPU
model = model.to('cpu')
torch.cuda.empty_cache()
return top_indices
def audio_to_base64(data, rate=16000):
"""Convert audio data to base64"""
buf = io.BytesIO()
write(buf, rate, data)
buf.seek(0)
encoded_string = base64.b64encode(buf.read()).decode("utf-8")
return encoded_string
def process_audio_rag(audio_file, query, chunk_length=30, use_openai=False, openai_key=None):
"""Main processing function"""
if not audio_file:
return "Please upload an audio file", None, None
# Chunk audio
audios = chunk_audio(audio_file, chunk_length)
# Embed chunks
embeddings = embed_audio_chunks(audios)
# Search for relevant chunks
top_indices = search_audio(query, embeddings, audios)
# Prepare results
result_text = f"Found {len(top_indices)} relevant audio chunks:\n"
result_text += f"Chunk indices: {top_indices}\n\n"
# Save first result as audio file
first_chunk_path = "result_chunk.wav"
wavfile.write(first_chunk_path, 16000, audios[top_indices[0]])
# Optional: Use OpenAI for answer generation
if use_openai and openai_key:
from openai import OpenAI
client = OpenAI(api_key=openai_key)
content = [{"type": "text", "text": f"Answer the query using the audio files. Query: {query}"}]
for idx in top_indices[:3]: # Use top 3 chunks
content.extend([
{"type": "text", "text": f"Audio chunk #{idx}:"},
{
"type": "input_audio",
"input_audio": {
"data": audio_to_base64(audios[idx]),
"format": "wav"
}
}
])
try:
completion = client.chat.completions.create(
model="gpt-4o-audio-preview",
messages=[{"role": "user", "content": content}]
)
result_text += f"\nOpenAI Answer: {completion.choices[0].message.content}"
except Exception as e:
result_text += f"\nOpenAI Error: {str(e)}"
# Create audio visualization
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(audios[top_indices[0]])
ax.set_title(f"Waveform of top matching chunk (#{top_indices[0]})")
ax.set_xlabel("Samples")
ax.set_ylabel("Amplitude")
plt.tight_layout()
return result_text, first_chunk_path, fig
# Create Gradio interface
with gr.Blocks(title="AudioRAG Demo") as demo:
gr.Markdown("# AudioRAG Demo - Semantic Audio Search")
gr.Markdown("Upload an audio file and search through it using natural language queries!")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
query_input = gr.Textbox(label="Search Query", placeholder="What are you looking for in the audio?")
chunk_length = gr.Slider(minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)")
with gr.Accordion("OpenAI Integration (Optional)", open=False):
use_openai = gr.Checkbox(label="Use OpenAI for answer generation")
openai_key = gr.Textbox(label="OpenAI API Key", type="password")
search_btn = gr.Button("Search Audio", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Results", lines=10)
output_audio = gr.Audio(label="Top Matching Audio Chunk", type="filepath")
output_plot = gr.Plot(label="Audio Waveform")
search_btn.click(
fn=process_audio_rag,
inputs=[audio_input, query_input, chunk_length, use_openai, openai_key],
outputs=[output_text, output_audio, output_plot]
)
gr.Examples(
examples=[
["example_audio.wav", "Was Hannibal well liked by his men?", 30],
["podcast.mp3", "What did they say about climate change?", 20],
],
inputs=[audio_input, query_input, chunk_length]
)
if __name__ == "__main__":
# Load model on startup
load_model()
demo.launch() |