Spaces:
Running
Running
File size: 7,905 Bytes
f9618f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import json
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any
model_name = "Qwen/Qwen3-Embedding-0.6B"
print(f"Loading SentenceTransformer model: '{model_name}'...")
device = 'cuda' if faiss.get_num_gpus() > 0 else 'cpu'
print(f"Using device: {device}")
model = SentenceTransformer(model_name, device=device)
if device == 'cuda':
model = model.half() # Use FP16 for GPU
def load_faiss_index_and_metadata(index_path: str, metadata_path: str, chunks_json_path: str) -> tuple:
"""
Loads the FAISS index, metadata, and original code chunks.
Args:
index_path (str): Path to the FAISS index file.
metadata_path (str): Path to the metadata JSON file.
chunks_json_path (str): Path to the original code_chunks.json file.
Returns:
tuple: (FAISS index, metadata list, chunks dictionary)
"""
print(f"Loading FAISS index from '{index_path}'...")
try:
index = faiss.read_index(index_path)
except Exception as e:
print(f"Error loading FAISS index: {e}")
return None, None, None
print(f"Loading metadata from '{metadata_path}'...")
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
except FileNotFoundError:
print(f"Error: The file '{metadata_path}' was not found.")
return None, None, None
print(f"Loading code chunks from '{chunks_json_path}'...")
try:
with open(chunks_json_path, 'r', encoding='utf-8') as f:
chunks = json.load(f)
# Create a dictionary for quick lookup by chunk_id
chunks_dict = {chunk['chunk_id']: chunk for chunk in chunks}
except FileNotFoundError:
print(f"Error: The file '{chunks_json_path}' was not found.")
return None, None, None
return index, metadata, chunks_dict
def retrieve_relevant_chunks(query: str, model_name: str, index: faiss.IndexFlatL2, metadata: List[Dict], chunks_dict: Dict, top_k: int = 5) -> List[Dict]:
"""
Encodes the query and retrieves the top-k most relevant code chunks.
Args:
query (str): The user's input query.
model_name (str): The SentenceTransformer model to use.
index (faiss.IndexFlatL2): The loaded FAISS index.
metadata (List[Dict]): The metadata for the indexed chunks.
chunks_dict (Dict): Dictionary mapping chunk_id to chunk details.
top_k (int): Number of top results to return.
Returns:
List[Dict]: List of dictionaries containing the retrieved chunks and their metadata.
"""
# Load the model
# Prepare query text (mimic the chunk format used during indexing)
query_text = f"Type: query\nDocstring: {query}\n---\n{query}"
query_embedding = model.encode([query_text], show_progress_bar=False).astype('float32')
# Perform FAISS search
print(f"Searching for top {top_k} relevant chunks...")
distances, indices = index.search(query_embedding, top_k)
# Collect results
results = []
for idx, distance in zip(indices[0], distances[0]):
if idx < len(metadata):
meta = metadata[idx]
chunk_id = meta['chunk_id']
chunk = chunks_dict.get(chunk_id, {})
results.append({
'chunk_id': chunk_id,
'file_path': meta.get('file_path', 'Unknown'),
'start_line': meta.get('start_line', -1),
'end_line': meta.get('end_line', -1),
'name': meta.get('name', 'Unknown'),
'chunk_type': meta.get('chunk_type', 'Unknown'),
'docstring': chunk.get('docstring', 'No docstring.'),
'content': chunk.get('content', 'No content available.'),
'distance': float(distance) # Similarity score (L2 distance)
})
else:
print(f"Warning: Index {idx} out of range for metadata.")
return results
# def print_results(results: List[Dict]):
# """
# Prints the retrieved results in a readable format.
# Args:
# results (List[Dict]): List of retrieved chunk details.
# """
# if not results:
# print("No relevant chunks found.")
# return
# print("\n=== Retrieved Chunks ===")
# returned_text=""
# for i, result in enumerate(results, 1):
# # print(f"\nResult {i}:")
# # print(f"Chunk ID: {result['chunk_id']}")
# # print(f"Type: {result['chunk_type']}")
# # print(f"Name: {result['name']}")
# # print(f"File: {result['file_path']} (Lines {result['start_line']}β{result['end_line']})")
# # print(f"Distance: {result['distance']:.4f}")
# # print(f"Docstring: {result['docstring']}")
# # print("\nCode:")
# # print(result['content'])
# # print("-" * 80)
# returned_text=returned_text + "\n" +"chunk_id: " + "\n"+ f"File: {result['file_path']} (Lines {result['start_line']}β{result['end_line']})" + "\n" + result['chunk_id'] + "\n" +"code: " + result['content']
# # return in style
# return returned_text
# #return { {'results': for result in results }
# # In retrive_docs.py
def print_results(results: List[Dict]):
"""
Formats the retrieved results into a Markdown string with GitHub links
and syntax highlighting.
Args:
results (List[Dict]): List of retrieved chunk details.
"""
if not results:
return "No relevant chunks found."
GITHUB_BASE_URL = "https://github.com/ultralytics/ultralytics/blob/main/"
markdown_output = ""
for i, result in enumerate(results, 1):
file_path = result.get('file_path', 'Unknown')
start_line = result.get('start_line', -1)
end_line = result.get('end_line', -1)
# Construct a direct link to the code on GitHub
if file_path != 'Unknown' and start_line != -1:
github_link = f"{GITHUB_BASE_URL}{file_path}#L{start_line}-L{end_line}"
header = f"### {i}. [{file_path}]({github_link}) (Lines {start_line}β{end_line})"
else:
header = f"### {i}. {file_path} (Lines {start_line}β{end_line})"
markdown_output += f"{header}\n"
markdown_output += f"**Type:** `{result.get('chunk_type', 'N/A')}` **Name:** `{result.get('name', 'N/A')}`\n\n"
markdown_output += "```python\n"
markdown_output += result.get('content', '# No content available.') + "\n"
markdown_output += "```\n---\n"
return markdown_output
# if __name__ == "__main__":
# # --- CONFIGURATION ---
# INDEX_PATH = "code_faiss.index"
# METADATA_PATH = "code_metadata.json"
# CHUNKS_JSON_PATH = "code_chunks.json"
# MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Must match the model used in create_faiss.py
# TOP_K = 5 # Number of results to retrieve
# # --- EXECUTION ---
# # Load FAISS index and metadata
# index, metadata, chunks_dict = load_faiss_index_and_metadata(
# index_path=INDEX_PATH,
# metadata_path=METADATA_PATH,
# chunks_json_path=CHUNKS_JSON_PATH
# )
# if index is None or metadata is None or chunks_dict is None:
# print("Failed to load index, metadata, or chunks. Exiting.")
# exit(1)
# # Get user query
# print("\nEnter your query (e.g., 'function to process text data'):")
# query = input("> ")
# # Retrieve and display results
# results = retrieve_relevant_chunks(
# query=query,
# model_name=MODEL_NAME,
# index=index,
# metadata=metadata,
# chunks_dict=chunks_dict,
# top_k=TOP_K
# )
# print_results(results) |