similarity / app.py
sofiyan3052's picture
Update app.py
0481afd verified
import gradio as gr
import os
import numpy as np
from sentence_transformers import SentenceTransformer
FILE_PATH = "outputRUPDATED.txt"
class SimilaritySearch:
def __init__(self,
model_name: str = "sentence-transformers/stsb-roberta-large",
bias: float = 0.0,
density_rate: float = 0.5,
max_penalty: float = 0.20):
"""
Uses a state-of-the-art SBERT model (RoBERTa-large) for embeddings.
"""
# Load a high-performance sentence embedding model
self.embedder = SentenceTransformer(model_name)
# Stored queries
self.data: list[str] = []
# Normalized embeddings matrix
self.embedded_texts: np.ndarray = np.empty((0,))
# Score adjustment parameters
self.bias = bias
self.density_rate = density_rate
self.max_penalty = max_penalty
# Penalty per entry based on dataset density
self.penalties: np.ndarray = np.array([])
# Load existing data and compute initial penalties
self.load_data()
self.compute_penalties()
def load_data(self):
# Ensure the storage file exists
if not os.path.exists(FILE_PATH):
open(FILE_PATH, 'w').close()
# Read stored queries
with open(FILE_PATH, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]
self.data = lines
# Compute embeddings if data exists
if self.data:
self.embedded_texts = self.embedder.encode(self.data, normalize_embeddings=True)
else:
self.embedded_texts = np.empty((0,))
def compute_penalties(self):
n = len(self.data)
if n < 2:
self.penalties = np.zeros(n)
return
# Ensure embeddings are 2D
emb = self.embedded_texts
if emb.ndim == 1:
emb = emb.reshape(1, -1)
# Pairwise similarity matrix
sim_matrix = np.dot(emb, emb.T)
# Zero out self-similarity
np.fill_diagonal(sim_matrix, 0.0)
# Compute average neighbor similarity (density)
densities = sim_matrix.mean(axis=1)
# Convert densities to penalties
self.penalties = np.minimum(densities * self.density_rate, self.max_penalty)
def search(self, query: str, top_n: int = 5) -> list[str]:
if not self.data:
return ["⚠️ No data to search. Add some queries first."]
# Embed query using the same high-quality model
q_emb = self.embedder.encode(query, normalize_embeddings=True)
emb = self.embedded_texts
if emb.ndim == 1:
emb = emb.reshape(1, -1)
# Compute cosine similarities
sims = np.dot(emb, q_emb).flatten()
# Adjust: add global bias, subtract density-based penalty
adjusted = sims + self.bias - self.penalties
final = np.clip(adjusted, 0.0, 1.0)
# Select top-n matches
top_n = min(top_n, len(self.data))
idxs = np.argsort(final)[::-1][:top_n]
return [f"({final[i]:.3f}) {self.data[i]}" for i in idxs]
def add_query(self, query: str) -> str:
q = query.strip()
if not q:
return "⚠️ Empty input. Not saved."
if q in self.data:
return f"⚠️ Query already exists: \"{q}\""
# Persist new query
with open(FILE_PATH, 'a', encoding='utf-8') as f:
f.write(f"{q}\n")
# Update in-memory structures
self.data.append(q)
new_emb = self.embedder.encode([q], normalize_embeddings=True)
if self.embedded_texts.size == 0:
self.embedded_texts = new_emb
else:
if self.embedded_texts.ndim == 1:
self.embedded_texts = self.embedded_texts.reshape(1, -1)
self.embedded_texts = np.vstack([self.embedded_texts, new_emb])
# Recompute density penalties
self.compute_penalties()
return f"✅ Saved: \"{q}\""
# Initialize the search engine with the chosen model
search_engine = SimilaritySearch()
def perform_search_and_maybe_save(query, save_to_file=False):
if not query.strip():
return "Please enter a search query.", ""
results = search_engine.search(query)
save_msg = search_engine.add_query(query) if save_to_file else ""
return "\n".join(results), save_msg
def load_file_contents() -> str:
with open(FILE_PATH, 'r', encoding='utf-8') as f:
return f.read()
with gr.Blocks() as demo:
gr.Markdown("# 🔍 Semantic Search Engine")
query = gr.Textbox(label="Search Query")
save_checkbox = gr.Checkbox(label="Save this query to file?", value=False)
search_btn = gr.Button("Search", variant="primary")
with gr.Row():
result = gr.Textbox(label="Top Matches", lines=3)
save_status = gr.Textbox(label="Save Status", lines=2)
file_content = gr.Textbox(label="Current File Content", lines=10)
search_btn.click(perform_search_and_maybe_save,
inputs=[query, save_checkbox],
outputs=[result, save_status])
search_btn.click(load_file_contents, None, file_content)
demo.load(load_file_contents, file_content)
if __name__ == "__main__":
demo.launch()