Skip to content

Commit

Permalink
sync with FB version 2017-07-18
Browse files Browse the repository at this point in the history
- implemented ScalarQuantizer (without IVF)
- implemented update for IndexIVFFlat
- implemented L2 normalization preproc
  • Loading branch information
mdouze committed Jul 18, 2017
1 parent 602deba commit f7aedbd
Show file tree
Hide file tree
Showing 24 changed files with 4,531 additions and 1,888 deletions.
92 changes: 56 additions & 36 deletions AutoTune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "IndexIVF.h"
#include "IndexIVFPQ.h"
#include "MetaIndexes.h"
#include "IndexIVFScalarQuantizer.h"
#include "IndexScalarQuantizer.h"


namespace faiss {
Expand Down Expand Up @@ -623,18 +623,28 @@ void ParameterSpace::explore (Index *index,
* index_factory
***************************************************************/

namespace {

struct VTChain {
std::vector<VectorTransform *> chain;
~VTChain () {
for (int i = 0; i < chain.size(); i++) {
delete chain[i];
}
}
};

}

Index *index_factory (int d, const char *description_in, MetricType metric)
{
VectorTransform *vt = nullptr;
VTChain vts;
Index *coarse_quantizer = nullptr;
Index *index = nullptr;
bool add_idmap = false;
bool make_IndexRefineFlat = false;

ScopeDeleter1<Index> del_coarse_quantizer, del_index;
ScopeDeleter1<VectorTransform> del_vt;

char description[strlen(description_in) + 1];
char *ptr;
Expand All @@ -656,18 +666,27 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
Index *index_1 = nullptr;

// VectorTransforms
if (!vt && sscanf (tok, "PCA%d", &d_out) == 1) {
if (sscanf (tok, "PCA%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out);
d = d_out;
} else if (!vt && sscanf (tok, "PCAR%d", &d_out) == 1) {
} else if (sscanf (tok, "PCAR%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out, 0, true);
d = d_out;
} else if (!vt && sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
} else if (sscanf (tok, "PCAW%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out, -0.5, false);
d = d_out;
} else if (sscanf (tok, "PCAWR%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out, -0.5, true);
d = d_out;
} else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
vt_1 = new OPQMatrix (d, opq_M, d_out);
d = d_out;
} else if (!vt && sscanf (tok, "OPQ%d", &opq_M) == 1) {
} else if (sscanf (tok, "OPQ%d", &opq_M) == 1) {
vt_1 = new OPQMatrix (d, opq_M);
// coarse quantizers
} else if (stok == "L2norm") {
vt_1 = new NormalizationTransform (d, 2.0);

// coarse quantizers
} else if (!coarse_quantizer &&
sscanf (tok, "IVF%d", &ncentroids) == 1) {
if (metric == METRIC_L2) {
Expand Down Expand Up @@ -698,28 +717,25 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
index_1 = index_ivf;
} else {
index_1 = new IndexFlat (d, metric);
if (add_idmap) {
IndexIDMap *idmap = new IndexIDMap(index_1);
idmap->own_fields = true;
index_1 = idmap;
add_idmap = false;
}
}
} else if (!index && (stok == "SQ8" || stok == "SQ4")) {
FAISS_THROW_IF_NOT_MSG(coarse_quantizer,
"ScalarQuantizer works only with an IVF");
ScalarQuantizer::QuantizerType qt =
stok == "SQ8" ? ScalarQuantizer::QT_8bit :
stok == "SQ4" ? ScalarQuantizer::QT_4bit :
ScalarQuantizer::QT_4bit;
IndexIVFScalarQuantizer *index_ivf = new IndexIVFScalarQuantizer (
coarse_quantizer, d, ncentroids, qt, metric);
index_ivf->quantizer_trains_alone =
dynamic_cast<MultiIndexQuantizer*>(coarse_quantizer)
!= nullptr;
del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
if (coarse_quantizer) {
IndexIVFScalarQuantizer *index_ivf =
new IndexIVFScalarQuantizer (
coarse_quantizer, d, ncentroids, qt, metric);
index_ivf->quantizer_trains_alone =
dynamic_cast<MultiIndexQuantizer*>(coarse_quantizer)
!= nullptr;
del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
} else {
index_1 = new IndexScalarQuantizer (d, qt, metric);
}
} else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) {
FAISS_THROW_IF_NOT_MSG(coarse_quantizer,
"PQ with + works only with an IVF");
Expand Down Expand Up @@ -750,13 +766,6 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
IndexPQ *index_pq = new IndexPQ (d, M, 8, metric);
index_pq->do_polysemous_training = true;
index_1 = index_pq;
if (add_idmap) {
IndexIDMap *idmap = new IndexIDMap(index_1);
del_index.set (idmap);
idmap->own_fields = true;
index_1 = idmap;
add_idmap = false;
}
}
} else if (stok == "RFlat") {
make_IndexRefineFlat = true;
Expand All @@ -765,9 +774,16 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
tok, description_in);
}

if (index_1 && add_idmap) {
IndexIDMap *idmap = new IndexIDMap(index_1);
del_index.set (idmap);
idmap->own_fields = true;
index_1 = idmap;
add_idmap = false;
}

if (vt_1) {
vt = vt_1;
del_vt.set (vt);
vts.chain.push_back (vt_1);
}

if (coarse_quantizer_1) {
Expand All @@ -793,10 +809,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
"IDMap option not used\n");
}

if (vt) {
IndexPreTransform *index_pt = new IndexPreTransform (vt, index);
del_vt.release ();
if (vts.chain.size() > 0) {
IndexPreTransform *index_pt = new IndexPreTransform (index);
index_pt->own_fields = true;
// add from back
while (vts.chain.size() > 0) {
index_pt->prepend_transform (vts.chain.back());
vts.chain.pop_back ();
}
index = index_pt;
}

Expand Down
7 changes: 7 additions & 0 deletions AuxIndexStructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ void RangeSearchPartialResult::set_result (bool incremental)
}


/***********************************************************************
* IDSelectorRange
***********************************************************************/

IDSelectorRange::IDSelectorRange (idx_t imin, idx_t imax):
imin (imin), imax (imax)
{
Expand All @@ -169,6 +173,9 @@ bool IDSelectorRange::is_member (idx_t id) const
}


/***********************************************************************
* IDSelectorBatch
***********************************************************************/

IDSelectorBatch::IDSelectorBatch (long n, const idx_t *indices)
{
Expand Down
9 changes: 0 additions & 9 deletions AuxIndexStructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
#define FAISS_AUX_INDEX_STRUCTURES_H

#include <vector>

#if __cplusplus >= 201103L
#include <unordered_set>
#endif

#include <set>


#include "Index.h"
Expand Down Expand Up @@ -80,11 +75,7 @@ struct IDSelectorRange: IDSelector {
* hash collisions if lsb's are always the same */
struct IDSelectorBatch: IDSelector {

#if __cplusplus >= 201103L
std::unordered_set<idx_t> set;
#else
std::set<idx_t> set;
#endif

typedef unsigned char uint8_t;
std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
Expand Down
6 changes: 5 additions & 1 deletion FaissException.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// Copyright 2004-present Facebook. All Rights Reserved.

#include "FaissException.h"
#include <cstdio>

namespace faiss {

Expand All @@ -28,4 +27,9 @@ FaissException::FaissException(const std::string& m,
funcName, file, line, m.c_str());
}

const char*
FaissException::what() const noexcept {
return msg.c_str();
}

}
4 changes: 1 addition & 3 deletions FaissException.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class FaissException : public std::exception {
int line);

/// from std::exception
const char* what() const noexcept override
{ return msg.c_str(); }
~FaissException () noexcept override {}
const char* what() const noexcept override;

std::string msg;
};
Expand Down
77 changes: 64 additions & 13 deletions IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,28 @@ void IndexIVF::add (idx_t n, const float * x)
add_with_ids (n, x, nullptr);
}

void IndexIVF::make_direct_map ()
void IndexIVF::make_direct_map (bool new_maintain_direct_map)
{
if (maintain_direct_map) return;

direct_map.resize (ntotal, -1);
for (size_t key = 0; key < nlist; key++) {
const std::vector<long> & idlist = ids[key];

for (long ofs = 0; ofs < idlist.size(); ofs++) {
direct_map [idlist [ofs]] =
key << 32 | ofs;
// nothing to do
if (new_maintain_direct_map == maintain_direct_map)
return;

if (new_maintain_direct_map) {
direct_map.resize (ntotal, -1);
for (size_t key = 0; key < nlist; key++) {
const std::vector<long> & idlist = ids[key];

for (long ofs = 0; ofs < idlist.size(); ofs++) {
FAISS_THROW_IF_NOT_MSG (
0 <= idlist [ofs] && idlist[ofs] < ntotal,
"direct map supported only for seuquential ids");
direct_map [idlist [ofs]] = key << 32 | ofs;
}
}
} else {
direct_map.clear ();
}

maintain_direct_map = true;
maintain_direct_map = new_maintain_direct_map;
}


Expand Down Expand Up @@ -183,7 +190,6 @@ void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)




IndexIVF::~IndexIVF()
{
if (own_fields) delete quantizer;
Expand Down Expand Up @@ -217,6 +223,8 @@ void IndexIVFFlat::add_core (idx_t n, const float * x, const long *xids,

{
FAISS_THROW_IF_NOT (is_trained);
FAISS_THROW_IF_NOT_MSG (!(maintain_direct_map && xids),
"cannot have direct map and add with ids");
const long * idx;
ScopeDeleter<long> del;

Expand Down Expand Up @@ -477,6 +485,49 @@ void IndexIVFFlat::copy_subset_to (IndexIVFFlat & other, int subset_type,
}
}

void IndexIVFFlat::update_vectors (int n, idx_t *new_ids, const float *x)
{
FAISS_THROW_IF_NOT (maintain_direct_map);
FAISS_THROW_IF_NOT (is_trained);
std::vector<idx_t> assign (n);
quantizer->assign (n, x, assign.data());

for (int i = 0; i < n; i++) {
idx_t id = new_ids[i];
FAISS_THROW_IF_NOT_MSG (0 <= id && id < ntotal,
"id to update out of range");
{ // remove old one
long dm = direct_map[id];
long ofs = dm & 0xffffffff;
long il = dm >> 32;
size_t l = ids[il].size();
if (ofs != l - 1) {
long id2 = ids[il].back();
ids[il][ofs] = id2;
direct_map[id2] = (il << 32) | ofs;
memcpy (vecs[il].data() + ofs * d,
vecs[il].data() + (l - 1) * d,
d * sizeof(vecs[il][0]));
}
ids[il].pop_back();
vecs[il].resize((l - 1) * d);
}
{ // insert new one
long il = assign[i];
size_t l = ids[il].size();
long dm = (il << 32) | l;
direct_map[id] = dm;
ids[il].push_back (id);
vecs[il].resize((l + 1) * d);
memcpy (vecs[il].data() + l * d,
x + i * d,
d * sizeof(vecs[il][0]));
}
}

}




void IndexIVFFlat::reset()
Expand Down
19 changes: 16 additions & 3 deletions IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ struct IndexIVF: Index {
size_t get_list_size (size_t list_no) const
{ return ids[list_no].size(); }


/// intialize a direct map
void make_direct_map ();
/** intialize a direct map
*
* @param new_maintain_direct_map if true, create a direct map,
* else clear it
*/
void make_direct_map (bool new_maintain_direct_map=true);

/// 1= perfectly balanced, >1: imbalanced
double imbalance_factor () const;
Expand Down Expand Up @@ -184,6 +187,16 @@ struct IndexIVFFlat: IndexIVF {
const long * keys,
float_maxheap_array_t * res) const;

/** Update a subset of vectors.
*
* The index must have a direct_map
*
* @param nv nb of vectors to update
* @param idx vector indices to update, size nv
* @param v vectors of new values, size nv*d
*/
void update_vectors (int nv, idx_t *idx, const float *v);

void reconstruct(idx_t key, float* recons) const override;

void merge_from_residuals(IndexIVF& other) override;
Expand Down
Loading

0 comments on commit f7aedbd

Please sign in to comment.