|
from sentence_transformers import SentenceTransformer |
|
from semantic_index import FaissSemanticIndex |
|
from embedding_cache import EmbeddingCache |
|
from datasets import load_dataset |
|
|
|
|
|
class NSFWMatchingAgent: |
|
def __init__(self): |
|
self.model_name = "all-MiniLM-L6-v2" |
|
self.model = SentenceTransformer(self.model_name) |
|
self.cache = EmbeddingCache(self.model_name) |
|
self.index = FaissSemanticIndex(dim=384) |
|
self.texts = [] |
|
|
|
self.sources = [ |
|
"aifeifei798/DPO_Pairs-Roleplay-NSFW", |
|
"Maxx0/sexting-nsfw-adultconten", |
|
"QuietImpostor/Claude-3-Opus-Claude-3.5-Sonnnet-9k", |
|
"HuggingFaceTB/everyday-conversations-llama3.1-2k", |
|
"Chadgpt-fam/sexting_dataset" |
|
] |
|
|
|
self._load_and_index() |
|
|
|
def _load_and_index(self): |
|
for source in self.sources: |
|
try: |
|
dataset = load_dataset(source) |
|
texts = [] |
|
for split in dataset: |
|
if 'text' in dataset[split].features: |
|
texts.extend(dataset[split]['text']) |
|
elif 'content' in dataset[split].features: |
|
texts.extend(dataset[split]['content']) |
|
|
|
embeddings = self.cache.compute_or_load_embeddings( |
|
texts, dataset_key=source.replace("/", "_") |
|
) |
|
self.index.add(embeddings, texts) |
|
self.texts.extend(texts) |
|
except Exception as e: |
|
print(f"[WARN] Failed to load {source}: {e}") |
|
|
|
def match(self, query: str) -> str: |
|
query_vec = self.model.encode([query], convert_to_tensor=True) |
|
result = self.index.search(query_vec, k=1) |
|
return result[0][0] if result else "No match found" |