File size: 6,013 Bytes
f4aa3fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import numpy as np
def load_aligned_bin(file_path, dtype=np.int8, alignment=8):
"""
读取并解析diskann规定的二进制向量文件,比如bigann_query.bin等
参数:
file_path (str): 二进制文件路径
dtype (np.dtype): 数据类型,默认为float32
alignment (int): 维度对齐的倍数,默认为8(读取数据集要对齐,保存res的时候没有对齐要求,所以读res的时候不用align)
返回:
np.ndarray: 读取并对齐后的向量数据
int: 向量数量
int: 原始向量维度
int: 对齐后的向量维度
"""
# 获取文件大小
actual_file_size = os.path.getsize(file_path)
with open(file_path, 'rb') as f:
# 读取向量数量和维度信息
npts = int(np.fromfile(f, dtype=np.int32, count=1)[0])
dim = int(np.fromfile(f, dtype=np.int32, count=1)[0])
# 计算预期的文件大小
expected_file_size = 2 * 4 + npts * dim * np.dtype(dtype).itemsize
if actual_file_size != expected_file_size:
raise ValueError(f"文件大小不匹配。实际大小: {actual_file_size} 字节,"
f"预期大小: {expected_file_size} 字节")
# 计算对齐后的维度
rounded_dim = ((dim + alignment - 1) // alignment) * alignment
# rounded_dim = dim
print(f"元数据: #向量 = {npts}, #维度 = {dim}, 对齐维度 = {rounded_dim}")
print(f"分配内存: {npts * rounded_dim * np.dtype(dtype).itemsize} 字节")
# 创建对齐后的数组(填充零)
data = np.zeros((npts, rounded_dim), dtype=dtype)
# 逐行读取向量数据
for i in range(npts):
# 从文件读取原始向量数据
vector = np.fromfile(f, dtype=dtype, count=dim)
# 存储到对齐后的数组位置
data[i, :dim] = vector
print("数据读取完成")
return data, npts, dim, rounded_dim
def load_bin(filename, dtype=np.float32):
"""
读取按照save_bin函数格式保存的二进制文件
参数:
filename (str): 二进制文件路径
dtype (np.dtype): 数据类型,根据文件内容选择np.uint32或np.float32
返回:
np.ndarray: 读取的数据数组
int: 向量数量(npts)
int: 每个向量的维度(ndims)
"""
with open(filename, 'rb') as f:
# 读取前两个整数(npts和ndims)
npts = np.fromfile(f, dtype=np.int32, count=1)[0]
ndims = np.fromfile(f, dtype=np.int32, count=1)[0]
print(f"读取元数据: #向量 = {npts}, #维度 = {ndims}")
# 读取实际数据
data = np.fromfile(f, dtype=dtype, count=npts * ndims)
# 重塑为二维数组
data = data.reshape(npts, ndims)
print(f"成功读取 {data.shape[0]} 个向量,每个向量维度为 {data.shape[1]}")
print(f"数据类型: {data.dtype}")
return data, npts, ndims
# 使用示例
if __name__ == "__main__":
# 读取索引结果文件
idx_filename = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/res_20_idx_uint32.bin" # 替换为实际文件路径
idx_data, npts, ndims = load_bin(idx_filename, dtype=np.uint32)
# 读取距离结果文件
dist_filename = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/res_20_dists_float.bin" # 替换为实际文件路径
dist_data, _, _ = load_bin(dist_filename, dtype=np.float32)
file_path = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/bigann_query.bin"
query_data, query_npts, query_dim, query_rounded_dim = load_aligned_bin(file_path)
# 打印前5个查询结果
print("\n示例结果:")
for i in range(min(5, npts)):
print(f"查询 {i}:")
print(f" 最近邻索引: {idx_data[i, :]}")
print(f" 最近邻距离: {dist_data[i, :]}")
import pdb; pdb.set_trace()
print('load learn vectors')
base_path = "/home/myw/wuchangli/yk/diskann_demo/DiskANN/build/data_backup_clean_test/bigann_learn.bin"
base_vectors, base_npts, base_dim, base_rounded_dim = load_aligned_bin(base_path)
import pdb; pdb.set_trace()
# 查询 0:
# 最近邻索引: [98840598 23883997 96907701 96907734 96908621 3762989 53147362 88013120
# 53147127 96908488]
# 最近邻距离: [68103. 68524. 70149. 71367. 71500. 71746. 71859. 73190. 73243. 73430.]
# 查询 1:
# 最近邻索引: [26755808 34080910 3236088 15610118 66524798 72823982 27594517 37556745
# 26786213 39265851]
# 最近邻距离: [28587. 30958. 31851. 32047. 32234. 33895. 34205. 34521. 34699. 34823.]
# 查询 2:
# 最近邻索引: [71015134 21282510 71010508 6626649 96440577 38613912 71106199 76697998
# 53211679 52141940]
# 最近邻距离: [36378. 41223. 41241. 41245. 41867. 42108. 42120. 42183. 42383. 42662.]
# 查询 3:
# 最近邻索引: [21759610 34536818 21549958 84252617 2083654 30364765 15796826 11393460
# 52313741 10230325]
# 最近邻距离: [2474. 2684. 2701. 2725. 2768. 2869. 2918. 2937. 2967. 2996.]
# 查询 4:
# 最近邻索引: [44759921 57423326 30803610 85912171 93322848 38777531 64765128 50769267
# 33429914 98348143]
# 最近邻距离: [48352. 64203. 65424. 66350. 69264. 69772. 70490. 70679. 71253. 72685.]
# (Pdb) np.sum((base_vectors[98840598].astype('float32') - query_data[0].astype('float32'))**2)
# 68103.0
# (Pdb) np.sum((base_vectors[23883997].astype('float32') - query_data[0].astype('float32'))**2)
# 68524.0
# (Pdb) np.sum((base_vectors[96907701].astype('float32') - query_data[0].astype('float32'))**2)
# 70149.0
# (Pdb) np.sum((base_vectors[26755808].astype('float32') - query_data[1].astype('float32'))**2)
# 28587.0
# (Pdb) np.sum((base_vectors[44759921].astype('float32') - query_data[4].astype('float32'))**2)
# 48352.0 |