nsfwdata / agents /nsfw_agent.py
S-Dreamer's picture
Create agents/nsfw_agent.py
1a3fc31 verified
from sentence_transformers import SentenceTransformer
from semantic_index import FaissSemanticIndex
from embedding_cache import EmbeddingCache
from datasets import load_dataset
class NSFWMatchingAgent:
def __init__(self):
self.model_name = "all-MiniLM-L6-v2"
self.model = SentenceTransformer(self.model_name)
self.cache = EmbeddingCache(self.model_name)
self.index = FaissSemanticIndex(dim=384) # MiniLM output size
self.texts = []
self.sources = [
"aifeifei798/DPO_Pairs-Roleplay-NSFW",
"Maxx0/sexting-nsfw-adultconten",
"QuietImpostor/Claude-3-Opus-Claude-3.5-Sonnnet-9k",
"HuggingFaceTB/everyday-conversations-llama3.1-2k",
"Chadgpt-fam/sexting_dataset"
]
self._load_and_index()
def _load_and_index(self):
for source in self.sources:
try:
dataset = load_dataset(source)
texts = []
for split in dataset:
if 'text' in dataset[split].features:
texts.extend(dataset[split]['text'])
elif 'content' in dataset[split].features:
texts.extend(dataset[split]['content'])
embeddings = self.cache.compute_or_load_embeddings(
texts, dataset_key=source.replace("/", "_")
)
self.index.add(embeddings, texts)
self.texts.extend(texts)
except Exception as e:
print(f"[WARN] Failed to load {source}: {e}")
def match(self, query: str) -> str:
query_vec = self.model.encode([query], convert_to_tensor=True)
result = self.index.search(query_vec, k=1)
return result[0][0] if result else "No match found"