Spaces:
Runtime error
Runtime error
| import os.path | |
| from abc import ABC, abstractmethod | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| from pgvector.sqlalchemy import Vector | |
| from sqlalchemy import create_engine, Column, Integer, String | |
| from sqlalchemy.orm import sessionmaker, declarative_base | |
| from config import Config | |
| Base = declarative_base() | |
| class Storage(ABC): | |
| """Abstract Storage class.""" | |
| # factory method | |
| def create_storage(cfg: Config) -> 'Storage': | |
| """Create a storage object.""" | |
| if cfg.use_postgres: | |
| return _PostgresStorage(cfg) | |
| else: | |
| return _IndexStorage(cfg) | |
| def add_all(self, embeddings: list[tuple[str, list[float]]], name: str): | |
| """Add multiple embeddings.""" | |
| pass | |
| def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]: | |
| """Get the text for the provided embedding.""" | |
| pass | |
| def get_all_embeddings(self, name: str): | |
| """Get all embeddings.""" | |
| pass | |
| def clear(self, name: str): | |
| """Clear the database.""" | |
| pass | |
| def been_indexed(self, name: str) -> bool: | |
| """Check if the database has been indexed.""" | |
| pass | |
| class _IndexStorage(Storage): | |
| """IndexStorage class.""" | |
| def __init__(self, cfg: Config): | |
| """Initialize the storage.""" | |
| self._cfg = cfg | |
| def add_all(self, embeddings: list[tuple[str, list[float]]], name): | |
| """Add multiple embeddings.""" | |
| texts, index = self._load(name) | |
| ids = np.array([len(texts) + i for i, _ in enumerate(embeddings)]) | |
| texts = pd.concat([texts, pd.DataFrame( | |
| {'index': len(texts) + i, 'text': text} for i, (text, _) in enumerate(embeddings))]) | |
| array = np.array([emb for text, emb in embeddings]) | |
| index.add_with_ids(array, ids) | |
| self._save(texts, index, name) | |
| def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]: | |
| """Get the text for the provided embedding.""" | |
| texts, index = self._load(name) | |
| _, indexs = index.search(np.array([embedding]), limit) | |
| indexs = [i for i in indexs[0] if i >= 0] | |
| return [f'paragraph {p}: {t}' for _, p, t in texts.iloc[indexs].values] | |
| def get_all_embeddings(self, name: str): | |
| texts, index = self._load(name) | |
| texts = texts.text.tolist() | |
| embeddings = index.reconstruct_n(0, len(texts)) | |
| return list(zip(texts, embeddings)) | |
| def clear(self, name: str): | |
| """Clear the database.""" | |
| self._delete(name) | |
| def been_indexed(self, name: str) -> bool: | |
| return os.path.exists(os.path.join(self._cfg.index_path, f'{name}.csv')) and os.path.exists( | |
| os.path.join(self._cfg.index_path, f'{name}.bin')) | |
| def _save(self, texts, index, name: str): | |
| texts.to_csv(os.path.join(self._cfg.index_path, f'{name}.csv')) | |
| faiss.write_index(index, os.path.join(self._cfg.index_path, f'{name}.bin')) | |
| def _load(self, name: str): | |
| if self.been_indexed(name): | |
| texts = pd.read_csv(os.path.join(self._cfg.index_path, f'{name}.csv')) | |
| index = faiss.read_index(os.path.join(self._cfg.index_path, f'{name}.bin')) | |
| else: | |
| texts = pd.DataFrame(columns=['index', 'text']) | |
| # IDMap2 with Flat | |
| index = faiss.index_factory(1536, "IDMap2,Flat", faiss.METRIC_INNER_PRODUCT) | |
| return texts, index | |
| def _delete(self, name: str): | |
| try: | |
| os.remove(os.path.join(self._cfg.index_path, f'{name}.csv')) | |
| os.remove(os.path.join(self._cfg.index_path, f'{name}.bin')) | |
| except FileNotFoundError: | |
| pass | |
| def singleton(cls): | |
| instances = {} | |
| def get_instance(cfg): | |
| if cls not in instances: | |
| instances[cls] = cls(cfg) | |
| return instances[cls] | |
| return get_instance | |
| class _PostgresStorage(Storage): | |
| """PostgresStorage class.""" | |
| def __init__(self, cfg: Config): | |
| """Initialize the storage.""" | |
| self._postgresql = cfg.postgres_url | |
| self._engine = create_engine(self._postgresql) | |
| Base.metadata.create_all(self._engine) | |
| session = sessionmaker(bind=self._engine) | |
| self._session = session() | |
| def add_all(self, embeddings: list[tuple[str, list[float]]], name: str): | |
| """Add multiple embeddings.""" | |
| data = [self.EmbeddingEntity(text=text, embedding=embedding, name=name) for text, embedding in embeddings] | |
| self._session.add_all(data) | |
| self._session.commit() | |
| def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]: | |
| """Get the text for the provided embedding.""" | |
| result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).order_by( | |
| self.EmbeddingEntity.embedding.cosine_distance(embedding)).limit(limit).all() | |
| return [f'paragraph {s.id}: {s.text}' for s in result] | |
| def get_all_embeddings(self, name: str): | |
| """Get all embeddings.""" | |
| result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).all() | |
| return [(s.text, s.embedding) for s in result] | |
| def clear(self, name: str): | |
| """Clear the database.""" | |
| self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).delete() | |
| self._session.commit() | |
| def been_indexed(self, name: str) -> bool: | |
| return self._session.query(self.EmbeddingEntity).filter_by(name=name).first() is not None | |
| def __del__(self): | |
| """Close the session.""" | |
| self._session.close() | |
| class EmbeddingEntity(Base): | |
| __tablename__ = 'embedding' | |
| id = Column(Integer, primary_key=True) | |
| name = Column(String) | |
| text = Column(String) | |
| embedding = Column(Vector(1536)) | |