Skip to content

Commit

Permalink
Publish CurvFaiss
Browse files Browse the repository at this point in the history
  • Loading branch information
XuZhirong committed Apr 16, 2022
1 parent e0cd89b commit 04d2d72
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
English | [简体中文](README_CN.md)

#### *NEWS!*
- [April 16] [*CurvFaiss*](examples/curvfaiss/README.md), the nearest neighbor search tool in the non-Euclidean space, is available now!
- [March 29] Our paper *[AMCAD: Adaptive Mixed-Curvature Representation based Advertisement Retrieval System](https://arxiv.org/abs/2203.14683)* has been accepted by ICDE'22!
- [November 03] The [Detailed explanation](https://mp.weixin.qq.com/s/uP_wU5nnd7faBoo5B_7Xfw?spm=ata.21736010.0.0.5b9d2c7eQ67WZg) is posted on Alimama Tech Blog!
- [November 03] The [detailed explanation](https://mp.weixin.qq.com/s/uP_wU5nnd7faBoo5B_7Xfw?spm=ata.21736010.0.0.5b9d2c7eQ67WZg) is posted on Alimama Tech Blog!

## Why Non-Euclidean Geometry

Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[English](./README.md) | 简体中文

#### *更新!*
- [04/16] 非欧最近邻检索工具[*CurvFaiss*](examples/curvfaiss/README.md)现已发布!
- [03/29] [*AMCAD: Adaptive Mixed-Curvature Representation based Advertisement Retrieval System*](https://arxiv.org/abs/2203.14683) 发表在ICDE'22!
- [11/03] [*曲率学习框架详解*](https://mp.weixin.qq.com/s/uP_wU5nnd7faBoo5B_7Xfw?spm=ata.21736010.0.0.5b9d2c7eQ67WZg) 发布在阿里妈妈技术公众号!

Expand Down
Binary file modified docs/sample.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 47 additions & 0 deletions examples/curvfaiss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# CurvFaiss

## Introduction

```CurvFaiss``` is a library for efficient similarity search and clustering of dense vectors in non-Euclidean manifolds.

Based on [*Faiss*](https://github.com/facebookresearch/faiss), ```CurvFaiss``` develops a new Index ```IndexFlatStereographic``` to support nearest neighbors searching with stereographic distance metric. Together with ```CurvLearn```, non-Euclidean model training and efficient inference are feasible.

Currently ```CurvFaiss``` supports retrieving neighbors in Hyperbolic, Euclidean, Spherical space. The indexing method is based on exact searching. Due to the parallelism in both the data level and instruction level, the indices can be built in less than two hours for 100 million nodes.

To those who want to apply on their own customized metric or optimize the indexing method, a hands-on [*tutorial*](customized.md) is also provided.

## Installation

```CurvFaiss``` requires curvlearn and python3.

The preferred way for installing is via `pip`.

```bash
pip install curvfaiss
```

Since the source codes are compiled under CentOS, as for other platforms, we recommend users follow the [*tutorial*](customized.md) to solve the code dependency.

## Usage

A frequent problem is the runtime dependency.
```bash
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c "from os.path import abspath,dirname,join; import curvlearn as cl; print(join(dirname((dirname(cl.__file__))),'curvfaiss'))"`
```

Since ```IndexFlatStereographic``` is inherited from ```IndexFlat```, the usgae is the same with ```IndexFlatL2``` in faiss except with an additional parameter ```curvature```.

```python
import curvfaiss

# build index, retrievaling in hyperbolic, euclidean, spherical metric with respect to curvature < 0, = 0, > 0
index = curvfaiss.IndexFlatStereographic(dim, curvature)
index.add(embedding)

# knn search
knn_distance, knn_index = index.search(query, topk)

print(knn_distance, knn_index)
```

See the [*full demo*](knn.py) here!
18 changes: 18 additions & 0 deletions examples/curvfaiss/customized.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Tutorial on developing customized distance metric in Faiss

## Build the environment
1) build the docker from [*Dockerfile*](https://github.com/facebookresearch/faiss/blob/main/Dockerfile).
2) install swig4, cmake, anaconda with py=3.7/numpy.
3) follow the installation [*wiki*](https://github.com/facebookresearch/faiss/wiki/Installing-Faiss#compiling-the-python-interface-within-an-anaconda-install), [*readme*](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md#building-from-source) to verify the correctness.


## Develop the new metric
Here is an [*example*](https://github.com/XuZhirong/faiss/commit/a7d5466884d924226aa1d43dd5a41dd44606817c).

The main modifications include: declaring the index in ```IndexFlat```, implementing the metric in ```utils```.

## Compile and Publish
1) remove the build files and rebuild.
2) put *.so files to the proper location.
3) write [*setup.py*](https://github.com/facebookresearch/faiss/blob/main/faiss/python/setup.py), where the ```package_data``` field should include the *.so files.
4) upload to pypi/docker. [*Instructions*](https://zhuanlan.zhihu.com/p/61174349) are here.
91 changes: 91 additions & 0 deletions examples/curvfaiss/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from absl import app
from absl import flags

import numpy as np
import curvfaiss
from curvlearn import manifolds
import tensorflow as tf

FLAGS = flags.FLAGS
flags.DEFINE_float("curvature", 0.0, "Curvature of stereographic model, default is 0")
flags.DEFINE_integer("batch", 16, "Embedding batch size")
flags.DEFINE_integer("dim", 8, "Embedding dimension")
flags.DEFINE_integer("topk", 10, "Find the top k nearest neighbors, default is 10")


def generate_sample():
np.random.seed(2022)
embedding = np.random.random((FLAGS.batch, FLAGS.dim)).astype('float32')
id = np.arange(FLAGS.batch)

def curvlearn_op(emb):
manifold = manifolds.Stereographic()
emb = manifold.to_manifold(tf.constant(emb, dtype=tf.float32), FLAGS.curvature)

# src = [emb[0],..,emb[0],emb[1],...,emb[1],...]
# dst = [emb[0],emb[1],...,emb[batch],emb[0],...]
src = tf.reshape(tf.tile(emb, [1, FLAGS.batch]), [-1, FLAGS.dim])
dst = tf.tile(emb, [FLAGS.batch, 1])
distance = tf.reshape(manifold.distance(src, dst, FLAGS.curvature), [FLAGS.batch, FLAGS.batch])

sess = tf.Session()
emb, distance = sess.run([emb, distance])

return emb, distance

embedding, distance = curvlearn_op(embedding)

knn_index = np.argsort(distance, axis=-1)
knn_id = np.array([[id[col] for col in row[:FLAGS.topk]] for row in knn_index])

return embedding, id, knn_id


def faiss_knn(emb, id, query=None):
def build_index(embedding):
"""
Stereographic distance only supports brute-force indexing method for now.
"""
n, d = embedding.shape[0], embedding.shape[1]

index = curvfaiss.IndexFlatStereographic(d, FLAGS.curvature)
index.add(embedding)

return index

if query is None:
query = emb

index = build_index(emb)
D, I = index.search(query, FLAGS.topk)
I = np.array([[id[col] for col in row] for row in I])

return I


def main(argv):
del argv

embedding, id, golden_knn = generate_sample()
knn = faiss_knn(embedding, id)

print("curvfaiss sanity check: {}!".format("passed" if np.array_equal(golden_knn, knn) else "failed"))


if __name__ == "__main__":
app.run(main)

0 comments on commit 04d2d72

Please sign in to comment.