Skip to content

Commit 3b11daf

Browse files
committed
Mutablehashtable lookup support full size dynamic default values.
This PR is one part of RFC:tensorflow/community#237
1 parent a038cf2 commit 3b11daf

File tree

7 files changed

+124
-16
lines changed

7 files changed

+124
-16
lines changed

tensorflow/core/framework/lookup_interface.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
8383
const Tensor& default_value) {
8484
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
8585
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
86-
if (default_value.shape() != value_shape()) {
86+
TensorShape fullsize_value_shape = key.shape();
87+
for (int i = 0; i < key_shape().dims(); ++i) {
88+
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
89+
}
90+
fullsize_value_shape.AppendShape(value_shape());
91+
if (default_value.shape() != value_shape() &&
92+
default_value.shape() != fullsize_value_shape) {
8793
return errors::InvalidArgument(
88-
"Expected shape ", value_shape().DebugString(),
89-
" for default value, got ", default_value.shape().DebugString());
94+
"Expected shape ", value_shape().DebugString(), " or ",
95+
fullsize_value_shape.DebugString(), " for default value, got ",
96+
default_value.shape().DebugString());
9097
}
9198
return Status::OK();
9299
}

tensorflow/core/framework/lookup_interface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
128128
// requirements are satisfied, otherwise it returns InvalidArgument:
129129
// - DataType of the tensor keys equals to the table key_dtype
130130
// - DataType of the tensor default_value equals to the table value_dtype
131-
// - the default_value tensor shape matches the table's value shape.
131+
// - the default_value tensor has the required shape given keys and the
132+
// tables's value shape.
132133
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
133134

134135
string DebugString() const override {

tensorflow/core/kernels/lookup_table_op.cc

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {
5656

5757
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
5858
const Tensor& default_value) override {
59-
const V default_val = default_value.flat<V>()(0);
6059
const auto key_values = key.flat<K>();
6160
auto value_values = value->flat<V>();
61+
const auto default_flat = default_value.flat<V>();
62+
63+
int64 total = value_values.size();
64+
int64 default_total = default_flat.size();
65+
bool is_full_size_default = (total == default_total);
6266

6367
tf_shared_lock l(mu_);
6468
for (int64 i = 0; i < key_values.size(); ++i) {
69+
// is_full_size_default is true:
70+
// Each key has an independent default value, key_values(i)
71+
// corresponding uses default_flat(i) as its default value.
72+
//
73+
// is_full_size_default is false:
74+
// All keys will share the default_flat(0) as default value.
6575
value_values(i) = gtl::FindWithDefault(
66-
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
76+
table_, SubtleMustCopyIfIntegral(key_values(i)),
77+
is_full_size_default ? default_flat(i) : default_flat(0));
6778
}
6879

6980
return Status::OK();
@@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
173184

174185
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
175186
const Tensor& default_value) override {
176-
const auto default_flat = default_value.flat<V>();
187+
const auto default_flat = default_value.flat_inner_dims<V, 2>();
177188
const auto key_values = key.flat<K>();
178189
auto value_values = value->flat_inner_dims<V, 2>();
179190
int64 value_dim = value_shape_.dim_size(0);
180191

192+
int64 total = value_values.size();
193+
int64 default_total = default_flat.size();
194+
bool is_full_size_default = (total == default_total);
195+
181196
tf_shared_lock l(mu_);
182197
for (int64 i = 0; i < key_values.size(); ++i) {
183198
ValueArray* value_vec =
@@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
187202
value_values(i, j) = value_vec->at(j);
188203
}
189204
} else {
205+
// is_full_size_default is true:
206+
// Each key has an independent default value, key_values(i)
207+
// corresponding uses default_flat(i) as its default value.
208+
//
209+
// is_full_size_default is false:
210+
// All keys will share the default_flat(0) as default value.
190211
for (int64 j = 0; j < value_dim; j++) {
191-
value_values(i, j) = default_flat(j);
212+
value_values(i, j) =
213+
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
192214
}
193215
}
194216
}

tensorflow/core/ops/lookup_ops.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
169169
ShapeHandle handle;
170170
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
171171

172-
// Default value must be scalar or vector.
173-
ShapeHandle keys;
174-
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
175-
176172
ShapeAndType value_shape_and_type;
177173
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
178174
c,

tensorflow/core/ops/lookup_ops_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ namespace {
2525
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
2626
ShapeInferenceTestOp op("LookupTableFindV2");
2727
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
28-
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
2928
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
3029
.Input({"table_handle", 0, DT_RESOURCE})
3130
.Input({"keys", 0, DT_INT64})

tensorflow/python/kernel_tests/lookup_ops_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,6 +3375,71 @@ def testMutableHashTableFindHighRank(self):
33753375
result = self.evaluate(output)
33763376
self.assertAllEqual([[0, 1], [-1, -1]], result)
33773377

3378+
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
3379+
default_val = [-1, -1]
3380+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3381+
default_val)
3382+
3383+
input_string = constant_op.constant([["brain", "salad"],
3384+
["tank", "tarkus"]])
3385+
3386+
invalid_default_val = constant_op.constant(
3387+
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
3388+
3389+
with self.assertRaisesRegex(
3390+
(ValueError, errors_impl.InvalidArgumentError),
3391+
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
3392+
self.evaluate(table.lookup(input_string, invalid_default_val))
3393+
3394+
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
3395+
dtypes.int64)
3396+
with self.assertRaisesRegex(
3397+
(ValueError, errors_impl.InvalidArgumentError),
3398+
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
3399+
self.evaluate(table.lookup(input_string, invalid_default_val))
3400+
3401+
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
3402+
default_val = -1
3403+
keys = constant_op.constant(["brain", "salad", "surgery"])
3404+
values = constant_op.constant([0, 1, 2], dtypes.int64)
3405+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3406+
default_val)
3407+
3408+
self.evaluate(table.insert(keys, values))
3409+
self.assertAllEqual(3, self.evaluate(table.size()))
3410+
3411+
input_string = constant_op.constant([["brain", "salad"],
3412+
["tank", "tarkus"]])
3413+
3414+
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
3415+
dtypes.int64)
3416+
output = table.lookup(input_string, dynamic_default_val)
3417+
self.assertAllEqual([2, 2], output.get_shape())
3418+
3419+
result = self.evaluate(output)
3420+
self.assertAllEqual([[0, 1], [-4, -5]], result)
3421+
3422+
def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
3423+
default_val = [-1, -1]
3424+
keys = constant_op.constant(["brain", "salad", "surgery"])
3425+
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
3426+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3427+
default_val)
3428+
3429+
self.evaluate(table.insert(keys, values))
3430+
self.assertAllEqual(3, self.evaluate(table.size()))
3431+
3432+
input_string = constant_op.constant([["brain", "salad"],
3433+
["tank", "tarkus"]])
3434+
3435+
dynamic_default_val = constant_op.constant(
3436+
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
3437+
output = table.lookup(input_string, dynamic_default_val)
3438+
self.assertAllEqual([2, 2, 2], output.get_shape())
3439+
3440+
result = self.evaluate(output)
3441+
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
3442+
33783443
def testMutableHashTableInsertHighRank(self):
33793444
default_val = -1
33803445
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])

tensorflow/python/ops/lookup_ops.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,14 +1849,31 @@ def remove(self, keys, name=None):
18491849

18501850
return op
18511851

1852-
def lookup(self, keys, name=None):
1852+
def lookup(self, keys, dynamic_default_values=None, name=None):
18531853
"""Looks up `keys` in a table, outputs the corresponding values.
18541854
18551855
The `default_value` is used for keys not present in the table.
18561856
18571857
Args:
18581858
keys: Keys to look up. Can be a tensor of any shape. Must match the
18591859
table's key_dtype.
1860+
dynamic_default_values: The values to use if a key is missing in the
1861+
table. If None (by default), the `table.default_value` will be used.
1862+
Shape of `dynamic_default_values` must be same with
1863+
`table.default_value` or the lookup result tensor.
1864+
In the latter case, each key will have a different default value.
1865+
1866+
For example:
1867+
1868+
```python
1869+
keys = [0, 1, 3]
1870+
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1871+
1872+
# The key '0' will use [1, 3, 4] as default value.
1873+
# The key '1' will use [2, 3, 9] as default value.
1874+
# The key '3' will use [8, 3, 0] as default value.
1875+
```
1876+
18601877
name: A name for the operation (optional).
18611878
18621879
Returns:
@@ -1870,8 +1887,9 @@ def lookup(self, keys, name=None):
18701887
(self.resource_handle, keys, self._default_value)):
18711888
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
18721889
with ops.colocate_with(self.resource_handle):
1873-
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
1874-
self._default_value)
1890+
values = gen_lookup_ops.lookup_table_find_v2(
1891+
self.resource_handle, keys, dynamic_default_values
1892+
if dynamic_default_values is not None else self._default_value)
18751893
return values
18761894

18771895
def insert(self, keys, values, name=None):

0 commit comments

Comments
 (0)