Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing #321

Merged
merged 6 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,351 changes: 2,676 additions & 2,675 deletions examples/indexing_colab.ipynb

Large diffs are not rendered by default.

3,182 changes: 1,592 additions & 1,590 deletions examples/multimodal_example.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions examples/sampler_io_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,7 @@
{
"name": "stdout",
"output_type": "stream",
"text": [
]
"text": []
},
{
"data": {
Expand Down Expand Up @@ -943,7 +942,11 @@
"# use the test split for indexing and querying\n",
"print(\"\\n\" + \"#\" * 10 + \" Test Sampler \" + \"#\" * 10)\n",
"test_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(\n",
" \"oxford_iiit_pet\", splits=\"test\", total_examples_per_class=20, classes_per_batch=tfds_classes_per_batch, preprocess_fn=resize\n",
" \"oxford_iiit_pet\",\n",
" splits=\"test\",\n",
" total_examples_per_class=20,\n",
" classes_per_batch=tfds_classes_per_batch,\n",
" preprocess_fn=resize,\n",
")"
]
},
Expand Down
15 changes: 8 additions & 7 deletions examples/supervised/visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,11 @@
"\n",
"print(f\"Class IDs seen during training {train_cls}\")\n",
"\n",
"\n",
"def img_augmentation(img_batch, y, *args):\n",
" # random resize and crop.\n",
" batch_size = tf.shape(img_batch)[0]\n",
" img_batch = tf.image.random_crop(img_batch, (batch_size,target_img_size,target_img_size,3))\n",
" img_batch = tf.image.random_crop(img_batch, (batch_size, target_img_size, target_img_size, 3))\n",
" # random horizontal flip\n",
" img_batch = tf.image.random_flip_left_right(img_batch)\n",
" return img_batch, y\n",
Expand Down Expand Up @@ -603,10 +604,10 @@
"\n",
"# building model\n",
"model = tfsim.architectures.EfficientNetSim(\n",
" train_ds.example_shape, \n",
" train_ds.example_shape,\n",
" embedding_size,\n",
" pooling=\"gem\", # Can change to use `gem` -> GeneralizedMeanPooling2D\n",
" gem_p=3.0, # Increase the contrast between activations in the feature map.\n",
" pooling=\"gem\", # Can change to use `gem` -> GeneralizedMeanPooling2D\n",
" gem_p=3.0, # Increase the contrast between activations in the feature map.\n",
")"
]
},
Expand Down Expand Up @@ -685,12 +686,12 @@
"loss = tfsim.losses.CircleLoss(gamma=gamma)\n",
"\n",
"# Create an NMSLib Search instance using Brute Force search.\n",
"# This will be slower but avoids any errors associated with an \n",
"# This will be slower but avoids any errors associated with an\n",
"# aproximate nearest neighbor search.\n",
"brute_force_search = tfsim.search.NMSLibSearch(\n",
" distance=distance,\n",
" dim=embedding_size,\n",
" method='brute_force',\n",
" method=\"brute_force\",\n",
" # space_params = None,\n",
" # data_type = tfsim.search.nmslib_search.nmslib.DataType.DENSE_VECTOR,\n",
" # dtype = tfsim.search.nmslib_search.nmslib.DistType.FLOAT,\n",
Expand All @@ -700,7 +701,7 @@
"\n",
"# compiling and training\n",
"model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(LR), \n",
" optimizer=tf.keras.optimizers.Adam(LR),\n",
" loss=loss,\n",
" distance=distance,\n",
" search=brute_force_search,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/supervised_hello_world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@
"source": [
"def get_model():\n",
" inputs = tf.keras.layers.Input(shape=(28, 28, 1))\n",
" x = tf.keras.layers.Rescaling(1. / 255.)(inputs)\n",
" x = tf.keras.layers.Rescaling(1.0 / 255.0)(inputs)\n",
" x = tf.keras.layers.Conv2D(32, 3, activation=\"relu\")(x)\n",
" x = tf.keras.layers.Conv2D(32, 3, activation=\"relu\")(x)\n",
" x = tf.keras.layers.MaxPool2D()(x)\n",
Expand Down
27 changes: 12 additions & 15 deletions examples/unsupervised_hello_world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@
" INIT_LR = 1e-3 # Initial LR for the learning rate schedule, see section B.1 in the paper.\n",
" TEMPERATURE = 0.5 # Tuned for CIFAR10, see section B.9 in the paper.\n",
"elif ALGORITHM == \"vicreg\":\n",
" INIT_LR = 1e-3 "
" INIT_LR = 1e-3"
]
},
{
Expand All @@ -445,10 +445,7 @@
"outputs": [],
"source": [
"def img_scaling(img):\n",
" return tf.keras.applications.imagenet_utils.preprocess_input(\n",
" img, \n",
" data_format=None, \n",
" mode='torch')\n",
" return tf.keras.applications.imagenet_utils.preprocess_input(img, data_format=None, mode=\"torch\")\n",
"\n",
"\n",
"@tf.function\n",
Expand All @@ -474,10 +471,10 @@
" img = tfsim.augmenters.augmentation_utils.cropping.crop_and_resize(\n",
" img, CIFAR_IMG_SIZE, CIFAR_IMG_SIZE, area_range=area_range\n",
" )\n",
" \n",
"\n",
" # The following transforms expect the data to be [0, 1]\n",
" img /= 255.\n",
" \n",
" img /= 255.0\n",
"\n",
" # random color jitter\n",
" def _jitter_transform(x):\n",
" return tfsim.augmenters.augmentation_utils.color_jitter.color_jitter_rand(\n",
Expand All @@ -503,10 +500,10 @@
"\n",
" # random horizontal flip\n",
" img = tf.image.random_flip_left_right(img)\n",
" \n",
"\n",
" # scale the data back to [0, 255]\n",
" img = img * 255.\n",
" img = tf.clip_by_value(img, 0., 255.)\n",
" img = img * 255.0\n",
" img = tf.clip_by_value(img, 0.0, 255.0)\n",
"\n",
" return img\n",
"\n",
Expand Down Expand Up @@ -688,7 +685,7 @@
"metadata": {},
"outputs": [],
"source": [
"projector = None # Passing None will automatically build the default projector.\n",
"projector = None # Passing None will automatically build the default projector.\n",
"\n",
"# Uncomment to build a custom projector.\n",
"# def get_projector(input_dim, dim, activation=\"relu\", num_layers: int = 3):\n",
Expand Down Expand Up @@ -724,7 +721,7 @@
"# return projector\n",
"\n",
"# projector = get_projector(input_dim=backbone.output.shape[-1], dim=DIM, num_layers=2)\n",
"# projector.summary()\n"
"# projector.summary()"
]
},
{
Expand All @@ -742,7 +739,7 @@
"metadata": {},
"outputs": [],
"source": [
"predictor = None # Passing None will automatically build the default predictor.\n",
"predictor = None # Passing None will automatically build the default predictor.\n",
"\n",
"# Uncomment to build a custom predictor.\n",
"# def get_predictor(input_dim, hidden_dim=512, activation=\"relu\"):\n",
Expand Down Expand Up @@ -1126,7 +1123,7 @@
" )\n",
" # random horizontal flip\n",
" img = tf.image.random_flip_left_right(img)\n",
" img = tf.clip_by_value(img, 0., 255.)\n",
" img = tf.clip_by_value(img, 0.0, 255.0)\n",
"\n",
" return img"
]
Expand Down
23 changes: 13 additions & 10 deletions tensorflow_similarity/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# internal
from .distances import Distance
from .evaluators import Evaluator, MemoryEvaluator
from .evaluators import Evaluator
from .search import LinearSearch, NMSLibSearch, Search, make_search
from .stores import MemoryStore, Store, make_store
from .types import FloatTensor, Lookup, PandasDataFrame, Tensor
Expand Down Expand Up @@ -108,15 +108,6 @@ def __init__(
self.kv_store_type = kv_store if isinstance(kv_store, str) else type(kv_store).__name__
if isinstance(kv_store, Store):
self.kv_store: Store = kv_store
# initialize internal structures
self._init_structures()

def reset(self) -> None:
"Reinitialize the indexer"
self._init_structures()

def _init_structures(self) -> None:
"(re)initialize internal storage structure"

if self.search_type == "nmslib":
self.search = NMSLibSearch(distance=self.distance, dim=self.embedding_size)
Expand All @@ -143,6 +134,18 @@ def _init_structures(self) -> None:
# self.kv_store should have been already initialized
raise ValueError("You need to either supply a know key value " "store name or a Store() object")

# initialize internal structures
self._init_structures()

def reset(self) -> None:
"Reinitialize the indexer"
self.search.reset()
self.kv_store.reset()
self._init_structures()

def _init_structures(self) -> None:
"(re)initialize internal stats structure"

# stats
self._stats: DefaultDict[str, int] = defaultdict(int)
self._lookup_timings_buffer: Deque[float] = deque([], maxlen=self.stat_buffer_size)
Expand Down
32 changes: 19 additions & 13 deletions tensorflow_similarity/search/faiss_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from termcolor import cprint

from tensorflow_similarity.distances import Distance
from tensorflow_similarity.distances import Distance, distance_canonicalizer
from tensorflow_similarity.types import FloatTensor

from .search import Search
Expand Down Expand Up @@ -66,20 +66,22 @@ def __init__(
f"| - normalize: {self.normalize}",
]
cprint("\n".join(t_msg) + "\n", "green")
self.reset()

def reset(self):
if self.algo == "ivfpq":
assert dim % m == 0, f"dim={dim}, m={m}"
assert self.dim % self.m == 0, f"dim={self.dim}, m={self.m}"
if self.algo == "ivfpq":
metric = faiss.METRIC_L2
prefix = ""
if distance == "cosine":
if self.distance == distance_canonicalizer("cosine"):
prefix = "L2norm,"
metric = faiss.METRIC_INNER_PRODUCT
# this distance requires both the input and query vectors to be normalized
ivf_string = f"IVF{nlist},"
pq_string = f"PQ{m}x{nbits}"
ivf_string = f"IVF{self.nlist},"
pq_string = f"PQ{self.m}x{self.nbits}"
factory_string = prefix + ivf_string + pq_string
self.index = faiss.index_factory(dim, factory_string, metric)
self.index = faiss.index_factory(self.dim, factory_string, metric)
# quantizer = faiss.IndexFlatIP(
# dim
# ) # we keep the same L2 distance flat index
Expand All @@ -91,16 +93,16 @@ def __init__(
# dim
# ) # we keep the same L2 distance flat index
# self.index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits)
self.index.nprobe = nprobe # set how many of nearest cells to search
elif algo == "flat":
if distance == "cosine":
self.index.nprobe = self.nprobe # set how many of nearest cells to search
elif self.algo == "flat":
if self.distance == distance_canonicalizer("cosine"):
# this is exact match using cosine/dot-product Distance
self.index = faiss.IndexFlatIP(dim)
elif distance == "l2":
self.index = faiss.IndexFlatIP(self.dim)
elif self.distance == distance_canonicalizer("l2"):
# this is exact match using L2 distance
self.index = faiss.IndexFlatL2(dim)
self.index = faiss.IndexFlatL2(self.dim)
else:
raise ValueError(f"distance {distance} not supported")
raise ValueError(f"distance {self.distance} not supported")

def is_built(self):
return self.algo == "flat" or self.index.is_trained
Expand Down Expand Up @@ -162,6 +164,7 @@ def batch_add(
idxs: Sequence[int],
verbose: int = 1,
normalize: bool = True,
build: bool = True,
**kwargs,
):
"""Add a batch of embeddings to the search index.
Expand All @@ -175,6 +178,9 @@ def batch_add(
"""
if normalize:
faiss.normalize_L2(embeddings)
if build and not self.is_built():
print("building Faiss index")
self.build_index(samples=embeddings, normalize=normalize)
if self.algo != "flat":
# flat does not accept indexes as parameters and assumes incremental
# indexes
Expand Down
17 changes: 10 additions & 7 deletions tensorflow_similarity/search/linear_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, distance: Distance | str, dim: int, verbose: int = 0, name: s
f"| - name: {self.name}",
]
cprint("\n".join(t_msg) + "\n", "green")
self.db: List[FloatTensor] = []
self.ids: List[int] = []
self.reset()

def is_built(self):
return True
Expand All @@ -68,17 +67,17 @@ def batch_lookup(
k: Number of nearest neighboors embedding to lookup. Defaults to 5.
"""

items = len(self.ids)
if normalize:
query = tf.math.l2_normalize(embeddings, axis=1)
else:
query = embeddings
db_tensor = tf.convert_to_tensor(self.db)
sims = self.distance(query, db_tensor)
similarity, id_idxs = tf.math.top_k(sims, k)
dists = self.distance(query, db_tensor)
dists, id_idxs = tf.math.top_k(tf.math.negative(dists), k)
dists = tf.math.negative(dists)
id_idxs = id_idxs.numpy()
ids_array = np.array(self.ids)
return list(np.array([ids_array[x] for x in id_idxs])), list(similarity)
return list(np.array([ids_array[x] for x in id_idxs])), list(dists)

def lookup(self, embedding: FloatTensor, k: int = 5, normalize: bool = True) -> tuple[list[int], list[float]]:
"""Find embedding K nearest neighboors embeddings.
Expand All @@ -100,7 +99,7 @@ def add(self, embedding: FloatTensor, idx: int, verbose: int = 1, normalize: boo
allow to lookup the data associated with a given embedding.
"""
if normalize:
embedding = tf.math.l2_normalize(np.array([embedding], dtype=tf.keras.backend.floatx()), axis=1)
embedding = tf.math.l2_normalize(np.array([embedding], dtype=tf.keras.backend.floatx()), axis=1)[0]
self.ids.append(idx)
self.db.append(embedding)

Expand Down Expand Up @@ -150,6 +149,10 @@ def load(self, path: str):
self.db = data[0]
self.ids = data[1]

def reset(self):
self.db: List[FloatTensor] = []
self.ids: List[int] = []

def __make_config_path(self, path):
return Path(path) / "config.json"

Expand Down
Loading