diff --git a/README.md b/README.md index 2edb735..b8b0a86 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,9 @@ English | [简体中文](README_CN.md) #### *NEWS!* +- [April 16] [*CurvFaiss*](examples/curvfaiss/README.md), the nearest neighbor search tool in the non-Euclidean space, is available now! - [March 29] Our paper *[AMCAD: Adaptive Mixed-Curvature Representation based Advertisement Retrieval System](https://arxiv.org/abs/2203.14683)* has been accepted by ICDE'22! -- [November 03] The [Detailed explanation](https://mp.weixin.qq.com/s/uP_wU5nnd7faBoo5B_7Xfw?spm=ata.21736010.0.0.5b9d2c7eQ67WZg) is posted on Alimama Tech Blog! +- [November 03] The [detailed explanation](https://mp.weixin.qq.com/s/uP_wU5nnd7faBoo5B_7Xfw?spm=ata.21736010.0.0.5b9d2c7eQ67WZg) is posted on Alimama Tech Blog! ## Why Non-Euclidean Geometry diff --git a/README_CN.md b/README_CN.md index 3ef0395..c1257d1 100644 --- a/README_CN.md +++ b/README_CN.md @@ -5,6 +5,7 @@ [English](./README.md) | 简体中文 #### *更新!* +- [04/16] 非欧最近邻检索工具[*CurvFaiss*](examples/curvfaiss/README.md)现已发布! - [03/29] [*AMCAD: Adaptive Mixed-Curvature Representation based Advertisement Retrieval System*](https://arxiv.org/abs/2203.14683) 发表在ICDE'22! - [11/03] [*曲率学习框架详解*](https://mp.weixin.qq.com/s/uP_wU5nnd7faBoo5B_7Xfw?spm=ata.21736010.0.0.5b9d2c7eQ67WZg) 发布在阿里妈妈技术公众号! diff --git a/docs/sample.png b/docs/sample.png index 0e1089c..05531f6 100644 Binary files a/docs/sample.png and b/docs/sample.png differ diff --git a/examples/curvfaiss/README.md b/examples/curvfaiss/README.md new file mode 100644 index 0000000..26953ef --- /dev/null +++ b/examples/curvfaiss/README.md @@ -0,0 +1,47 @@ +# CurvFaiss + +## Introduction + +```CurvFaiss``` is a library for efficient similarity search and clustering of dense vectors in non-Euclidean manifolds. + +Based on [*Faiss*](https://github.com/facebookresearch/faiss), ```CurvFaiss``` develops a new Index ```IndexFlatStereographic``` to support nearest neighbors searching with stereographic distance metric. Together with ```CurvLearn```, non-Euclidean model training and efficient inference are feasible. + +Currently ```CurvFaiss``` supports retrieving neighbors in Hyperbolic, Euclidean, Spherical space. The indexing method is based on exact searching. Due to the parallelism in both the data level and instruction level, the indices can be built in less than two hours for 100 million nodes. + +To those who want to apply on their own customized metric or optimize the indexing method, a hands-on [*tutorial*](customized.md) is also provided. + +## Installation + +```CurvFaiss``` requires curvlearn and python3. + +The preferred way for installing is via `pip`. + +```bash +pip install curvfaiss +``` + +Since the source codes are compiled under CentOS, as for other platforms, we recommend users follow the [*tutorial*](customized.md) to solve the code dependency. + +## Usage + +A frequent problem is the runtime dependency. +```bash +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c "from os.path import abspath,dirname,join; import curvlearn as cl; print(join(dirname((dirname(cl.__file__))),'curvfaiss'))"` +``` + +Since ```IndexFlatStereographic``` is inherited from ```IndexFlat```, the usgae is the same with ```IndexFlatL2``` in faiss except with an additional parameter ```curvature```. + +```python +import curvfaiss + +# build index, retrievaling in hyperbolic, euclidean, spherical metric with respect to curvature < 0, = 0, > 0 +index = curvfaiss.IndexFlatStereographic(dim, curvature) +index.add(embedding) + +# knn search +knn_distance, knn_index = index.search(query, topk) + +print(knn_distance, knn_index) +``` + +See the [*full demo*](knn.py) here! \ No newline at end of file diff --git a/examples/curvfaiss/customized.md b/examples/curvfaiss/customized.md new file mode 100644 index 0000000..4221edc --- /dev/null +++ b/examples/curvfaiss/customized.md @@ -0,0 +1,18 @@ +# Tutorial on developing customized distance metric in Faiss + +## Build the environment +1) build the docker from [*Dockerfile*](https://github.com/facebookresearch/faiss/blob/main/Dockerfile). +2) install swig4, cmake, anaconda with py=3.7/numpy. +3) follow the installation [*wiki*](https://github.com/facebookresearch/faiss/wiki/Installing-Faiss#compiling-the-python-interface-within-an-anaconda-install), [*readme*](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md#building-from-source) to verify the correctness. + + +## Develop the new metric +Here is an [*example*](https://github.com/XuZhirong/faiss/commit/a7d5466884d924226aa1d43dd5a41dd44606817c). + +The main modifications include: declaring the index in ```IndexFlat```, implementing the metric in ```utils```. + +## Compile and Publish +1) remove the build files and rebuild. +2) put *.so files to the proper location. +3) write [*setup.py*](https://github.com/facebookresearch/faiss/blob/main/faiss/python/setup.py), where the ```package_data``` field should include the *.so files. +4) upload to pypi/docker. [*Instructions*](https://zhuanlan.zhihu.com/p/61174349) are here. \ No newline at end of file diff --git a/examples/curvfaiss/knn.py b/examples/curvfaiss/knn.py new file mode 100644 index 0000000..ba7e2bf --- /dev/null +++ b/examples/curvfaiss/knn.py @@ -0,0 +1,91 @@ +# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl import app +from absl import flags + +import numpy as np +import curvfaiss +from curvlearn import manifolds +import tensorflow as tf + +FLAGS = flags.FLAGS +flags.DEFINE_float("curvature", 0.0, "Curvature of stereographic model, default is 0") +flags.DEFINE_integer("batch", 16, "Embedding batch size") +flags.DEFINE_integer("dim", 8, "Embedding dimension") +flags.DEFINE_integer("topk", 10, "Find the top k nearest neighbors, default is 10") + + +def generate_sample(): + np.random.seed(2022) + embedding = np.random.random((FLAGS.batch, FLAGS.dim)).astype('float32') + id = np.arange(FLAGS.batch) + + def curvlearn_op(emb): + manifold = manifolds.Stereographic() + emb = manifold.to_manifold(tf.constant(emb, dtype=tf.float32), FLAGS.curvature) + + # src = [emb[0],..,emb[0],emb[1],...,emb[1],...] + # dst = [emb[0],emb[1],...,emb[batch],emb[0],...] + src = tf.reshape(tf.tile(emb, [1, FLAGS.batch]), [-1, FLAGS.dim]) + dst = tf.tile(emb, [FLAGS.batch, 1]) + distance = tf.reshape(manifold.distance(src, dst, FLAGS.curvature), [FLAGS.batch, FLAGS.batch]) + + sess = tf.Session() + emb, distance = sess.run([emb, distance]) + + return emb, distance + + embedding, distance = curvlearn_op(embedding) + + knn_index = np.argsort(distance, axis=-1) + knn_id = np.array([[id[col] for col in row[:FLAGS.topk]] for row in knn_index]) + + return embedding, id, knn_id + + +def faiss_knn(emb, id, query=None): + def build_index(embedding): + """ + Stereographic distance only supports brute-force indexing method for now. + """ + n, d = embedding.shape[0], embedding.shape[1] + + index = curvfaiss.IndexFlatStereographic(d, FLAGS.curvature) + index.add(embedding) + + return index + + if query is None: + query = emb + + index = build_index(emb) + D, I = index.search(query, FLAGS.topk) + I = np.array([[id[col] for col in row] for row in I]) + + return I + + +def main(argv): + del argv + + embedding, id, golden_knn = generate_sample() + knn = faiss_knn(embedding, id) + + print("curvfaiss sanity check: {}!".format("passed" if np.array_equal(golden_knn, knn) else "failed")) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file