File size: 1,819 Bytes
1a3fc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from sentence_transformers import SentenceTransformer
from semantic_index import FaissSemanticIndex
from embedding_cache import EmbeddingCache
from datasets import load_dataset
class NSFWMatchingAgent:
def __init__(self):
self.model_name = "all-MiniLM-L6-v2"
self.model = SentenceTransformer(self.model_name)
self.cache = EmbeddingCache(self.model_name)
self.index = FaissSemanticIndex(dim=384) # MiniLM output size
self.texts = []
self.sources = [
"aifeifei798/DPO_Pairs-Roleplay-NSFW",
"Maxx0/sexting-nsfw-adultconten",
"QuietImpostor/Claude-3-Opus-Claude-3.5-Sonnnet-9k",
"HuggingFaceTB/everyday-conversations-llama3.1-2k",
"Chadgpt-fam/sexting_dataset"
]
self._load_and_index()
def _load_and_index(self):
for source in self.sources:
try:
dataset = load_dataset(source)
texts = []
for split in dataset:
if 'text' in dataset[split].features:
texts.extend(dataset[split]['text'])
elif 'content' in dataset[split].features:
texts.extend(dataset[split]['content'])
embeddings = self.cache.compute_or_load_embeddings(
texts, dataset_key=source.replace("/", "_")
)
self.index.add(embeddings, texts)
self.texts.extend(texts)
except Exception as e:
print(f"[WARN] Failed to load {source}: {e}")
def match(self, query: str) -> str:
query_vec = self.model.encode([query], convert_to_tensor=True)
result = self.index.search(query_vec, k=1)
return result[0][0] if result else "No match found" |