-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
159 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# CurvFaiss | ||
|
||
## Introduction | ||
|
||
```CurvFaiss``` is a library for efficient similarity search and clustering of dense vectors in non-Euclidean manifolds. | ||
|
||
Based on [*Faiss*](https://github.com/facebookresearch/faiss), ```CurvFaiss``` develops a new Index ```IndexFlatStereographic``` to support nearest neighbors searching with stereographic distance metric. Together with ```CurvLearn```, non-Euclidean model training and efficient inference are feasible. | ||
|
||
Currently ```CurvFaiss``` supports retrieving neighbors in Hyperbolic, Euclidean, Spherical space. The indexing method is based on exact searching. Due to the parallelism in both the data level and instruction level, the indices can be built in less than two hours for 100 million nodes. | ||
|
||
To those who want to apply on their own customized metric or optimize the indexing method, a hands-on [*tutorial*](customized.md) is also provided. | ||
|
||
## Installation | ||
|
||
```CurvFaiss``` requires curvlearn and python3. | ||
|
||
The preferred way for installing is via `pip`. | ||
|
||
```bash | ||
pip install curvfaiss | ||
``` | ||
|
||
Since the source codes are compiled under CentOS, as for other platforms, we recommend users follow the [*tutorial*](customized.md) to solve the code dependency. | ||
|
||
## Usage | ||
|
||
A frequent problem is the runtime dependency. | ||
```bash | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c "from os.path import abspath,dirname,join; import curvlearn as cl; print(join(dirname((dirname(cl.__file__))),'curvfaiss'))"` | ||
``` | ||
|
||
Since ```IndexFlatStereographic``` is inherited from ```IndexFlat```, the usgae is the same with ```IndexFlatL2``` in faiss except with an additional parameter ```curvature```. | ||
|
||
```python | ||
import curvfaiss | ||
|
||
# build index, retrievaling in hyperbolic, euclidean, spherical metric with respect to curvature < 0, = 0, > 0 | ||
index = curvfaiss.IndexFlatStereographic(dim, curvature) | ||
index.add(embedding) | ||
|
||
# knn search | ||
knn_distance, knn_index = index.search(query, topk) | ||
|
||
print(knn_distance, knn_index) | ||
``` | ||
|
||
See the [*full demo*](knn.py) here! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Tutorial on developing customized distance metric in Faiss | ||
|
||
## Build the environment | ||
1) build the docker from [*Dockerfile*](https://github.com/facebookresearch/faiss/blob/main/Dockerfile). | ||
2) install swig4, cmake, anaconda with py=3.7/numpy. | ||
3) follow the installation [*wiki*](https://github.com/facebookresearch/faiss/wiki/Installing-Faiss#compiling-the-python-interface-within-an-anaconda-install), [*readme*](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md#building-from-source) to verify the correctness. | ||
|
||
|
||
## Develop the new metric | ||
Here is an [*example*](https://github.com/XuZhirong/faiss/commit/a7d5466884d924226aa1d43dd5a41dd44606817c). | ||
|
||
The main modifications include: declaring the index in ```IndexFlat```, implementing the metric in ```utils```. | ||
|
||
## Compile and Publish | ||
1) remove the build files and rebuild. | ||
2) put *.so files to the proper location. | ||
3) write [*setup.py*](https://github.com/facebookresearch/faiss/blob/main/faiss/python/setup.py), where the ```package_data``` field should include the *.so files. | ||
4) upload to pypi/docker. [*Instructions*](https://zhuanlan.zhihu.com/p/61174349) are here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from absl import app | ||
from absl import flags | ||
|
||
import numpy as np | ||
import curvfaiss | ||
from curvlearn import manifolds | ||
import tensorflow as tf | ||
|
||
FLAGS = flags.FLAGS | ||
flags.DEFINE_float("curvature", 0.0, "Curvature of stereographic model, default is 0") | ||
flags.DEFINE_integer("batch", 16, "Embedding batch size") | ||
flags.DEFINE_integer("dim", 8, "Embedding dimension") | ||
flags.DEFINE_integer("topk", 10, "Find the top k nearest neighbors, default is 10") | ||
|
||
|
||
def generate_sample(): | ||
np.random.seed(2022) | ||
embedding = np.random.random((FLAGS.batch, FLAGS.dim)).astype('float32') | ||
id = np.arange(FLAGS.batch) | ||
|
||
def curvlearn_op(emb): | ||
manifold = manifolds.Stereographic() | ||
emb = manifold.to_manifold(tf.constant(emb, dtype=tf.float32), FLAGS.curvature) | ||
|
||
# src = [emb[0],..,emb[0],emb[1],...,emb[1],...] | ||
# dst = [emb[0],emb[1],...,emb[batch],emb[0],...] | ||
src = tf.reshape(tf.tile(emb, [1, FLAGS.batch]), [-1, FLAGS.dim]) | ||
dst = tf.tile(emb, [FLAGS.batch, 1]) | ||
distance = tf.reshape(manifold.distance(src, dst, FLAGS.curvature), [FLAGS.batch, FLAGS.batch]) | ||
|
||
sess = tf.Session() | ||
emb, distance = sess.run([emb, distance]) | ||
|
||
return emb, distance | ||
|
||
embedding, distance = curvlearn_op(embedding) | ||
|
||
knn_index = np.argsort(distance, axis=-1) | ||
knn_id = np.array([[id[col] for col in row[:FLAGS.topk]] for row in knn_index]) | ||
|
||
return embedding, id, knn_id | ||
|
||
|
||
def faiss_knn(emb, id, query=None): | ||
def build_index(embedding): | ||
""" | ||
Stereographic distance only supports brute-force indexing method for now. | ||
""" | ||
n, d = embedding.shape[0], embedding.shape[1] | ||
|
||
index = curvfaiss.IndexFlatStereographic(d, FLAGS.curvature) | ||
index.add(embedding) | ||
|
||
return index | ||
|
||
if query is None: | ||
query = emb | ||
|
||
index = build_index(emb) | ||
D, I = index.search(query, FLAGS.topk) | ||
I = np.array([[id[col] for col in row] for row in I]) | ||
|
||
return I | ||
|
||
|
||
def main(argv): | ||
del argv | ||
|
||
embedding, id, golden_knn = generate_sample() | ||
knn = faiss_knn(embedding, id) | ||
|
||
print("curvfaiss sanity check: {}!".format("passed" if np.array_equal(golden_knn, knn) else "failed")) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |