Skip to content

Commit 16c4783

Browse files
TensorFlow Hub Authorsandresusanopinto
TensorFlow Hub Authors
authored andcommitted
Adding module_search/utils.py
PiperOrigin-RevId: 275830769
1 parent 9bc54d7 commit 16c4783

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
licenses(["notice"]) # Apache 2.0 License
16+
17+
py_library(
18+
name = "utils",
19+
srcs = ["utils.py"],
20+
deps = [
21+
"//tensorflow_hub:expect_numpy_installed",
22+
"//tensorflow_hub:expect_tensorflow_installed",
23+
],
24+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Module search
2+
3+
*WIP* - tool to rank a list of modules for use in a downstream task.
+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utils for module search functionality."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
import tensorflow as tf
23+
import tensorflow_hub as hub
24+
import tensorflow_datasets as tfds
25+
26+
27+
def compute_distance_matrix(x_train, x_test, measure="squared_l2"):
28+
"""Calculates the distance matrix between test and train.
29+
30+
Args:
31+
x_train: Matrix (NxD) where each row represents a training sample
32+
x_test: Matrix (MxD) where each row represents a test sample
33+
measure: Distance measure (not necessarly metric) to use
34+
35+
Raises:
36+
NotImplementedError: When the measure is not implemented
37+
38+
Returns:
39+
Matrix (MxN) where elemnt i,j is the distance between
40+
x_test_i and x_train_j.
41+
"""
42+
43+
x_train = tf.convert_to_tensor(x_train, tf.float64)
44+
x_test = tf.convert_to_tensor(x_test, tf.float64)
45+
46+
if measure == "squared_l2":
47+
x_xt = tf.matmul(x_test, tf.transpose(x_train)).numpy()
48+
49+
x_train_2 = tf.reduce_sum(tf.math.square(x_train), 1).numpy()
50+
x_test_2 = tf.reduce_sum(tf.math.square(x_test), 1).numpy()
51+
52+
for i in range(np.shape(x_xt)[0]):
53+
x_xt[i, :] = np.multiply(x_xt[i, :], -2)
54+
x_xt[i, :] = np.add(x_xt[i, :], x_test_2[i])
55+
x_xt[i, :] = np.add(x_xt[i, :], x_train_2)
56+
57+
else:
58+
raise NotImplementedError("Method '{}' is not implemented".format(measure))
59+
60+
return x_xt
61+
62+
63+
def compute_distance_matrix_loo(x, measure="squared_l2"):
64+
"""Calculates the distance matrix for leave-one-out strategy.
65+
66+
Args:
67+
x: Matrix (NxD) where each row represents a sample
68+
measure: Distance measure (not necessarly metric) to use
69+
70+
Raises:
71+
NotImplementedError: When the measure is not implemented
72+
73+
Returns:
74+
Matrix (NxN) where elemnt i,j is the distance between x_i and x_j.
75+
The diagonal is set to infinity
76+
"""
77+
78+
x = tf.convert_to_tensor(x, tf.float64)
79+
80+
if measure == "squared_l2":
81+
x_xt = tf.matmul(x, tf.transpose(x)).numpy()
82+
diag = np.diag(x_xt)
83+
d = np.copy(x_xt)
84+
85+
for i in range(np.shape(d)[0]):
86+
d[i, :] = np.multiply(d[i, :], -2)
87+
d[i, :] = np.add(d[i, :], x_xt[i, i])
88+
d[i, :] = np.add(d[i, :], diag)
89+
d[i, i] = float("inf")
90+
91+
elif measure == "cosine":
92+
d = tf.matmul(x, tf.transpose(x)).numpy()
93+
diag_sqrt = np.sqrt(np.diag(d))
94+
outer = np.outer(diag_sqrt, diag_sqrt)
95+
d = np.ones(np.shape(d)) - np.divide(d, outer)
96+
np.fill_diagonal(d, float("inf"))
97+
98+
else:
99+
raise NotImplementedError("Method '{}' is not implemented".format(measure))
100+
101+
return d
102+
103+
104+
def knn_errorrate(d, y_train, y_test, k=1):
105+
"""Calculate the knn error rate based on the distance matrix d.
106+
107+
Args:
108+
d: distance matrix
109+
y_train: label vector for the training samples
110+
y_test: label vector for the test samples
111+
k: number of direct neighbors for knn
112+
113+
Returns:
114+
knn error rate (1 - accuracy)
115+
"""
116+
117+
if k == 1:
118+
indices = np.argmin(d, axis=1)
119+
120+
cnt = 0
121+
for i in range(len(indices)):
122+
if y_test[i] != y_train[indices[i]]:
123+
cnt += 1
124+
125+
return float(cnt) / len(indices)
126+
127+
indices = np.argpartition(d, k - 1, axis=1)
128+
cnt = 0
129+
for i in range(np.shape(d)[0]):
130+
cnt_i = 0
131+
for j in range(k):
132+
if y_test[i] != y_train[indices[i, j]]:
133+
cnt_i += 1
134+
if cnt_i >= k / 2.0:
135+
cnt += 1
136+
137+
return float(cnt) / np.shape(d)[0]
138+
139+
140+
def knn_errorrate_loo(d, y, k=1):
141+
"""Calculate the leave-one-out expected knn error rate based
142+
on the distance matrix d.
143+
144+
Args:
145+
d: distance matrix, the diagonal should be infinity
146+
y: label matrix
147+
k: number of direct neighbors for knn
148+
149+
Returns:
150+
Expected leave-one-out knn error rate (1 - accuracy)
151+
"""
152+
153+
if k == 1:
154+
indices = np.argmin(d, axis=1)
155+
156+
cnt = 0
157+
for i in range(len(indices)):
158+
if y[i] != y[indices[i]]:
159+
cnt += 1
160+
161+
return float(cnt) / len(indices)
162+
163+
indices = np.argpartition(d, k - 1, axis=1)
164+
cnt = 0
165+
for i in range(np.shape(d)[0]):
166+
cnt_i = 0
167+
for j in range(k):
168+
if y[i] != y[indices[i, j]]:
169+
cnt_i += 1
170+
if cnt_i >= k / 2.0:
171+
cnt += 1
172+
173+
return float(cnt) / np.shape(d)[0]
174+
175+
176+
def load_data(dataset, split, num_examples=None):
177+
ds = tfds.load(dataset, split=split, shuffle_files=False)
178+
if num_examples:
179+
ds = ds.take(num_examples)
180+
return ds
181+
182+
183+
def load_embedding_fn(module):
184+
m = hub.load(module, tags=[])
185+
return lambda x: m.signatures["default"](x)["default"]

0 commit comments

Comments
 (0)