Skip to content

Commit

Permalink
Move examples/qualcomm out from under the torch namespace (#5400)
Browse files Browse the repository at this point in the history
Summary:
The code under examples/... is a proxy for user code, and users should never declare code under the `torch::` or `executorch::` namespaces.

Move this code under the `example::` namespace to make it more clear that users should use their own namespaces when writing code like this.

Pull Request resolved: #5400

Test Plan:
- Built using the instructions at https://github.com/pytorch/executorch/blob/main/examples/qualcomm/README.md
- test-llama-runner-qnn-linux CI job succeeded

Reviewed By: shoumikhin

Differential Revision: D62969111

Pulled By: dbort

fbshipit-source-id: 9ec27528dd85f60d8c538d54ce6ddf621e63cf52
  • Loading branch information
dbort authored and facebook-github-bot committed Sep 18, 2024
1 parent b89c52c commit 8ef6c79
Show file tree
Hide file tree
Showing 13 changed files with 228 additions and 181 deletions.
33 changes: 24 additions & 9 deletions examples/qualcomm/executor_runner/qnn_executor_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,24 @@ DEFINE_int32(
20000000, // 20MB
"Size of the debug buffer in bytes to allocate for intermediate outputs and program outputs logging.");

using namespace torch::executor;
using torch::executor::MemoryAllocator;
using torch::executor::util::FileDataLoader;
using executorch::aten::Tensor;
using executorch::aten::TensorImpl;
using executorch::etdump::ETDumpGen;
using executorch::etdump::ETDumpResult;
using executorch::extension::FileDataLoader;
using executorch::extension::prepare_input_tensors;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::EventTracerDebugLogLevel;
using executorch::runtime::HierarchicalAllocator;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::MemoryManager;
using executorch::runtime::Method;
using executorch::runtime::MethodMeta;
using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::TensorInfo;

class CustomMemory {
public:
Expand Down Expand Up @@ -112,7 +127,7 @@ class CustomMemory {
};

int main(int argc, char** argv) {
runtime_init();
executorch::runtime::runtime_init();

gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc != 1) {
Expand Down Expand Up @@ -211,7 +226,7 @@ int main(int argc, char** argv) {
// the method can mutate the memory-planned buffers, so the method should only
// be used by a single thread at at time, but it can be reused.
//
torch::executor::ETDumpGen etdump_gen = torch::executor::ETDumpGen();
ETDumpGen etdump_gen;
Result<Method> method =
program->load_method(method_name, &memory_manager, &etdump_gen);
ET_CHECK_MSG(
Expand Down Expand Up @@ -261,7 +276,7 @@ int main(int argc, char** argv) {
}
for (int output_index = 0; output_index < method->outputs_size();
++output_index) {
const exec_aten::Tensor& t = method->get_output(output_index).toTensor();
const Tensor& t = method->get_output(output_index).toTensor();
out_custom_mem.push_back(
std::make_unique<CustomMemory>(FLAGS_shared_buffer));
std::unique_ptr<CustomMemory>& custom_mem_ptr = out_custom_mem.back();
Expand Down Expand Up @@ -415,7 +430,7 @@ int main(int argc, char** argv) {
elapsed_time / inference_index);
} else {
// if no input is provided, fill the inputs with default values
auto inputs = util::prepare_input_tensors(*method);
auto inputs = prepare_input_tensors(*method);
ET_CHECK_MSG(
inputs.ok(),
"Could not prepare inputs: 0x%" PRIx32,
Expand All @@ -434,7 +449,7 @@ int main(int argc, char** argv) {

// Dump the etdump data containing profiling/debugging data to the specified
// file.
etdump_result result = etdump_gen.get_etdump_data();
ETDumpResult result = etdump_gen.get_etdump_data();
if (result.buf != nullptr && result.size > 0) {
ET_LOG(
Info,
Expand All @@ -452,7 +467,7 @@ int main(int argc, char** argv) {
Info,
"Write debug output binary to %s, Size = %zu",
FLAGS_debug_output_path.c_str(),
FLAGS_debug_buffer_size);
(size_t)FLAGS_debug_buffer_size);
FILE* f = fopen(FLAGS_debug_output_path.c_str(), "w+");
fwrite((uint8_t*)debug_buffer, 1, FLAGS_debug_buffer_size, f);
fclose(f);
Expand Down
11 changes: 6 additions & 5 deletions examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include <fstream>
#include <vector>

using torch::executor::MemoryAllocator;

DEFINE_string(
model_path,
"qnn_llama2.pte",
Expand All @@ -49,9 +47,12 @@ DEFINE_int32(
128,
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");

int main(int argc, char** argv) {
using namespace torch::executor;
using executorch::runtime::Error;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::MethodMeta;
using executorch::runtime::Result;

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
Expand All @@ -60,7 +61,7 @@ int main(int argc, char** argv) {
int32_t seq_len = FLAGS_seq_len;

// create llama runner
Runner runner(FLAGS_model_path, tokenizer_path, temperature);
example::Runner runner(FLAGS_model_path, tokenizer_path, temperature);
ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method");

// MethodMeta describes the memory requirements of the method.
Expand Down
72 changes: 43 additions & 29 deletions examples/qualcomm/oss_scripts/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,27 @@
#include <memory>
#include <sstream>

namespace torch {
namespace executor {
using executorch::aten::ScalarType;
using executorch::aten::SizesType;
using executorch::aten::Tensor;
using executorch::extension::from_blob;
using executorch::extension::Module;
using executorch::extension::TensorPtr;
using executorch::extension::llm::BPETokenizer;
using executorch::extension::llm::Sampler;
using executorch::extension::llm::time_in_ms;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::MethodMeta;
using executorch::runtime::Result;
using executorch::runtime::TensorInfo;

// TODO: Remove this usage of an internal-only function.
using executorch::runtime::internal::set_tensor_data;

namespace example {

namespace {
using namespace executorch::extension;
static constexpr auto kTopp = 0.9f;
void printReport(const Runner::Stats& stats);
std::string statsToJsonString(const Runner::Stats& stats);
Expand Down Expand Up @@ -57,7 +73,7 @@ Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
stats_.model_load_start_ms = util::time_in_ms();
stats_.model_load_start_ms = time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));

// Read out metadata from the model
Expand Down Expand Up @@ -97,7 +113,7 @@ Error Runner::load() {
temperature_,
kTopp,
static_cast<unsigned long long>(std::time(nullptr)));
stats_.model_load_end_ms = util::time_in_ms();
stats_.model_load_end_ms = time_in_ms();

return Error::Ok;
}
Expand Down Expand Up @@ -125,7 +141,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
}

template <typename T>
int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
T* logits = logits_tensor.mutable_data_ptr<T>();

// Since the logits are for all tokens, get the last token probabilities
Expand All @@ -135,7 +151,7 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {

// Given an input token. Set up the inputs for the model and execute a single
// step. Returning the logits tensor.
Result<exec_aten::Tensor> Runner::run_model_step(
Result<Tensor> Runner::run_model_step(
int64_t input_token,
TensorPtr& token,
TensorPtr& start_pos,
Expand Down Expand Up @@ -167,7 +183,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
char* new_inp_addr = io_mem_mgr_.update_k_caches_read(j, el_size);
// inputs
ET_CHECK_MSG(
internal::set_tensor_data(
set_tensor_data(
*kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
"Failed to set input tensor when updating k_cache");
}
Expand All @@ -177,13 +193,13 @@ Result<exec_aten::Tensor> Runner::run_model_step(
char* new_inp_addr = io_mem_mgr_.update_v_caches_read(v_idx, v_offset);

ET_CHECK_MSG(
internal::set_tensor_data(
set_tensor_data(
*kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
"Failed to set input tensor when updating v_cache");
// outputs
char* new_out_addr = io_mem_mgr_.update_v_caches_write(v_idx, v_offset);
ET_CHECK_MSG(
internal::set_tensor_data(
set_tensor_data(
*kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
"Failed to set output tensor when updating v_cache");
ET_CHECK_MSG(
Expand All @@ -210,7 +226,7 @@ Error Runner::generate(

// First token time only measures the time it takes to encode the prompt and
// return a response token.
stats_.inference_start_ms = util::time_in_ms();
stats_.inference_start_ms = time_in_ms();
shouldStop_ = false;

// Set the sequence length to the max seq length if not provided
Expand All @@ -235,21 +251,21 @@ Error Runner::generate(
"Sequence length exceeded - please increase the seq_len value passed to generate()");

int32_t pos = 0, prev_token, cur_token = prompt_tokens[0];
std::vector<exec_aten::SizesType> token_shape = {1, 1};
std::vector<SizesType> token_shape = {1, 1};

io_mem_mgr_.get_input_token_ptr()[0] = 0;
std::vector<exec_aten::SizesType> start_pos_shape = {1, 1};
std::vector<SizesType> start_pos_shape = {1, 1};

float* atten_mask_ptr =
reinterpret_cast<float*>(io_mem_mgr_.get_atten_mask_ptr());
std::fill(atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255);
atten_mask_ptr[max_seq_len_ - 1] = 0;

std::vector<exec_aten::SizesType> atten_mask_shape = {1, max_seq_len_};
std::vector<SizesType> atten_mask_shape = {1, max_seq_len_};

std::vector<exec_aten::SizesType> logits_data_shape = {1, vocab_size_};
std::vector<SizesType> logits_data_shape = {1, vocab_size_};

std::vector<exec_aten::SizesType> hidden_states_data_shape = {1, 1, dim_};
std::vector<SizesType> hidden_states_data_shape = {1, 1, dim_};

// initialize tensor wrappers
auto token = from_blob(
Expand All @@ -274,7 +290,7 @@ Error Runner::generate(
method_meta->input_tensor_meta(input_index);

auto tensor_shape = tensor_meta->sizes();
std::vector<exec_aten::SizesType> sizes(
std::vector<SizesType> sizes(
tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
kv_tensors.emplace_back(from_blob(
io_mem_mgr_.get_k_caches_read_ptr(i),
Expand All @@ -284,7 +300,7 @@ Error Runner::generate(
// outpus
Result<TensorInfo> out_tensor_meta = method_meta->output_tensor_meta(i + 1);
tensor_shape = out_tensor_meta->sizes();
sizes = std::vector<exec_aten::SizesType>{
sizes = std::vector<SizesType>{
tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
kv_outputs.emplace_back(from_blob(
io_mem_mgr_.get_k_caches_write_ptr(i),
Expand All @@ -303,7 +319,7 @@ Error Runner::generate(
Result<TensorInfo> tensor_meta =
method_meta->input_tensor_meta(input_index);
auto tensor_shape = tensor_meta->sizes();
std::vector<exec_aten::SizesType> sizes(
std::vector<SizesType> sizes(
tensor_shape.data(), tensor_shape.data() + tensor_shape.size());

kv_tensors.emplace_back(from_blob(
Expand All @@ -315,7 +331,7 @@ Error Runner::generate(
Result<TensorInfo> out_tensor_meta =
method_meta->output_tensor_meta(output_index);
tensor_shape = out_tensor_meta->sizes();
sizes = std::vector<exec_aten::SizesType>{
sizes = std::vector<SizesType>{
tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};

kv_outputs.push_back(from_blob(
Expand All @@ -342,19 +358,18 @@ Error Runner::generate(
auto logits_res = run_model_step(
cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs);
if (pos == num_prompt_tokens) {
stats_.first_token_ms = util::time_in_ms();
stats_.first_token_ms = time_in_ms();
} else if (pos == num_prompt_tokens - 1) {
stats_.prompt_eval_end_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = time_in_ms();
}

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
exec_aten::Tensor& logits_tensor = logits_res.get();
Tensor& logits_tensor = logits_res.get();
prev_token = cur_token;
long sample_start_time_ms = util::time_in_ms();
long sample_start_time_ms = time_in_ms();

cur_token = logitsToToken<float>(logits_tensor);
stats_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;
stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;

// advance the state machine
if (pos < num_prompt_tokens - 1) {
Expand All @@ -381,7 +396,7 @@ Error Runner::generate(
break;
}
}
stats_.inference_end_ms = util::time_in_ms();
stats_.inference_end_ms = time_in_ms();

if (pos == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
Expand Down Expand Up @@ -650,5 +665,4 @@ template bool Runner::getMetadataHelper<bool>(
std::string method_name,
bool default_val);

} // namespace executor
} // namespace torch
} // namespace example
Loading

0 comments on commit 8ef6c79

Please sign in to comment.