Skip to content

Commit

Permalink
read/write index with std::function wrapper (facebookresearch#427)
Browse files Browse the repository at this point in the history
* add access function to IndexIVF;

* - access for IndexIVF;
- write_index/read_index with std::function<...>;

* - fix test compile on mac;
- adjust write/read with std::function;

* replace std::function with IOReader/IOWriter;

* remove IndexIVF::access // tmp

* PFN_WRITE/READ => WRITE;

* revert mac compile fix;

* rename;

* fix compile;

* reset CMakeList;

* format; remove unused function/header;
  • Loading branch information
dengoswei authored and mdouze committed May 24, 2018
1 parent 433f5c0 commit abe2b0f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 43 deletions.
1 change: 0 additions & 1 deletion IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,6 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
}



IndexIVF::~IndexIVF()
{
if (own_invlists) {
Expand Down
129 changes: 88 additions & 41 deletions index_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ static uint32_t fourcc (const char sx[4]) {
**************************************************************/


#define WRITEANDCHECK(ptr, n) { \
size_t ret = fwrite (ptr, sizeof (* (ptr)), n, f); \
FAISS_THROW_IF_NOT_MSG (ret == (n), "write error"); \
#define WRITEANDCHECK(ptr, n) { \
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
FAISS_THROW_IF_NOT_MSG(ret == (n), "write error"); \
}

#define READANDCHECK(ptr, n) { \
size_t ret = fread (ptr, sizeof (* (ptr)), n, f); \
FAISS_THROW_IF_NOT_MSG (ret == (n), "read error"); \
#define READANDCHECK(ptr, n) { \
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
FAISS_THROW_IF_NOT_MSG(ret == (n), "read error"); \
}

#define WRITE1(x) WRITEANDCHECK(&(x), 1)
Expand All @@ -106,15 +106,41 @@ struct ScopeFileCloser {
~ScopeFileCloser () {fclose (f); }
};

namespace {

struct FileIOReader: IOReader {
FILE *f = nullptr;

FileIOReader(FILE *rf): f(rf) {}

~FileIOReader() = default;

virtual size_t operator()(
void *ptr, size_t size, size_t nitems) override {
return fread(ptr, size, nitems, f);
}
};

struct FileIOWriter: IOWriter {
FILE *f = nullptr;

FileIOWriter(FILE *wf): f(wf) {}
~FileIOWriter() = default;

virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) override {
return fwrite(ptr, size, nitems, f);
}
};


} // namespace


/*************************************************************
* Write
**************************************************************/

static void write_index_header (const Index *idx, FILE *f) {
static void write_index_header (const Index *idx, IOWriter *f) {
WRITE1 (idx->d);
WRITE1 (idx->ntotal);
Index::idx_t dummy = 1 << 20;
Expand All @@ -124,7 +150,7 @@ static void write_index_header (const Index *idx, FILE *f) {
WRITE1 (idx->metric_type);
}

void write_VectorTransform (const VectorTransform *vt, FILE *f) {
void write_VectorTransform (const VectorTransform *vt, IOWriter *f) {
if (const LinearTransform * lt =
dynamic_cast < const LinearTransform *> (vt)) {
if (dynamic_cast<const RandomRotationMatrix *>(lt)) {
Expand Down Expand Up @@ -167,14 +193,16 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) {
WRITE1 (vt->is_trained);
}

static void write_ProductQuantizer (const ProductQuantizer *pq, FILE *f) {
static void write_ProductQuantizer (
const ProductQuantizer *pq, IOWriter *f) {
WRITE1 (pq->d);
WRITE1 (pq->M);
WRITE1 (pq->nbits);
WRITEVECTOR (pq->centroids);
}

static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {

static void write_ScalarQuantizer (
const ScalarQuantizer *ivsc, IOWriter *f) {
WRITE1 (ivsc->qtype);
WRITE1 (ivsc->rangestat);
WRITE1 (ivsc->rangestat_arg);
Expand All @@ -183,7 +211,7 @@ static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {
WRITEVECTOR (ivsc->trained);
}

static void write_InvertedLists (const InvertedLists *ils, FILE *f) {
static void write_InvertedLists (const InvertedLists *ils, IOWriter *f) {
if (ils == nullptr) {
uint32_t h = fourcc ("il00");
WRITE1 (h);
Expand Down Expand Up @@ -258,10 +286,12 @@ void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) {
FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
ScopeFileCloser closer(f);
write_ProductQuantizer (pq, f);

FileIOWriter writer(f);
write_ProductQuantizer (pq, &writer);
}

static void write_HNSW (const HNSW *hnsw, FILE *f) {
static void write_HNSW (const HNSW *hnsw, IOWriter *f) {

WRITEVECTOR (hnsw->assign_probas);
WRITEVECTOR (hnsw->cum_nneighbor_per_level);
Expand All @@ -274,10 +304,9 @@ static void write_HNSW (const HNSW *hnsw, FILE *f) {
WRITE1 (hnsw->efConstruction);
WRITE1 (hnsw->efSearch);
WRITE1 (hnsw->upper_beam);

}

static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
static void write_ivf_header (const IndexIVF *ivf, IOWriter *f) {
write_index_header (ivf, f);
WRITE1 (ivf->nlist);
WRITE1 (ivf->nprobe);
Expand All @@ -286,7 +315,7 @@ static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
WRITEVECTOR (ivf->direct_map);
}

void write_index (const Index *idx, FILE *f) {
void write_index (const Index *idx, IOWriter *f) {
if (const IndexFlat * idxf = dynamic_cast<const IndexFlat *> (idx)) {
uint32_t h = fourcc (
idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" :
Expand Down Expand Up @@ -418,6 +447,11 @@ void write_index (const Index *idx, FILE *f) {
}
}

void write_index (const Index *idx, FILE *f) {
FileIOWriter writer(f);
write_index(idx, &writer);
}

void write_index (const Index *idx, const char *fname) {
FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
Expand All @@ -429,14 +463,16 @@ void write_VectorTransform (const VectorTransform *vt, const char *fname) {
FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
ScopeFileCloser closer(f);
write_VectorTransform (vt, f);

FileIOWriter writer(f);
write_VectorTransform (vt, &writer);
}

/*************************************************************
* Read
**************************************************************/

static void read_index_header (Index *idx, FILE *f) {
static void read_index_header (Index *idx, IOReader *f) {
READ1 (idx->d);
READ1 (idx->ntotal);
Index::idx_t dummy;
Expand All @@ -447,7 +483,7 @@ static void read_index_header (Index *idx, FILE *f) {
idx->verbose = false;
}

VectorTransform* read_VectorTransform (FILE *f) {
VectorTransform* read_VectorTransform (IOReader *f) {
uint32_t h;
READ1 (h);
VectorTransform *vt = nullptr;
Expand Down Expand Up @@ -497,7 +533,7 @@ VectorTransform* read_VectorTransform (FILE *f) {


static void read_ArrayInvertedLists_sizes (
FILE *f, std::vector<size_t> & sizes)
IOReader *f, std::vector<size_t> & sizes)
{
size_t nlist = sizes.size();
uint32_t list_type;
Expand All @@ -518,8 +554,7 @@ static void read_ArrayInvertedLists_sizes (
}
}


InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
InvertedLists *read_InvertedLists (IOReader *f, int io_flags) {
uint32_t h;
READ1 (h);
if (h == fourcc ("il00")) {
Expand All @@ -545,23 +580,27 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
}
return ails;
} else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) {
auto impl = dynamic_cast<FileIOReader*>(f);
FAISS_THROW_IF_NOT(NULL != impl);
FILE *raw_f = impl->f;

auto ails = new OnDiskInvertedLists ();
READ1 (ails->nlist);
READ1 (ails->code_size);
ails->read_only = true;
ails->lists.resize (ails->nlist);
std::vector<size_t> sizes (ails->nlist);
read_ArrayInvertedLists_sizes (f, sizes);
size_t o0 = ftell (f), o = o0;
size_t o0 = ftell (raw_f), o = o0;
{ // do the mmap
struct stat buf;
int ret = fstat (fileno(f), &buf);
int ret = fstat (fileno(raw_f), &buf);
FAISS_THROW_IF_NOT_FMT (ret == 0,
"fstat failed: %s", strerror(errno));
ails->totsize = buf.st_size;
ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize,
PROT_READ, MAP_SHARED,
fileno (f), 0);
fileno (raw_f), 0);
FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED,
"could not mmap: %s",
strerror(errno));
Expand All @@ -574,7 +613,7 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
ails->code_size);
}
// resume normal reading of file
fseek (f, o, SEEK_SET);
fseek (raw_f, o, SEEK_SET);
return ails;
} else if (h == fourcc ("ilod")) {
OnDiskInvertedLists *od = new OnDiskInvertedLists();
Expand All @@ -601,24 +640,24 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
}
}

static void read_InvertedLists (IndexIVF *ivf, FILE *f, int io_flags) {
static void read_InvertedLists (
IndexIVF *ivf, IOReader *f, int io_flags) {
InvertedLists *ils = read_InvertedLists (f, io_flags);
FAISS_THROW_IF_NOT (ils->nlist == ivf->nlist &&
ils->code_size == ivf->code_size);
ivf->invlists = ils;
ivf->own_invlists = true;
}


static void read_ProductQuantizer (ProductQuantizer *pq, FILE *f) {

static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) {
READ1 (pq->d);
READ1 (pq->M);
READ1 (pq->nbits);
pq->set_derived_values ();
READVECTOR (pq->centroids);
}

static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) {
READ1 (ivsc->qtype);
READ1 (ivsc->rangestat);
READ1 (ivsc->rangestat_arg);
Expand All @@ -628,7 +667,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
}


static void read_HNSW (HNSW *hnsw, FILE *f) {
static void read_HNSW (HNSW *hnsw, IOReader *f) {
READVECTOR (hnsw->assign_probas);
READVECTOR (hnsw->cum_nneighbor_per_level);
READVECTOR (hnsw->levels);
Expand All @@ -648,14 +687,16 @@ ProductQuantizer * read_ProductQuantizer (const char*fname) {
ScopeFileCloser closer(f);
ProductQuantizer *pq = new ProductQuantizer();
ScopeDeleter1<ProductQuantizer> del (pq);
read_ProductQuantizer(pq, f);

FileIOReader reader(f);
read_ProductQuantizer(pq, &reader);
del.release ();
return pq;
}

static void read_ivf_header (
IndexIVF * ivf, FILE *f,
std::vector<std::vector<Index::idx_t> > *ids = nullptr)
IndexIVF *ivf, IOReader *f,
std::vector<std::vector<Index::idx_t> > *ids = nullptr)
{
read_index_header (ivf, f);
READ1 (ivf->nlist);
Expand Down Expand Up @@ -683,7 +724,7 @@ static ArrayInvertedLists *set_array_invlist(
return ail;
}

static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)
static IndexIVFPQ *read_ivfpq (IOReader *f, uint32_t h, int io_flags)
{
bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ");

Expand Down Expand Up @@ -720,7 +761,7 @@ static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)

int read_old_fmt_hack = 0;

Index *read_index (FILE * f, int io_flags) {
Index *read_index (IOReader *f, int io_flags) {
Index * idx = nullptr;
uint32_t h;
READ1 (h);
Expand Down Expand Up @@ -913,6 +954,10 @@ Index *read_index (FILE * f, int io_flags) {
}


Index *read_index (FILE * f, int io_flags) {
FileIOReader reader(f);
return read_index(&reader, io_flags);
}

Index *read_index (const char *fname, int io_flags) {
FILE *f = fopen (fname, "r");
Expand All @@ -929,7 +974,9 @@ VectorTransform *read_VectorTransform (const char *fname) {
perror ("");
abort ();
}
VectorTransform *vt = read_VectorTransform (f);

FileIOReader reader(f);
VectorTransform *vt = read_VectorTransform (&reader);
fclose (f);
return vt;
}
Expand Down
21 changes: 20 additions & 1 deletion index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@ struct Index;
struct VectorTransform;
struct IndexIVF;
struct ProductQuantizer;
struct IOReader;
struct IOWriter;

void write_index (const Index *idx, FILE *f);
void write_index (const Index *idx, const char *fname);

void write_index (const Index *idx, IOWriter *writer);


const int IO_FLAG_MMAP = 1;
const int IO_FLAG_READ_ONLY = 2;

Index *read_index (FILE * f, int io_flags = 0);
Index *read_index (const char *fname, int io_flags = 0);

Index *read_index (IOReader *reader, int io_flags = 0);


void write_VectorTransform (const VectorTransform *vt, const char *fname);
Expand All @@ -55,6 +59,21 @@ struct Cloner {
virtual ~Cloner() {}
};

struct IOReader {
// fread
virtual size_t operator()(
void *ptr, size_t size, size_t nitems) = 0;
virtual ~IOReader() {}
};

struct IOWriter {
// fwrite
virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) = 0;

virtual ~IOWriter() {}
};

}

#endif

0 comments on commit abe2b0f

Please sign in to comment.