-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinspect_tools.py
84 lines (68 loc) · 2.67 KB
/
inspect_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
def get_invlist(invlists, l):
""" returns the inverted lists content as a pair of (list_ids, list_codes).
The codes are reshaped to a proper size
"""
invlists = faiss.downcast_InvertedLists(invlists)
ls = invlists.list_size(l)
list_ids = np.zeros(ls, dtype='int64')
ids = codes = None
try:
ids = invlists.get_ids(l)
if ls > 0:
faiss.memcpy(faiss.swig_ptr(list_ids), ids, list_ids.nbytes)
codes = invlists.get_codes(l)
if invlists.code_size != faiss.InvertedLists.INVALID_CODE_SIZE:
list_codes = np.zeros((ls, invlists.code_size), dtype='uint8')
else:
# it's a BlockInvertedLists
npb = invlists.n_per_block
bs = invlists.block_size
ls_round = (ls + npb - 1) // npb
list_codes = np.zeros((ls_round, bs // npb, npb), dtype='uint8')
if ls > 0:
faiss.memcpy(faiss.swig_ptr(list_codes), codes, list_codes.nbytes)
finally:
if ids is not None:
invlists.release_ids(l, ids)
if codes is not None:
invlists.release_codes(l, codes)
return list_ids, list_codes
def get_invlist_sizes(invlists):
""" return the array of sizes of the inverted lists """
return np.array([
invlists.list_size(i)
for i in range(invlists.nlist)
], dtype='int64')
def print_object_fields(obj):
""" list values all fields of an object known to SWIG """
for name in obj.__class__.__swig_getmethods__:
print(f"{name} = {getattr(obj, name)}")
def get_pq_centroids(pq):
""" return the PQ centroids as an array """
cen = faiss.vector_to_array(pq.centroids)
return cen.reshape(pq.M, pq.ksub, pq.dsub)
def get_LinearTransform_matrix(pca):
""" extract matrix + bias from the PCA object
works for any linear transform (OPQ, random rotation, etc.)
"""
b = faiss.vector_to_array(pca.b)
A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
return A, b
def get_additive_quantizer_codebooks(aq):
""" return to codebooks of an additive quantizer """
codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)
co = faiss.vector_to_array(aq.codebook_offsets)
return [
codebooks[co[i]:co[i + 1]]
for i in range(aq.M)
]
def get_flat_data(index):
""" copy and return the data matrix in an IndexFlat """
xb = faiss.vector_to_array(index.codes).view("float32")
return xb.reshape(index.ntotal, index.d)