Skip to content

Add tiktoken #3015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@
[submodule "examples/third-party/LLaVA"]
path = examples/third-party/LLaVA
url = https://github.com/haotian-liu/LLaVA.git
[submodule "examples/models/llama2/third-party/re2"]
path = examples/models/llama2/third-party/re2
url = https://github.com/google/re2.git
[submodule "examples/models/llama2/third-party/abseil-cpp"]
path = examples/models/llama2/third-party/abseil-cpp
url = https://github.com/abseil/abseil-cpp.git
19 changes: 16 additions & 3 deletions examples/models/llama2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ project(llama_runner)
# Duplicating options as root CMakeLists.txt
option(EXECUTORCH_BUILD_OPTIMIZED "Build the optimized kernels" OFF)

option(EXECUTORCH_BUILD_RE2 "Build RE2" OFF)

include(CMakeDependentOption)
#
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
Expand Down Expand Up @@ -86,8 +88,19 @@ endif()

# llama_runner library
add_subdirectory(runner)

set(link_libraries)
if(EXECUTORCH_BUILD_RE2)
# find RE2 for tokenizer
set(ABSL_ENABLE_INSTALL ON)
set(_pic_flag
${CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2)
set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
target_link_libraries(llama_runner PUBLIC re2::re2)
endif()

set(link_libraries gflags)
set(_srcs main.cpp)

if(EXECUTORCH_BUILD_OPTIMIZED)
Expand Down Expand Up @@ -162,7 +175,7 @@ if(CMAKE_BUILD_TYPE EQUAL "RELEASE")
endif()

target_include_directories(llama_main PUBLIC ${_common_include_directories})
target_link_libraries(llama_main PUBLIC gflags llama_runner ${link_libraries})
target_link_libraries(llama_main PUBLIC llama_runner ${link_libraries})
target_compile_options(llama_main PUBLIC ${_common_compile_options})

if(APPLE)
Expand Down
10 changes: 9 additions & 1 deletion examples/models/llama2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ DEFINE_int32(
-1,
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");

DEFINE_bool(
use_tiktoken,
false,
"Use Tiktoken tokenizer instead of the default BPE tokenizer.");

int32_t main(int32_t argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

Expand All @@ -57,6 +62,8 @@ int32_t main(int32_t argc, char** argv) {

int32_t cpu_threads = FLAGS_cpu_threads;

bool use_tiktoken = FLAGS_use_tiktoken;

#if defined(ET_USE_THREADPOOL)
uint32_t num_performant_cores = cpu_threads == -1
? torch::executorch::cpuinfo::get_num_performant_cores()
Expand All @@ -69,7 +76,8 @@ int32_t main(int32_t argc, char** argv) {
}
#endif
// create llama runner
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
::torch::executor::Runner runner(
model_path, tokenizer_path, temperature, use_tiktoken);

// generate
runner.generate(prompt, seq_len);
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def generate( # noqa: C901
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
temperature: float = 0.8,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
Expand Down
13 changes: 10 additions & 3 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <executorch/examples/models/llama2/runner/runner.h>
#include <executorch/examples/models/llama2/tokenizer/bpe_tokenizer.h>
#include <executorch/examples/models/llama2/tokenizer/tiktoken.h>
#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/extension/runner_util/managed_tensor.h>

Expand All @@ -37,8 +38,10 @@ std::string statsToJsonString(const Runner::Stats& stats);
Runner::Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature)
: module_(std::make_unique<Module>(
const float temperature,
bool use_tiktoken)
: use_tiktoken_(use_tiktoken),
module_(std::make_unique<Module>(
model_path,
Module::MlockConfig::UseMlockIgnoreErrors)),
tokenizer_path_(tokenizer_path),
Expand Down Expand Up @@ -77,7 +80,11 @@ Error Runner::load() {
append_eos_ = getMetadataHelper("append_eos_to_prompt", false);

// Load tokenizer
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
if (use_tiktoken_) {
tokenizer_ = std::make_unique<Tiktoken>(vocab_size_, bos_id_, eos_id_);
} else {
tokenizer_ = std::make_unique<BPETokenizer>(vocab_size_, bos_id_, eos_id_);
}
tokenizer_->load(tokenizer_path_);
if (tokenizer_->bos_tok() != bos_id_) {
ET_LOG(
Expand Down
4 changes: 3 additions & 1 deletion examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class Runner {
explicit Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature = 0.8f);
const float temperature = 0.8f,
bool use_tiktoken = false);

struct Stats {
// Scaling factor for timestamps - in this case, we use ms.
Expand Down Expand Up @@ -85,6 +86,7 @@ class Runner {
int32_t n_bos_;
int32_t n_eos_;
int32_t max_seq_len_;
bool use_tiktoken_;
bool use_kv_cache_;
bool use_sdpa_with_kv_cache_;
bool append_eos_;
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/third-party/abseil-cpp
Submodule abseil-cpp added at 854193
1 change: 1 addition & 0 deletions examples/models/llama2/third-party/re2
Submodule re2 added at ac82d4
180 changes: 180 additions & 0 deletions examples/models/llama2/tokenizer/base64.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// @lint-ignore-every LICENSELINT
/**************************************************************************
Copyright (c) 2023 sewenew

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*************************************************************************/

#pragma once

#include <executorch/runtime/platform/assert.h>
#include <cassert>
#include <string>
#include <string_view>

namespace torch {
namespace executor {
namespace base64 {

std::string decode(const std::string_view& input);

namespace detail {

constexpr uint32_t DECODE_TABLE[] = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

inline void validate(uint32_t v) {
ET_CHECK_MSG(v != 255, "invalid char");
}

inline void decode(const std::string_view& input, std::string& output) {
ET_CHECK_MSG(
input.size() == 4, "input length must be 4, got %zu", input.size());

uint32_t val = 0;

uint8_t c = input[0];
auto v = DECODE_TABLE[c];
validate(v);
val = v;

c = input[1];
v = DECODE_TABLE[c];
validate(v);
val = (val << 6) | v;

c = input[2];
v = DECODE_TABLE[c];
validate(v);
val = (val << 6) | v;

c = input[3];
v = DECODE_TABLE[c];
validate(v);
val = (val << 6) | v;

output.push_back(static_cast<char>((val >> 16) & 0xFF));
output.push_back(static_cast<char>((val >> 8) & 0xFF));
output.push_back(static_cast<char>(val & 0xFF));
}

inline void decode_1_padding(
const std::string_view& input,
std::string& output) {
ET_CHECK_MSG(
input.size() == 3, "input length must be 3, got %zu", input.size());

uint32_t val = 0;

uint8_t c = input[0];
auto v = DECODE_TABLE[c];
validate(v);
val = v;

c = input[1];
v = DECODE_TABLE[c];
validate(v);
val = (val << 6) | v;

c = input[2];
v = DECODE_TABLE[c];
validate(v);
val = (val << 6) | v;

output.push_back(static_cast<char>((val >> 10) & 0xFF));
output.push_back(static_cast<char>((val >> 2) & 0xFF));
}

inline void decode_2_padding(
const std::string_view& input,
std::string& output) {
assert(input.size() == 2);

uint32_t val = 0;

uint8_t c = input[0];
auto v = DECODE_TABLE[c];
validate(v);
val = v;

c = input[1];
v = DECODE_TABLE[c];
validate(v);
val = (val << 6) | v;

output.push_back(static_cast<char>((val >> 4) & 0xFF));
}

} // namespace detail

inline std::string decode(const std::string_view& input) {
ET_CHECK_MSG(!input.empty(), "empty input");

// Faster than `input.size() % 4`.
ET_CHECK_MSG(
(input.size() & 3) == 0 && input.size() >= 4,
"input length must be larger than 4 and is multiple of 4, got %zu",
input.size());

std::string output;
output.reserve(input.size() / 4 * 3);
auto idx = 0U;
for (; idx < input.size() - 4; idx += 4) {
detail::decode(input.substr(idx, 4), output);
}

// Last 4 bytes. Might contain paddings.
if (input[idx + 3] == '=') {
if (input[idx + 2] == '=') {
// Tow paddings.
detail::decode_2_padding(input.substr(idx, 2), output);
} else {
// One padding.
detail::decode_1_padding(input.substr(idx, 3), output);
}
} else {
// No padding.
detail::decode(input.substr(idx, 4), output);
}

return output;
}

} // namespace base64

} // namespace executor
} // namespace torch
6 changes: 6 additions & 0 deletions examples/models/llama2/tokenizer/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ def define_common_targets():
name = "tokenizer",
srcs = [
"bpe_tokenizer.cpp",
"tiktoken.cpp",
],
exported_headers = [
"tokenizer.h",
"bpe_tokenizer.h",
"tiktoken.h",
"base64.h",
],
exported_deps = [
"//executorch/runtime/core/exec_aten:lib",
Expand All @@ -17,6 +20,9 @@ def define_common_targets():
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_external_deps = [
"re2",
],
)

runtime.python_library(
Expand Down
23 changes: 23 additions & 0 deletions examples/models/llama2/tokenizer/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,36 @@ def define_common_targets():
},
)

runtime.cxx_test(
name = "test_tiktoken",
srcs = [
"test_tiktoken.cpp",
],
deps = [
"//executorch/examples/models/llama2/tokenizer:tokenizer",
],
env = {
"RESOURCES_PATH": "$(location :resources_fb_only)/resources",
},
external_deps = [
"re2",
],
)

runtime.filegroup(
name = "resources",
srcs = native.glob([
"resources/**",
]),
)

runtime.filegroup(
name = "resources_fb_only",
srcs = native.glob([
"resources/fb/**",
]),
)

runtime.python_test(
name = "test_tokenizer_py",
srcs = [
Expand Down
Loading