Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
FILE_PATH = "outputRUPDATED.txt" | |
class SimilaritySearch: | |
def __init__(self, | |
model_name: str = "sentence-transformers/stsb-roberta-large", | |
bias: float = 0.0, | |
density_rate: float = 0.5, | |
max_penalty: float = 0.20): | |
""" | |
Uses a state-of-the-art SBERT model (RoBERTa-large) for embeddings. | |
""" | |
# Load a high-performance sentence embedding model | |
self.embedder = SentenceTransformer(model_name) | |
# Stored queries | |
self.data: list[str] = [] | |
# Normalized embeddings matrix | |
self.embedded_texts: np.ndarray = np.empty((0,)) | |
# Score adjustment parameters | |
self.bias = bias | |
self.density_rate = density_rate | |
self.max_penalty = max_penalty | |
# Penalty per entry based on dataset density | |
self.penalties: np.ndarray = np.array([]) | |
# Load existing data and compute initial penalties | |
self.load_data() | |
self.compute_penalties() | |
def load_data(self): | |
# Ensure the storage file exists | |
if not os.path.exists(FILE_PATH): | |
open(FILE_PATH, 'w').close() | |
# Read stored queries | |
with open(FILE_PATH, 'r', encoding='utf-8') as f: | |
lines = [line.strip() for line in f if line.strip()] | |
self.data = lines | |
# Compute embeddings if data exists | |
if self.data: | |
self.embedded_texts = self.embedder.encode(self.data, normalize_embeddings=True) | |
else: | |
self.embedded_texts = np.empty((0,)) | |
def compute_penalties(self): | |
n = len(self.data) | |
if n < 2: | |
self.penalties = np.zeros(n) | |
return | |
# Ensure embeddings are 2D | |
emb = self.embedded_texts | |
if emb.ndim == 1: | |
emb = emb.reshape(1, -1) | |
# Pairwise similarity matrix | |
sim_matrix = np.dot(emb, emb.T) | |
# Zero out self-similarity | |
np.fill_diagonal(sim_matrix, 0.0) | |
# Compute average neighbor similarity (density) | |
densities = sim_matrix.mean(axis=1) | |
# Convert densities to penalties | |
self.penalties = np.minimum(densities * self.density_rate, self.max_penalty) | |
def search(self, query: str, top_n: int = 5) -> list[str]: | |
if not self.data: | |
return ["⚠️ No data to search. Add some queries first."] | |
# Embed query using the same high-quality model | |
q_emb = self.embedder.encode(query, normalize_embeddings=True) | |
emb = self.embedded_texts | |
if emb.ndim == 1: | |
emb = emb.reshape(1, -1) | |
# Compute cosine similarities | |
sims = np.dot(emb, q_emb).flatten() | |
# Adjust: add global bias, subtract density-based penalty | |
adjusted = sims + self.bias - self.penalties | |
final = np.clip(adjusted, 0.0, 1.0) | |
# Select top-n matches | |
top_n = min(top_n, len(self.data)) | |
idxs = np.argsort(final)[::-1][:top_n] | |
return [f"({final[i]:.3f}) {self.data[i]}" for i in idxs] | |
def add_query(self, query: str) -> str: | |
q = query.strip() | |
if not q: | |
return "⚠️ Empty input. Not saved." | |
if q in self.data: | |
return f"⚠️ Query already exists: \"{q}\"" | |
# Persist new query | |
with open(FILE_PATH, 'a', encoding='utf-8') as f: | |
f.write(f"{q}\n") | |
# Update in-memory structures | |
self.data.append(q) | |
new_emb = self.embedder.encode([q], normalize_embeddings=True) | |
if self.embedded_texts.size == 0: | |
self.embedded_texts = new_emb | |
else: | |
if self.embedded_texts.ndim == 1: | |
self.embedded_texts = self.embedded_texts.reshape(1, -1) | |
self.embedded_texts = np.vstack([self.embedded_texts, new_emb]) | |
# Recompute density penalties | |
self.compute_penalties() | |
return f"✅ Saved: \"{q}\"" | |
# Initialize the search engine with the chosen model | |
search_engine = SimilaritySearch() | |
def perform_search_and_maybe_save(query, save_to_file=False): | |
if not query.strip(): | |
return "Please enter a search query.", "" | |
results = search_engine.search(query) | |
save_msg = search_engine.add_query(query) if save_to_file else "" | |
return "\n".join(results), save_msg | |
def load_file_contents() -> str: | |
with open(FILE_PATH, 'r', encoding='utf-8') as f: | |
return f.read() | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🔍 Semantic Search Engine") | |
query = gr.Textbox(label="Search Query") | |
save_checkbox = gr.Checkbox(label="Save this query to file?", value=False) | |
search_btn = gr.Button("Search", variant="primary") | |
with gr.Row(): | |
result = gr.Textbox(label="Top Matches", lines=3) | |
save_status = gr.Textbox(label="Save Status", lines=2) | |
file_content = gr.Textbox(label="Current File Content", lines=10) | |
search_btn.click(perform_search_and_maybe_save, | |
inputs=[query, save_checkbox], | |
outputs=[result, save_status]) | |
search_btn.click(load_file_contents, None, file_content) | |
demo.load(load_file_contents, file_content) | |
if __name__ == "__main__": | |
demo.launch() | |