Skip to content

Commit 90ab9c1

Browse files
talumbautensorflower-gardener
authored andcommitted
Add a Resource for KV Cache buffer storage
PiperOrigin-RevId: 620097541
1 parent 8ae67af commit 90ab9c1

File tree

5 files changed

+181
-1
lines changed

5 files changed

+181
-1
lines changed

tensorflow/lite/experimental/resource/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,32 @@ package(
66
licenses = ["notice"],
77
)
88

9+
cc_library(
10+
name = "cache_buffer",
11+
srcs = ["cache_buffer.cc"],
12+
hdrs = [
13+
"cache_buffer.h",
14+
"//tensorflow/lite/core/c:common.h",
15+
],
16+
deps = [
17+
":resource",
18+
"//tensorflow/lite/core/c:c_api_types",
19+
"//tensorflow/lite/core/c:common",
20+
"//tensorflow/lite/kernels:kernel_util",
21+
"//tensorflow/lite/kernels/internal:compatibility",
22+
],
23+
)
24+
25+
cc_test(
26+
name = "cache_buffer_test",
27+
srcs = ["cache_buffer_test.cc"],
28+
deps = [
29+
":cache_buffer",
30+
"//tensorflow/lite/c:common",
31+
"@com_google_googletest//:gtest_main",
32+
],
33+
)
34+
935
cc_library(
1036
name = "resource",
1137
srcs = [
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/experimental/resource/cache_buffer.h"
17+
18+
#include <cstdlib>
19+
#include <cstring>
20+
21+
#include "tensorflow/lite/core/c/c_api_types.h"
22+
#include "tensorflow/lite/core/c/common.h"
23+
#include "tensorflow/lite/kernels/internal/compatibility.h"
24+
#include "tensorflow/lite/kernels/kernel_util.h"
25+
26+
namespace tflite {
27+
namespace resource {
28+
29+
constexpr char kCacheBufferTensorName[] = "CacheBuffer";
30+
31+
TfLiteStatus CacheBuffer::Initialize(const TfLiteIntArray &shape,
32+
const TfLiteType &type) {
33+
// Set basic parameters.
34+
tensor_.name = kCacheBufferTensorName;
35+
tensor_.allocation_type = kTfLiteDynamic;
36+
tensor_.type = type;
37+
38+
// Set the shape and allocate the memory.
39+
tensor_.dims = TfLiteIntArrayCopy(&shape);
40+
const size_t num_bytes = TfLiteTypeGetSize(type) * NumElements(&tensor_);
41+
TfLiteTensorRealloc(num_bytes, &tensor_);
42+
43+
memset(tensor_.data.raw, 0, tensor_.bytes);
44+
is_initialized_ = true;
45+
return kTfLiteOk;
46+
}
47+
48+
size_t CacheBuffer::GetNumEntries() const { return num_entries_; }
49+
50+
void CacheBuffer::SetNumEntries(size_t count) {
51+
TFLITE_DCHECK(count <= tensor_.dims->data[2]);
52+
num_entries_ = count;
53+
}
54+
55+
} // namespace resource
56+
} // namespace tflite
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_CACHE_BUFFER_H_
16+
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_CACHE_BUFFER_H_
17+
18+
#include <memory>
19+
#include <unordered_map>
20+
21+
#include "tensorflow/lite/core/c/common.h"
22+
#include "tensorflow/lite/experimental/resource/resource_variable.h"
23+
#include "tensorflow/lite/kernels/kernel_util.h"
24+
25+
namespace tflite {
26+
namespace resource {
27+
28+
/// WARNING: Experimental interface, subject to change.
29+
// A Cache Buffer class. Useful for keeping the keys and values of a
30+
// transformer block attention mechanism in autoregressive decode.
31+
// Ops can access this buffer and add tensors to it. It also keeps track of the
32+
// number of used entries in the cache.
33+
class CacheBuffer : public ResourceVariable {
34+
public:
35+
CacheBuffer() = default;
36+
CacheBuffer(const CacheBuffer &) = delete;
37+
CacheBuffer &operator=(const CacheBuffer &) = delete;
38+
// Initialize tensor of a certain shape using the provided type.
39+
TfLiteStatus Initialize(const TfLiteIntArray &shape, const TfLiteType &type);
40+
size_t GetNumEntries() const;
41+
void SetNumEntries(size_t count);
42+
43+
private:
44+
// The number of entries currently used in the buffer;
45+
size_t num_entries_ = 0;
46+
};
47+
48+
} // namespace resource
49+
} // namespace tflite
50+
51+
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_CACHE_BUFFER_H_
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/experimental/resource/cache_buffer.h"
16+
17+
#include <gtest/gtest.h>
18+
#include "tensorflow/lite/c/common.h"
19+
20+
namespace tflite {
21+
namespace resource {
22+
23+
TEST(CacheBufferTest, Initialize) {
24+
TfLiteIntArray* shape = TfLiteIntArrayCreate(4);
25+
shape->data[0] = 1;
26+
shape->data[1] = 3;
27+
shape->data[2] = 5;
28+
shape->data[3] = 7;
29+
30+
TfLiteType type = kTfLiteFloat32;
31+
CacheBuffer cache_buffer;
32+
cache_buffer.Initialize(*shape, type);
33+
34+
EXPECT_EQ(cache_buffer.GetTensor()->type, type);
35+
EXPECT_EQ(cache_buffer.GetTensor()->dims->size, 4);
36+
EXPECT_EQ(cache_buffer.GetTensor()->dims->data[0], 1);
37+
EXPECT_EQ(cache_buffer.GetTensor()->dims->data[1], 3);
38+
EXPECT_EQ(cache_buffer.GetTensor()->bytes, 420);
39+
ASSERT_NE(cache_buffer.GetTensor()->data.raw, nullptr);
40+
EXPECT_EQ(cache_buffer.GetNumEntries(), 0);
41+
cache_buffer.SetNumEntries(3);
42+
EXPECT_EQ(cache_buffer.GetNumEntries(), 3);
43+
TfLiteIntArrayFree(shape);
44+
}
45+
46+
} // namespace resource
47+
} // namespace tflite

tensorflow/lite/experimental/resource/resource_variable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class ResourceVariable : public ResourceBase {
5050
return is_initialized_ ? tensor_.bytes : 0;
5151
}
5252

53-
private:
53+
protected:
5454
// The tensor (and its buffer stored in `tensor_.data` is fully owned by
5555
// the `ResourceVariable` object.
5656
TfLiteTensor tensor_;

0 commit comments

Comments
 (0)