diskann_files / test_res_id_to_vec.py
cocoshe's picture
add test
f4aa3fa
import os
import numpy as np
def load_aligned_bin(file_path, dtype=np.int8, alignment=8):
"""
读取并解析diskann规定的二进制向量文件,比如bigann_query.bin等
参数:
file_path (str): 二进制文件路径
dtype (np.dtype): 数据类型,默认为float32
alignment (int): 维度对齐的倍数,默认为8(读取数据集要对齐,保存res的时候没有对齐要求,所以读res的时候不用align)
返回:
np.ndarray: 读取并对齐后的向量数据
int: 向量数量
int: 原始向量维度
int: 对齐后的向量维度
"""
# 获取文件大小
actual_file_size = os.path.getsize(file_path)
with open(file_path, 'rb') as f:
# 读取向量数量和维度信息
npts = int(np.fromfile(f, dtype=np.int32, count=1)[0])
dim = int(np.fromfile(f, dtype=np.int32, count=1)[0])
# 计算预期的文件大小
expected_file_size = 2 * 4 + npts * dim * np.dtype(dtype).itemsize
if actual_file_size != expected_file_size:
raise ValueError(f"文件大小不匹配。实际大小: {actual_file_size} 字节,"
f"预期大小: {expected_file_size} 字节")
# 计算对齐后的维度
rounded_dim = ((dim + alignment - 1) // alignment) * alignment
# rounded_dim = dim
print(f"元数据: #向量 = {npts}, #维度 = {dim}, 对齐维度 = {rounded_dim}")
print(f"分配内存: {npts * rounded_dim * np.dtype(dtype).itemsize} 字节")
# 创建对齐后的数组(填充零)
data = np.zeros((npts, rounded_dim), dtype=dtype)
# 逐行读取向量数据
for i in range(npts):
# 从文件读取原始向量数据
vector = np.fromfile(f, dtype=dtype, count=dim)
# 存储到对齐后的数组位置
data[i, :dim] = vector
print("数据读取完成")
return data, npts, dim, rounded_dim
def load_bin(filename, dtype=np.float32):
"""
读取按照save_bin函数格式保存的二进制文件
参数:
filename (str): 二进制文件路径
dtype (np.dtype): 数据类型,根据文件内容选择np.uint32或np.float32
返回:
np.ndarray: 读取的数据数组
int: 向量数量(npts)
int: 每个向量的维度(ndims)
"""
with open(filename, 'rb') as f:
# 读取前两个整数(npts和ndims)
npts = np.fromfile(f, dtype=np.int32, count=1)[0]
ndims = np.fromfile(f, dtype=np.int32, count=1)[0]
print(f"读取元数据: #向量 = {npts}, #维度 = {ndims}")
# 读取实际数据
data = np.fromfile(f, dtype=dtype, count=npts * ndims)
# 重塑为二维数组
data = data.reshape(npts, ndims)
print(f"成功读取 {data.shape[0]} 个向量,每个向量维度为 {data.shape[1]}")
print(f"数据类型: {data.dtype}")
return data, npts, ndims
# 使用示例
if __name__ == "__main__":
# 读取索引结果文件
idx_filename = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/res_20_idx_uint32.bin" # 替换为实际文件路径
idx_data, npts, ndims = load_bin(idx_filename, dtype=np.uint32)
# 读取距离结果文件
dist_filename = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/res_20_dists_float.bin" # 替换为实际文件路径
dist_data, _, _ = load_bin(dist_filename, dtype=np.float32)
file_path = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/bigann_query.bin"
query_data, query_npts, query_dim, query_rounded_dim = load_aligned_bin(file_path)
# 打印前5个查询结果
print("\n示例结果:")
for i in range(min(5, npts)):
print(f"查询 {i}:")
print(f" 最近邻索引: {idx_data[i, :]}")
print(f" 最近邻距离: {dist_data[i, :]}")
import pdb; pdb.set_trace()
print('load learn vectors')
base_path = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/bigann_learn.bin"
base_vectors, base_npts, base_dim, base_rounded_dim = load_aligned_bin(base_path)
import pdb; pdb.set_trace()
# 查询 0:
# 最近邻索引: [98840598 23883997 96907701 96907734 96908621 3762989 53147362 88013120
# 53147127 96908488]
# 最近邻距离: [68103. 68524. 70149. 71367. 71500. 71746. 71859. 73190. 73243. 73430.]
# 查询 1:
# 最近邻索引: [26755808 34080910 3236088 15610118 66524798 72823982 27594517 37556745
# 26786213 39265851]
# 最近邻距离: [28587. 30958. 31851. 32047. 32234. 33895. 34205. 34521. 34699. 34823.]
# 查询 2:
# 最近邻索引: [71015134 21282510 71010508 6626649 96440577 38613912 71106199 76697998
# 53211679 52141940]
# 最近邻距离: [36378. 41223. 41241. 41245. 41867. 42108. 42120. 42183. 42383. 42662.]
# 查询 3:
# 最近邻索引: [21759610 34536818 21549958 84252617 2083654 30364765 15796826 11393460
# 52313741 10230325]
# 最近邻距离: [2474. 2684. 2701. 2725. 2768. 2869. 2918. 2937. 2967. 2996.]
# 查询 4:
# 最近邻索引: [44759921 57423326 30803610 85912171 93322848 38777531 64765128 50769267
# 33429914 98348143]
# 最近邻距离: [48352. 64203. 65424. 66350. 69264. 69772. 70490. 70679. 71253. 72685.]
# (Pdb) np.sum((base_vectors[98840598].astype('float32') - query_data[0].astype('float32'))**2)
# 68103.0
# (Pdb) np.sum((base_vectors[23883997].astype('float32') - query_data[0].astype('float32'))**2)
# 68524.0
# (Pdb) np.sum((base_vectors[96907701].astype('float32') - query_data[0].astype('float32'))**2)
# 70149.0
# (Pdb) np.sum((base_vectors[26755808].astype('float32') - query_data[1].astype('float32'))**2)
# 28587.0
# (Pdb) np.sum((base_vectors[44759921].astype('float32') - query_data[4].astype('float32'))**2)
# 48352.0