Spaces:
Sleeping
Sleeping
Update core/models/knowledge_base.py
Browse files- core/models/knowledge_base.py +116 -52
core/models/knowledge_base.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
-
|
2 |
-
import faiss
|
3 |
-
import pickle
|
4 |
import logging
|
5 |
-
import
|
6 |
from pathlib import Path
|
7 |
from typing import List, Dict, Any
|
|
|
|
|
|
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
class OptimizedGazaKnowledgeBase:
|
|
|
|
|
13 |
def __init__(self, vector_store_dir: str = "./vector_store"):
|
14 |
self.vector_store_dir = Path(vector_store_dir)
|
15 |
self.faiss_index = None
|
@@ -17,60 +19,121 @@ class OptimizedGazaKnowledgeBase:
|
|
17 |
self.chunks = []
|
18 |
self.metadata = []
|
19 |
self.is_initialized = False
|
20 |
-
|
21 |
def initialize(self):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
|
|
51 |
if not self.is_initialized:
|
52 |
-
raise RuntimeError("Knowledge base not initialized
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def get_stats(self) -> Dict[str, Any]:
|
|
|
72 |
if not self.is_initialized:
|
73 |
return {"status": "not_initialized"}
|
|
|
74 |
return {
|
75 |
"status": "initialized",
|
76 |
"total_chunks": len(self.chunks),
|
@@ -79,3 +142,4 @@ class OptimizedGazaKnowledgeBase:
|
|
79 |
"index_type": type(self.faiss_index).__name__,
|
80 |
"sources": list(set(meta.get("source", "unknown") for meta in self.metadata))
|
81 |
}
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import pickle
|
3 |
from pathlib import Path
|
4 |
from typing import List, Dict, Any
|
5 |
+
|
6 |
+
import faiss
|
7 |
+
import numpy as np
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
class OptimizedGazaKnowledgeBase:
|
13 |
+
"""Optimized knowledge base that loads pre-made FAISS index and assets"""
|
14 |
+
|
15 |
def __init__(self, vector_store_dir: str = "./vector_store"):
|
16 |
self.vector_store_dir = Path(vector_store_dir)
|
17 |
self.faiss_index = None
|
|
|
19 |
self.chunks = []
|
20 |
self.metadata = []
|
21 |
self.is_initialized = False
|
22 |
+
|
23 |
def initialize(self):
|
24 |
+
"""Load pre-made FAISS index and associated data"""
|
25 |
+
try:
|
26 |
+
logger.info("π Loading pre-made FAISS index and assets...")
|
27 |
+
|
28 |
+
# 1. Load FAISS index
|
29 |
+
index_path = self.vector_store_dir / "index.faiss"
|
30 |
+
if not index_path.exists():
|
31 |
+
raise FileNotFoundError(f"FAISS index not found at {index_path}")
|
32 |
+
|
33 |
+
self.faiss_index = faiss.read_index(str(index_path))
|
34 |
+
logger.info(f"β
Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions")
|
35 |
+
|
36 |
+
# 2. Load chunks
|
37 |
+
chunks_path = self.vector_store_dir / "chunks.txt"
|
38 |
+
if not chunks_path.exists():
|
39 |
+
raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
|
40 |
+
|
41 |
+
with open(chunks_path, \'r\', encoding=\'utf-8\') as f:
|
42 |
+
lines = f.readlines()
|
43 |
+
|
44 |
+
# Parse chunks from the formatted file
|
45 |
+
current_chunk = ""
|
46 |
+
for line in lines:
|
47 |
+
line = line.strip()
|
48 |
+
if line.startswith("=== Chunk") and current_chunk:
|
49 |
+
self.chunks.append(current_chunk.strip())
|
50 |
+
current_chunk = ""
|
51 |
+
elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"):
|
52 |
+
current_chunk += line + " "
|
53 |
+
|
54 |
+
# Add the last chunk
|
55 |
+
if current_chunk:
|
56 |
+
self.chunks.append(current_chunk.strip())
|
57 |
+
|
58 |
+
logger.info(f"β
Loaded {len(self.chunks)} text chunks")
|
59 |
+
|
60 |
+
# 3. Load metadata
|
61 |
+
metadata_path = self.vector_store_dir / "metadata.pkl"
|
62 |
+
if metadata_path.exists():
|
63 |
+
with open(metadata_path, \'rb\') as f:
|
64 |
+
metadata_dict = pickle.load(f)
|
65 |
+
|
66 |
+
if isinstance(metadata_dict, dict) and \'metadata\' in metadata_dict:
|
67 |
+
self.metadata = metadata_dict[\'metadata\']
|
68 |
+
logger.info(f"β
Loaded {len(self.metadata)} metadata entries")
|
69 |
+
else:
|
70 |
+
logger.warning("β οΈ Metadata format not recognized, using empty metadata")
|
71 |
+
self.metadata = [{}] * len(self.chunks)
|
72 |
+
else:
|
73 |
+
logger.warning("β οΈ No metadata file found, using empty metadata")
|
74 |
+
self.metadata = [{}] * len(self.chunks)
|
75 |
+
|
76 |
+
# 4. Initialize embedding model for query encoding
|
77 |
+
logger.info("π Loading embedding model for queries...")
|
78 |
+
self.embedding_model = SentenceTransformer(\'sentence-transformers/all-mpnet-base-v2\')
|
79 |
+
logger.info("β
Embedding model loaded")
|
80 |
+
|
81 |
+
# 5. Verify data consistency
|
82 |
+
if len(self.chunks) != self.faiss_index.ntotal:
|
83 |
+
logger.warning(f"β οΈ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors")
|
84 |
+
# Trim chunks to match index size
|
85 |
+
self.chunks = self.chunks[:self.faiss_index.ntotal]
|
86 |
+
self.metadata = self.metadata[:self.faiss_index.ntotal]
|
87 |
+
logger.info(f"β
Trimmed to {len(self.chunks)} chunks to match index")
|
88 |
+
|
89 |
+
self.is_initialized = True
|
90 |
+
logger.info("π Knowledge base initialization complete!")
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"β Failed to initialize knowledge base: {e}")
|
94 |
+
raise
|
95 |
+
|
96 |
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
97 |
+
"""Search using pre-made FAISS index"""
|
98 |
if not self.is_initialized:
|
99 |
+
raise RuntimeError("Knowledge base not initialized")
|
100 |
+
|
101 |
+
try:
|
102 |
+
# 1. Encode query
|
103 |
+
query_embedding = self.embedding_model.encode([query])
|
104 |
+
query_vector = np.array(query_embedding, dtype=np.float32)
|
105 |
+
|
106 |
+
# 2. Search FAISS index
|
107 |
+
distances, indices = self.faiss_index.search(query_vector, k)
|
108 |
+
|
109 |
+
# 3. Prepare results
|
110 |
+
results = []
|
111 |
+
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
|
112 |
+
if idx >= 0 and idx < len(self.chunks): # Valid index
|
113 |
+
chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {}
|
114 |
+
|
115 |
+
result = {
|
116 |
+
"text": self.chunks[idx],
|
117 |
+
"score": float(1.0 / (1.0 + distance)), # Convert distance to similarity score
|
118 |
+
"source": chunk_metadata.get("source", "unknown"),
|
119 |
+
"chunk_index": int(idx),
|
120 |
+
"distance": float(distance),
|
121 |
+
"metadata": chunk_metadata
|
122 |
+
}
|
123 |
+
results.append(result)
|
124 |
+
|
125 |
+
logger.info(f"π Search for \'{query}...\' returned {len(results)} results")
|
126 |
+
return results
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"β Search error: {e}")
|
130 |
+
return []
|
131 |
+
|
132 |
def get_stats(self) -> Dict[str, Any]:
|
133 |
+
"""Get knowledge base statistics"""
|
134 |
if not self.is_initialized:
|
135 |
return {"status": "not_initialized"}
|
136 |
+
|
137 |
return {
|
138 |
"status": "initialized",
|
139 |
"total_chunks": len(self.chunks),
|
|
|
142 |
"index_type": type(self.faiss_index).__name__,
|
143 |
"sources": list(set(meta.get("source", "unknown") for meta in self.metadata))
|
144 |
}
|
145 |
+
|