Skip to content

Commit 578bd3a

Browse files
tanzhenyutensorflower-gardener
authored andcommitted
Move clustering ops to core.
PiperOrigin-RevId: 228808275
1 parent 3fc2b09 commit 578bd3a

16 files changed

+221
-157
lines changed

tensorflow/contrib/factorization/BUILD

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ tf_custom_op_py_library(
2828
"python/ops/wals.py",
2929
],
3030
dso = [
31-
":python/ops/_clustering_ops.so",
3231
":python/ops/_factorization_ops.so",
3332
],
3433
kernels = [
@@ -38,12 +37,12 @@ tf_custom_op_py_library(
3837
srcs_version = "PY2AND3",
3938
deps = [
4039
":factorization_ops_test_utils_py",
41-
":gen_clustering_ops",
4240
":gen_factorization_ops",
4341
"//tensorflow/contrib/framework:framework_py",
4442
"//tensorflow/contrib/util:util_py",
4543
"//tensorflow/python:array_ops",
4644
"//tensorflow/python:check_ops",
45+
"//tensorflow/python:clustering_ops_gen",
4746
"//tensorflow/python:control_flow_ops",
4847
"//tensorflow/python:data_flow_ops",
4948
"//tensorflow/python:embedding_ops",
@@ -77,17 +76,6 @@ py_library(
7776
],
7877
)
7978

80-
# Ops
81-
tf_custom_op_library(
82-
name = "python/ops/_clustering_ops.so",
83-
srcs = [
84-
"ops/clustering_ops.cc",
85-
],
86-
deps = [
87-
"//tensorflow/contrib/factorization/kernels:clustering_ops",
88-
],
89-
)
90-
9179
tf_custom_op_library(
9280
name = "python/ops/_factorization_ops.so",
9381
srcs = [
@@ -100,26 +88,16 @@ tf_custom_op_library(
10088
)
10189

10290
tf_gen_op_libs([
103-
"clustering_ops",
10491
"factorization_ops",
10592
])
10693

10794
cc_library(
10895
name = "all_ops",
10996
deps = [
110-
":clustering_ops_op_lib",
11197
":factorization_ops_op_lib",
11298
],
11399
)
114100

115-
tf_gen_op_wrapper_py(
116-
name = "gen_clustering_ops",
117-
out = "python/ops/gen_clustering_ops.py",
118-
deps = [
119-
":clustering_ops_op_lib",
120-
],
121-
)
122-
123101
tf_gen_op_wrapper_py(
124102
name = "gen_factorization_ops",
125103
out = "python/ops/gen_factorization_ops.py",

tensorflow/contrib/factorization/kernels/BUILD

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
1111
cc_library(
1212
name = "all_kernels",
1313
deps = [
14-
":clustering_ops",
1514
":masked_matmul_ops",
1615
":wals_solver_ops",
1716
"@protobuf_archive//:protobuf_headers",
@@ -29,17 +28,6 @@ cc_library(
2928
alwayslink = 1,
3029
)
3130

32-
cc_library(
33-
name = "clustering_ops",
34-
srcs = ["clustering_ops.cc"],
35-
deps = [
36-
"//tensorflow/core:framework_headers_lib",
37-
"//third_party/eigen3",
38-
"@protobuf_archive//:protobuf_headers",
39-
],
40-
alwayslink = 1,
41-
)
42-
4331
cc_library(
4432
name = "masked_matmul_ops",
4533
srcs = ["masked_matmul_ops.cc"],
@@ -51,19 +39,3 @@ cc_library(
5139
],
5240
alwayslink = 1,
5341
)
54-
55-
tf_cc_test(
56-
name = "clustering_ops_test",
57-
srcs = ["clustering_ops_test.cc"],
58-
deps = [
59-
":clustering_ops",
60-
"//tensorflow/contrib/factorization:clustering_ops_op_lib",
61-
"//tensorflow/core:core_cpu",
62-
"//tensorflow/core:framework",
63-
"//tensorflow/core:lib",
64-
"//tensorflow/core:protos_all_cc",
65-
"//tensorflow/core:test",
66-
"//tensorflow/core:test_main",
67-
"//tensorflow/core:testlib",
68-
],
69-
)

tensorflow/contrib/factorization/ops/clustering_ops.cc

Lines changed: 0 additions & 91 deletions
This file was deleted.

tensorflow/contrib/factorization/python/ops/clustering_ops.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,23 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from tensorflow.contrib.factorization.python.ops import gen_clustering_ops
22-
# go/tf-wildcard-import
23-
# pylint: disable=wildcard-import
24-
from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import *
25-
# pylint: enable=wildcard-import
26-
from tensorflow.contrib.util import loader
2721
from tensorflow.python.framework import constant_op
2822
from tensorflow.python.framework import dtypes
2923
from tensorflow.python.framework import ops
3024
from tensorflow.python.ops import array_ops
3125
from tensorflow.python.ops import check_ops
3226
from tensorflow.python.ops import control_flow_ops
27+
from tensorflow.python.ops import gen_clustering_ops
3328
from tensorflow.python.ops import math_ops
3429
from tensorflow.python.ops import nn_impl
3530
from tensorflow.python.ops import random_ops
3631
from tensorflow.python.ops import state_ops
3732
from tensorflow.python.ops import variable_scope
3833
from tensorflow.python.ops.embedding_ops import embedding_lookup
39-
from tensorflow.python.platform import resource_loader
40-
41-
_clustering_ops = loader.load_op_library(
42-
resource_loader.get_path_to_datafile('_clustering_ops.so'))
34+
# go/tf-wildcard-import
35+
# pylint: disable=wildcard-import
36+
from tensorflow.python.ops.gen_clustering_ops import *
37+
# pylint: enable=wildcard-import
4338

4439
# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\)
4540
# which is the square root of the sum of the absolute squares of the elements

tensorflow/core/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ tf_gen_op_libs(
10741074
"tensor_forest_ops",
10751075
"candidate_sampling_ops",
10761076
"checkpoint_ops",
1077+
"clustering_ops",
10771078
"collective_ops",
10781079
"control_flow_ops",
10791080
"ctc_ops",
@@ -1228,6 +1229,7 @@ cc_library(
12281229
":tensor_forest_ops_op_lib",
12291230
":candidate_sampling_ops_op_lib",
12301231
":checkpoint_ops_op_lib",
1232+
":clustering_ops_op_lib",
12311233
":collective_ops_op_lib",
12321234
":control_flow_ops_op_lib",
12331235
":ctc_ops_op_lib",
@@ -1382,6 +1384,7 @@ cc_library(
13821384
"//tensorflow/core/kernels:tensor_forest_ops",
13831385
"//tensorflow/core/kernels:candidate_sampler_ops",
13841386
"//tensorflow/core/kernels:checkpoint_ops",
1387+
"//tensorflow/core/kernels:clustering_ops",
13851388
"//tensorflow/core/kernels:collective_ops",
13861389
"//tensorflow/core/kernels:control_flow_ops",
13871390
"//tensorflow/core/kernels:ctc_ops",
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
op {
2+
graph_op_name: "KMC2ChainInitialization"
3+
visibility: HIDDEN
4+
in_arg {
5+
name: "distances"
6+
description: <<END
7+
Vector with squared distances to the closest previously sampled cluster center
8+
for each candidate point.
9+
END
10+
}
11+
in_arg {
12+
name: "seed"
13+
description: <<END
14+
Scalar. Seed for initializing the random number generator.
15+
END
16+
}
17+
out_arg {
18+
name: "index"
19+
description: <<END
20+
Scalar with the index of the sampled point.
21+
END
22+
}
23+
summary: "Returns the index of a data point that should be added to the seed set."
24+
description: <<END
25+
Entries in distances are assumed to be squared distances of candidate points to
26+
the already sampled centers in the seed set. The op constructs one Markov chain
27+
of the k-MC^2 algorithm and returns the index of one candidate point to be added
28+
as an additional cluster center.
29+
END
30+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
op {
2+
graph_op_name: "KmeansPlusPlusInitialization"
3+
visibility: HIDDEN
4+
in_arg {
5+
name: "points"
6+
description: <<END
7+
Matrix of shape (n, d). Rows are assumed to be input points.
8+
END
9+
}
10+
in_arg {
11+
name: "num_to_sample"
12+
description: <<END
13+
Scalar. The number of rows to sample. This value must not be larger than n.
14+
END
15+
}
16+
in_arg {
17+
name: "seed"
18+
description: <<END
19+
Scalar. Seed for initializing the random number generator.
20+
END
21+
}
22+
in_arg {
23+
name: "num_retries_per_sample"
24+
description: <<END
25+
Scalar. For each row that is sampled, this parameter
26+
specifies the number of additional points to draw from the current
27+
distribution before selecting the best. If a negative value is specified, a
28+
heuristic is used to sample O(log(num_to_sample)) additional points.
29+
END
30+
}
31+
out_arg {
32+
name: "samples"
33+
description: <<END
34+
Matrix of shape (num_to_sample, d). The sampled rows.
35+
END
36+
}
37+
summary: "Selects num_to_sample rows of input using the KMeans++ criterion."
38+
description: <<END
39+
Rows of points are assumed to be input points. One row is selected at random.
40+
Subsequent rows are sampled with probability proportional to the squared L2
41+
distance from the nearest row selected thus far till num_to_sample rows have
42+
been sampled.
43+
END
44+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
op {
2+
graph_op_name: "NearestNeighbors"
3+
visibility: HIDDEN
4+
in_arg {
5+
name: "points"
6+
description: <<END
7+
Matrix of shape (n, d). Rows are assumed to be input points.
8+
END
9+
}
10+
in_arg {
11+
name: "centers"
12+
description: <<END
13+
Matrix of shape (m, d). Rows are assumed to be centers.
14+
END
15+
}
16+
in_arg {
17+
name: "k"
18+
description: <<END
19+
Number of nearest centers to return for each point. If k is larger than m, then
20+
only m centers are returned.
21+
END
22+
}
23+
out_arg {
24+
name: "nearest_center_indices"
25+
description: <<END
26+
Matrix of shape (n, min(m, k)). Each row contains the indices of the centers
27+
closest to the corresponding point, ordered by increasing distance.
28+
END
29+
}
30+
out_arg {
31+
name: "nearest_center_distances"
32+
description: <<END
33+
Matrix of shape (n, min(m, k)). Each row contains the squared L2 distance to the
34+
corresponding center in nearest_center_indices.
35+
END
36+
}
37+
summary: "Selects the k nearest centers for each point."
38+
description: <<END
39+
Rows of points are assumed to be input points. Rows of centers are assumed to be
40+
the list of candidate centers. For each point, the k centers that have least L2
41+
distance to it are computed.
42+
END
43+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "KMC2ChainInitialization"
3+
visibility: HIDDEN
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "KmeansPlusPlusInitialization"
3+
visibility: HIDDEN
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
op {
2+
graph_op_name: "NearestNeighbors"
3+
visibility: HIDDEN
4+
}

0 commit comments

Comments
 (0)