From 8ef6c79d55f5cec9c39e5e82008e47820773feca Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Wed, 18 Sep 2024 14:53:08 -0700 Subject: [PATCH] Move examples/qualcomm out from under the torch namespace (#5400) Summary: The code under examples/... is a proxy for user code, and users should never declare code under the `torch::` or `executorch::` namespaces. Move this code under the `example::` namespace to make it more clear that users should use their own namespaces when writing code like this. Pull Request resolved: https://github.com/pytorch/executorch/pull/5400 Test Plan: - Built using the instructions at https://github.com/pytorch/executorch/blob/main/examples/qualcomm/README.md - test-llama-runner-qnn-linux CI job succeeded Reviewed By: shoumikhin Differential Revision: D62969111 Pulled By: dbort fbshipit-source-id: 9ec27528dd85f60d8c538d54ce6ddf621e63cf52 --- .../executor_runner/qnn_executor_runner.cpp | 33 +++++--- .../oss_scripts/llama2/qnn_llama_runner.cpp | 11 +-- .../oss_scripts/llama2/runner/runner.cpp | 72 +++++++++++------- .../oss_scripts/llama2/runner/runner.h | 48 ++++++------ .../llama/llama2/qaihub_llama2_7b_runner.cpp | 4 +- .../llama/llama3/qaihub_llama3_8b_runner.cpp | 4 +- .../qaihub_scripts/llama/runner/io_memory.cpp | 17 +++-- .../qaihub_scripts/llama/runner/io_memory.h | 76 ++++++++++--------- .../qaihub_scripts/llama/runner/runner.cpp | 34 +++++---- .../qaihub_scripts/llama/runner/runner.h | 24 +++--- .../qaihub_stable_diffusion_runner.cpp | 7 +- .../stable_diffusion/runner/runner.cpp | 58 +++++++------- .../stable_diffusion/runner/runner.h | 21 +++-- 13 files changed, 228 insertions(+), 181 deletions(-) diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index f901304136..7235e36681 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -71,9 +71,24 @@ DEFINE_int32( 20000000, // 20MB "Size of the debug buffer in bytes to allocate for intermediate outputs and program outputs logging."); -using namespace torch::executor; -using torch::executor::MemoryAllocator; -using torch::executor::util::FileDataLoader; +using executorch::aten::Tensor; +using executorch::aten::TensorImpl; +using executorch::etdump::ETDumpGen; +using executorch::etdump::ETDumpResult; +using executorch::extension::FileDataLoader; +using executorch::extension::prepare_input_tensors; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::EventTracerDebugLogLevel; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::TensorInfo; class CustomMemory { public: @@ -112,7 +127,7 @@ class CustomMemory { }; int main(int argc, char** argv) { - runtime_init(); + executorch::runtime::runtime_init(); gflags::ParseCommandLineFlags(&argc, &argv, true); if (argc != 1) { @@ -211,7 +226,7 @@ int main(int argc, char** argv) { // the method can mutate the memory-planned buffers, so the method should only // be used by a single thread at at time, but it can be reused. // - torch::executor::ETDumpGen etdump_gen = torch::executor::ETDumpGen(); + ETDumpGen etdump_gen; Result method = program->load_method(method_name, &memory_manager, &etdump_gen); ET_CHECK_MSG( @@ -261,7 +276,7 @@ int main(int argc, char** argv) { } for (int output_index = 0; output_index < method->outputs_size(); ++output_index) { - const exec_aten::Tensor& t = method->get_output(output_index).toTensor(); + const Tensor& t = method->get_output(output_index).toTensor(); out_custom_mem.push_back( std::make_unique(FLAGS_shared_buffer)); std::unique_ptr& custom_mem_ptr = out_custom_mem.back(); @@ -415,7 +430,7 @@ int main(int argc, char** argv) { elapsed_time / inference_index); } else { // if no input is provided, fill the inputs with default values - auto inputs = util::prepare_input_tensors(*method); + auto inputs = prepare_input_tensors(*method); ET_CHECK_MSG( inputs.ok(), "Could not prepare inputs: 0x%" PRIx32, @@ -434,7 +449,7 @@ int main(int argc, char** argv) { // Dump the etdump data containing profiling/debugging data to the specified // file. - etdump_result result = etdump_gen.get_etdump_data(); + ETDumpResult result = etdump_gen.get_etdump_data(); if (result.buf != nullptr && result.size > 0) { ET_LOG( Info, @@ -452,7 +467,7 @@ int main(int argc, char** argv) { Info, "Write debug output binary to %s, Size = %zu", FLAGS_debug_output_path.c_str(), - FLAGS_debug_buffer_size); + (size_t)FLAGS_debug_buffer_size); FILE* f = fopen(FLAGS_debug_output_path.c_str(), "w+"); fwrite((uint8_t*)debug_buffer, 1, FLAGS_debug_buffer_size, f); fclose(f); diff --git a/examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp index 599accfd1e..1e46f919dc 100644 --- a/examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp @@ -23,8 +23,6 @@ #include #include -using torch::executor::MemoryAllocator; - DEFINE_string( model_path, "qnn_llama2.pte", @@ -49,9 +47,12 @@ DEFINE_int32( 128, "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); -int main(int argc, char** argv) { - using namespace torch::executor; +using executorch::runtime::Error; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); @@ -60,7 +61,7 @@ int main(int argc, char** argv) { int32_t seq_len = FLAGS_seq_len; // create llama runner - Runner runner(FLAGS_model_path, tokenizer_path, temperature); + example::Runner runner(FLAGS_model_path, tokenizer_path, temperature); ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); // MethodMeta describes the memory requirements of the method. diff --git a/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp index 59025f6567..2d37748fbe 100644 --- a/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp @@ -22,11 +22,27 @@ #include #include -namespace torch { -namespace executor { +using executorch::aten::ScalarType; +using executorch::aten::SizesType; +using executorch::aten::Tensor; +using executorch::extension::from_blob; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::extension::llm::BPETokenizer; +using executorch::extension::llm::Sampler; +using executorch::extension::llm::time_in_ms; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +using executorch::runtime::TensorInfo; + +// TODO: Remove this usage of an internal-only function. +using executorch::runtime::internal::set_tensor_data; + +namespace example { namespace { -using namespace executorch::extension; static constexpr auto kTopp = 0.9f; void printReport(const Runner::Stats& stats); std::string statsToJsonString(const Runner::Stats& stats); @@ -57,7 +73,7 @@ Error Runner::load() { if (is_loaded()) { return Error::Ok; } - stats_.model_load_start_ms = util::time_in_ms(); + stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); // Read out metadata from the model @@ -97,7 +113,7 @@ Error Runner::load() { temperature_, kTopp, static_cast(std::time(nullptr))); - stats_.model_load_end_ms = util::time_in_ms(); + stats_.model_load_end_ms = time_in_ms(); return Error::Ok; } @@ -125,7 +141,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) { } template -int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { +int32_t Runner::logitsToToken(const Tensor& logits_tensor) { T* logits = logits_tensor.mutable_data_ptr(); // Since the logits are for all tokens, get the last token probabilities @@ -135,7 +151,7 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. -Result Runner::run_model_step( +Result Runner::run_model_step( int64_t input_token, TensorPtr& token, TensorPtr& start_pos, @@ -167,7 +183,7 @@ Result Runner::run_model_step( char* new_inp_addr = io_mem_mgr_.update_k_caches_read(j, el_size); // inputs ET_CHECK_MSG( - internal::set_tensor_data( + set_tensor_data( *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok, "Failed to set input tensor when updating k_cache"); } @@ -177,13 +193,13 @@ Result Runner::run_model_step( char* new_inp_addr = io_mem_mgr_.update_v_caches_read(v_idx, v_offset); ET_CHECK_MSG( - internal::set_tensor_data( + set_tensor_data( *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok, "Failed to set input tensor when updating v_cache"); // outputs char* new_out_addr = io_mem_mgr_.update_v_caches_write(v_idx, v_offset); ET_CHECK_MSG( - internal::set_tensor_data( + set_tensor_data( *kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok, "Failed to set output tensor when updating v_cache"); ET_CHECK_MSG( @@ -210,7 +226,7 @@ Error Runner::generate( // First token time only measures the time it takes to encode the prompt and // return a response token. - stats_.inference_start_ms = util::time_in_ms(); + stats_.inference_start_ms = time_in_ms(); shouldStop_ = false; // Set the sequence length to the max seq length if not provided @@ -235,21 +251,21 @@ Error Runner::generate( "Sequence length exceeded - please increase the seq_len value passed to generate()"); int32_t pos = 0, prev_token, cur_token = prompt_tokens[0]; - std::vector token_shape = {1, 1}; + std::vector token_shape = {1, 1}; io_mem_mgr_.get_input_token_ptr()[0] = 0; - std::vector start_pos_shape = {1, 1}; + std::vector start_pos_shape = {1, 1}; float* atten_mask_ptr = reinterpret_cast(io_mem_mgr_.get_atten_mask_ptr()); std::fill(atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255); atten_mask_ptr[max_seq_len_ - 1] = 0; - std::vector atten_mask_shape = {1, max_seq_len_}; + std::vector atten_mask_shape = {1, max_seq_len_}; - std::vector logits_data_shape = {1, vocab_size_}; + std::vector logits_data_shape = {1, vocab_size_}; - std::vector hidden_states_data_shape = {1, 1, dim_}; + std::vector hidden_states_data_shape = {1, 1, dim_}; // initialize tensor wrappers auto token = from_blob( @@ -274,7 +290,7 @@ Error Runner::generate( method_meta->input_tensor_meta(input_index); auto tensor_shape = tensor_meta->sizes(); - std::vector sizes( + std::vector sizes( tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); kv_tensors.emplace_back(from_blob( io_mem_mgr_.get_k_caches_read_ptr(i), @@ -284,7 +300,7 @@ Error Runner::generate( // outpus Result out_tensor_meta = method_meta->output_tensor_meta(i + 1); tensor_shape = out_tensor_meta->sizes(); - sizes = std::vector{ + sizes = std::vector{ tensor_shape.data(), tensor_shape.data() + tensor_shape.size()}; kv_outputs.emplace_back(from_blob( io_mem_mgr_.get_k_caches_write_ptr(i), @@ -303,7 +319,7 @@ Error Runner::generate( Result tensor_meta = method_meta->input_tensor_meta(input_index); auto tensor_shape = tensor_meta->sizes(); - std::vector sizes( + std::vector sizes( tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); kv_tensors.emplace_back(from_blob( @@ -315,7 +331,7 @@ Error Runner::generate( Result out_tensor_meta = method_meta->output_tensor_meta(output_index); tensor_shape = out_tensor_meta->sizes(); - sizes = std::vector{ + sizes = std::vector{ tensor_shape.data(), tensor_shape.data() + tensor_shape.size()}; kv_outputs.push_back(from_blob( @@ -342,19 +358,18 @@ Error Runner::generate( auto logits_res = run_model_step( cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs); if (pos == num_prompt_tokens) { - stats_.first_token_ms = util::time_in_ms(); + stats_.first_token_ms = time_in_ms(); } else if (pos == num_prompt_tokens - 1) { - stats_.prompt_eval_end_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); } ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); - exec_aten::Tensor& logits_tensor = logits_res.get(); + Tensor& logits_tensor = logits_res.get(); prev_token = cur_token; - long sample_start_time_ms = util::time_in_ms(); + long sample_start_time_ms = time_in_ms(); cur_token = logitsToToken(logits_tensor); - stats_.aggregate_sampling_time_ms += - util::time_in_ms() - sample_start_time_ms; + stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; // advance the state machine if (pos < num_prompt_tokens - 1) { @@ -381,7 +396,7 @@ Error Runner::generate( break; } } - stats_.inference_end_ms = util::time_in_ms(); + stats_.inference_end_ms = time_in_ms(); if (pos == seq_len) { ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len); @@ -650,5 +665,4 @@ template bool Runner::getMetadataHelper( std::string method_name, bool default_val); -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/oss_scripts/llama2/runner/runner.h b/examples/qualcomm/oss_scripts/llama2/runner/runner.h index 4f7482b4c0..700cb94f52 100644 --- a/examples/qualcomm/oss_scripts/llama2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama2/runner/runner.h @@ -106,21 +106,20 @@ class RpcMemAllocator { return reinterpret_cast(ptr_) + name##_pos_[idx]; \ } -namespace torch { -namespace executor { +namespace example { class IoMemMgr { public: // Allocate a big memory which is capable to contain all IO of all modules IoMemMgr(){}; - IoMemMgr(MethodMeta method_meta); + IoMemMgr(executorch::runtime::MethodMeta method_meta); struct InfoAttrs { - std::unique_ptr tensor_meta; + std::unique_ptr tensor_meta; size_t size = 0; std::vector shape; uint32_t rank; size_t element_size; - exec_aten::ScalarType dtype; + executorch::aten::ScalarType dtype; }; struct IoInfo { @@ -186,15 +185,16 @@ class IoMemMgr { std::vector v_caches_write_pos_; IoInfo io_info_; - std::unique_ptr method_meta_; + std::unique_ptr method_meta_; RpcMemAllocator rpc_mem_allocator{QnnMemDescriptor::kCustom}; - std::unordered_map scalar_type_to_size = { - {ScalarType::Int, sizeof(int32_t)}, - {ScalarType::Float, sizeof(float)}, - {ScalarType::Char, sizeof(int8_t)}, - {ScalarType::Short, sizeof(int16_t)}, - {ScalarType::Byte, sizeof(uint8_t)}, - {ScalarType::Bits16, sizeof(uint16_t)}, + std::unordered_map scalar_type_to_size = + { + {executorch::aten::ScalarType::Int, sizeof(int32_t)}, + {executorch::aten::ScalarType::Float, sizeof(float)}, + {executorch::aten::ScalarType::Char, sizeof(int8_t)}, + {executorch::aten::ScalarType::Short, sizeof(int16_t)}, + {executorch::aten::ScalarType::Byte, sizeof(uint8_t)}, + {executorch::aten::ScalarType::Bits16, sizeof(uint16_t)}, }; }; @@ -232,23 +232,24 @@ class Runner { }; bool is_loaded() const; - Error load(); - Error mem_alloc(size_t alignment, size_t seq_len); - Error generate( + executorch::runtime::Error load(); + executorch::runtime::Error mem_alloc(size_t alignment, size_t seq_len); + executorch::runtime::Error generate( const std::string& prompt, int32_t seq_len, std::function token_callback = {}, std::function stats_callback = {}); void stop(); - Result get_method_meta(); + executorch::runtime::Result + get_method_meta(); private: // metadata template T getMetadataHelper(std::string method_name, T default_val); template - int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); - Result run_model_step( + int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor); + executorch::runtime::Result run_model_step( int64_t input_token, ::executorch::extension::TensorPtr& token, ::executorch::extension::TensorPtr& start_pos, @@ -265,16 +266,15 @@ class Runner { int32_t head_dim_; int32_t dim_; std::unordered_set model_methods_; - std::unique_ptr module_; + std::unique_ptr module_; std::string tokenizer_path_; std::string model_path_; float temperature_; - std::unique_ptr tokenizer_; - std::unique_ptr sampler_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; bool shouldStop_{false}; Stats stats_; IoMemMgr io_mem_mgr_; }; -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp b/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp index d69aa0aa7a..3de97cde7e 100644 --- a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp +++ b/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp @@ -49,8 +49,6 @@ DEFINE_double(logits_scale, 0.0, "Path to logits scale file"); DEFINE_int32(logits_offset, 0, "Path to logits offset file"); int main(int argc, char** argv) { - using namespace torch::executor; - gflags::ParseCommandLineFlags(&argc, &argv, true); std::vector models_path = { @@ -62,7 +60,7 @@ int main(int argc, char** argv) { FLAGS_freq_cos_path, FLAGS_freq_sin_path}; // create llama runner - Runner runner( + example::Runner runner( models_path, pos_embs_path, {8, 8, 8, 8}, diff --git a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp b/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp index 9d06e8118d..7591b7ae1e 100644 --- a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp +++ b/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp @@ -54,8 +54,6 @@ DEFINE_double(logits_scale, 0.0, "Path to logits scale file"); DEFINE_int32(logits_offset, 0, "Path to logits offset file"); int main(int argc, char** argv) { - using namespace torch::executor; - gflags::ParseCommandLineFlags(&argc, &argv, true); std::vector models_path = { @@ -68,7 +66,7 @@ int main(int argc, char** argv) { FLAGS_freq_cos_path, FLAGS_freq_sin_path}; // create llama runner - Runner runner( + example::Runner runner( models_path, pos_embs_path, {4, 8, 8, 8, 4}, diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp b/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp index bdcd9ea014..9dc1ee7e25 100644 --- a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp +++ b/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp @@ -6,13 +6,21 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include -namespace torch { -namespace executor { +using executorch::aten::Tensor; +using executorch::aten::TensorImpl; +using executorch::extension::Module; +using executorch::runtime::Error; +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +using executorch::runtime::TensorInfo; + +namespace example { Memory::Memory( const std::vector& pos_embs_path, @@ -476,7 +484,7 @@ void KVCachedMemory::update_io( ThreadPool::ThreadPool() : stop_(false) { size_t hc = (std::thread::hardware_concurrency() + 3) / 4; // maximum number should be divisible by head dimension which equals to 32 - num_workers_ = min(32, hc * 4); + num_workers_ = std::min(32, hc * 4); for (size_t i = 0; i < num_workers_; ++i) { threads_.emplace_back([this]() { while (1) { @@ -520,5 +528,4 @@ size_t ThreadPool::num_workers() { return num_workers_; } -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h b/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h index df64bf8263..4ad7264cc9 100644 --- a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h +++ b/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h @@ -26,44 +26,47 @@ #define QAIHUB_LLAMA_LOGITS 32000 #endif -namespace torch { -namespace executor { +namespace example { class Memory { public: Memory( const std::vector& pos_embs_path, - std::vector>& modules); + std::vector>& modules); virtual ~Memory(); virtual void prepare_io( - const std::vector>& methods_meta) = 0; + const std::vector< + executorch::runtime::Result>& + methods_meta) = 0; virtual void update_io( int64_t cur_token, int64_t pos, - std::vector>& output_tensors) = 0; + std::vector>& output_tensors) = 0; void* get_mutable_ptr(); - std::vector get_input_tensors(int shard_index); - std::vector get_output_tensors(int shard_index); + std::vector get_input_tensors(int shard_index); + std::vector get_output_tensors(int shard_index); protected: std::unique_ptr data_ptr_; - std::vector> input_tensors_; - std::vector> output_tensors_; + std::vector> input_tensors_; + std::vector> output_tensors_; std::vector pos_embs_path_; - std::vector> modules_; + std::vector> modules_; }; class BertMemory : public Memory { public: BertMemory( const std::vector& pos_embs_path, - std::vector>& modules, + std::vector>& modules, std::vector shard_layers); - void prepare_io(const std::vector>& methods_meta) override; + void prepare_io(const std::vector>& methods_meta) override; void update_io( int64_t cur_token, int64_t pos, - std::vector>& output_tensors) override; + std::vector>& output_tensors) + override; struct IO { int32_t input_ids[1024 * 2]; uint16_t hidden_state[1024 * 4096]; @@ -76,14 +79,14 @@ class BertMemory : public Memory { }; private: - std::unique_ptr input_ids_; - std::unique_ptr hidden_state_; - std::unique_ptr attention_mask_; - std::unique_ptr position_ids_cos_; - std::unique_ptr position_ids_sin_; - std::vector> k_cache_; - std::vector> v_cache_; - std::unique_ptr logits_; + std::unique_ptr input_ids_; + std::unique_ptr hidden_state_; + std::unique_ptr attention_mask_; + std::unique_ptr position_ids_cos_; + std::unique_ptr position_ids_sin_; + std::vector> k_cache_; + std::vector> v_cache_; + std::unique_ptr logits_; std::vector shard_layers_; int num_heads_; }; @@ -117,13 +120,15 @@ class KVCachedMemory : public Memory { public: KVCachedMemory( const std::vector& pos_embs_path, - std::vector>& modules, + std::vector>& modules, std::vector shard_layers); - void prepare_io(const std::vector>& methods_meta) override; + void prepare_io(const std::vector>& methods_meta) override; void update_io( int64_t cur_token, int64_t pos, - std::vector>& output_tensors) override; + std::vector>& output_tensors) + override; struct IO { int32_t input_ids; uint16_t hidden_state[4096]; @@ -142,16 +147,16 @@ class KVCachedMemory : public Memory { }; private: - std::unique_ptr input_ids_; - std::unique_ptr hidden_state_; - std::unique_ptr attention_mask_; - std::unique_ptr position_ids_cos_; - std::unique_ptr position_ids_sin_; - std::vector> k_cache_in_; - std::vector> v_cache_in_; - std::vector> k_cache_out_; - std::vector> v_cache_out_; - std::unique_ptr logits_; + std::unique_ptr input_ids_; + std::unique_ptr hidden_state_; + std::unique_ptr attention_mask_; + std::unique_ptr position_ids_cos_; + std::unique_ptr position_ids_sin_; + std::vector> k_cache_in_; + std::vector> v_cache_in_; + std::vector> k_cache_out_; + std::vector> v_cache_out_; + std::unique_ptr logits_; std::vector lr_update_kv_; std::vector> futures_; ThreadPool thread_pool_; @@ -159,5 +164,4 @@ class KVCachedMemory : public Memory { int num_heads_; }; -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp b/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp index d150d1a04e..721c16209c 100644 --- a/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp @@ -29,8 +29,16 @@ #include "arm_neon.h" #endif -namespace torch { -namespace executor { +using executorch::aten::Tensor; +using executorch::extension::Module; +using executorch::extension::llm::Sampler; +using executorch::extension::llm::time_in_ms; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; + +namespace example { namespace { static constexpr auto kTopp = 0.9f; @@ -71,7 +79,7 @@ Runner::Runner( eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); version_ = LlamaVersion::kLlama3; #else - tokenizer_ = std::make_unique(); + tokenizer_ = std::make_unique(); tokenizer_->load(tokenizer_path_); version_ = LlamaVersion::kLlama2; #endif @@ -170,7 +178,7 @@ Error Runner::generate( std::vector> input_tensors, output_tensors; std::vector> inputs; if (!is_loaded()) { - stats_.model_load_start_ms = util::time_in_ms(); + stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); for (int i = 0; i < modules_.size(); ++i) { input_tensors.emplace_back(io_mem_->get_input_tensors(i)); @@ -185,10 +193,10 @@ Error Runner::generate( inputs.emplace_back( std::vector(begin(input_tensors[i]), end(input_tensors[i]))); } - stats_.model_load_end_ms = util::time_in_ms(); + stats_.model_load_end_ms = time_in_ms(); } - stats_.inference_start_ms = util::time_in_ms(); + stats_.inference_start_ms = time_in_ms(); seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; std::string post_process_prompt; @@ -275,16 +283,15 @@ Error Runner::generate( Tensor& logits_tensor = output_tensors.back().back(); if (pos == num_prompt_tokens) { - stats_.first_token_ms = util::time_in_ms(); + stats_.first_token_ms = time_in_ms(); } else if (pos == num_prompt_tokens - 1) { - stats_.prompt_eval_end_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); } - long sample_start_time_ms = util::time_in_ms(); + long sample_start_time_ms = time_in_ms(); prev_token = cur_token; cur_token = logitsToToken(logits_tensor); - stats_.aggregate_sampling_time_ms += - util::time_in_ms() - sample_start_time_ms; + stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; if (pos < num_prompt_tokens - 1) { cur_token = prompt_tokens[pos + 1]; @@ -303,7 +310,7 @@ Error Runner::generate( break; } } - stats_.inference_end_ms = util::time_in_ms(); + stats_.inference_end_ms = time_in_ms(); if (pos == seq_len) { ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); @@ -405,5 +412,4 @@ std::vector> Runner::get_methods_meta() { } return methods_meta; } -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/runner.h b/examples/qualcomm/qaihub_scripts/llama/runner/runner.h index bd24ea6beb..0d15114bc6 100644 --- a/examples/qualcomm/qaihub_scripts/llama/runner/runner.h +++ b/examples/qualcomm/qaihub_scripts/llama/runner/runner.h @@ -22,8 +22,7 @@ #include #include -namespace torch { -namespace executor { +namespace example { class Runner { public: @@ -64,15 +63,16 @@ class Runner { }; bool is_loaded() const; - Error load(); - Error generate( + executorch::runtime::Error load(); + executorch::runtime::Error generate( const std::string& prompt, const std::string& system_prompt, int32_t seq_len, std::function token_callback = {}, std::function stats_callback = {}); void stop(); - std::vector> get_methods_meta(); + std::vector> + get_methods_meta(); private: enum EvalMode { @@ -86,8 +86,9 @@ class Runner { kLlama3, }; - int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); - void run_model_step(std::vector>& inputs); + int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor); + void run_model_step( + std::vector>& inputs); // metadata int32_t bos_id_; std::unordered_set eos_id_; @@ -96,11 +97,11 @@ class Runner { const int32_t vocab_size_; const int32_t max_seq_len_; int32_t eval_mode_; - std::vector> modules_; + std::vector> modules_; std::string tokenizer_path_; float temperature_; - std::unique_ptr tokenizer_; - std::unique_ptr sampler_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; Stats stats_; std::unique_ptr io_mem_; const float logits_scale_; @@ -108,5 +109,4 @@ class Runner { LlamaVersion version_; }; -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp index 687a260c4a..9c15ceadf8 100644 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp @@ -66,9 +66,10 @@ void usage_message() { gflags::SetUsageMessage(usage_message); } +using executorch::runtime::Error; + int main(int argc, char** argv) { - using namespace torch::executor; - runtime_init(); + executorch::runtime::runtime_init(); usage_message(); gflags::ParseCommandLineFlags(&argc, &argv, true); bool is_default = @@ -101,7 +102,7 @@ int main(int argc, char** argv) { FLAGS_text_encoder_path, FLAGS_unet_path, FLAGS_vae_path}; // Create stable_diffusion_runner - Runner runner( + example::Runner runner( models_path, FLAGS_num_time_steps, FLAGS_guidance_scale, diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp index fddd3527b5..cc54a80173 100644 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp @@ -22,10 +22,15 @@ #include #include -using namespace ::executorch::extension; +using executorch::extension::from_blob; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::extension::llm::time_in_ms; +using executorch::runtime::Error; +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; -namespace torch { -namespace executor { +namespace example { Runner::Runner( const std::vector& models_path, @@ -88,11 +93,11 @@ Error Runner::load() { if (is_loaded()) { return Error::Ok; } - stats_.model_load_start_ms = util::time_in_ms(); + stats_.model_load_start_ms = time_in_ms(); for (auto& module : modules_) { ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward")); } - stats_.model_load_end_ms = util::time_in_ms(); + stats_.model_load_end_ms = time_in_ms(); return Error::Ok; } @@ -119,7 +124,7 @@ Error Runner::parse_input_list(std::string& path) { Error Runner::init_tokenizer(const std::string& vocab_json_path) { ET_LOG(Info, "Loading Tokenizer from json"); - stats_.tokenizer_load_start_ms = util::time_in_ms(); + stats_.tokenizer_load_start_ms = time_in_ms(); std::ifstream fin(vocab_json_path); auto update_map = [this](std::string& target, std::regex& re) { std::smatch sm; @@ -159,7 +164,7 @@ Error Runner::init_tokenizer(const std::string& vocab_json_path) { std::string target = text.substr(pos); update_map(target, re_pattern); } - stats_.tokenizer_load_end_ms = util::time_in_ms(); + stats_.tokenizer_load_end_ms = time_in_ms(); return Error::Ok; } @@ -338,15 +343,15 @@ void Runner::step( Error Runner::generate(std::string prompt) { ET_LOG(Info, "Start generating"); - stats_.generate_start_ms = util::time_in_ms(); + stats_.generate_start_ms = time_in_ms(); // Start tokenize - stats_.tokenizer_parsing_start_ms = util::time_in_ms(); + stats_.tokenizer_parsing_start_ms = time_in_ms(); std::vector cond_tokens = tokenize(prompt); cond_tokens.resize(max_tokens_); std::vector uncond_tokens = tokenize(""); uncond_tokens.resize(max_tokens_); - stats_.tokenizer_parsing_end_ms = util::time_in_ms(); + stats_.tokenizer_parsing_end_ms = time_in_ms(); std::vector> method_metas = get_methods_meta(); @@ -374,13 +379,13 @@ Error Runner::generate(std::string prompt) { {1, 77, 1024}, encoder_method_meta.output_tensor_meta(0)->scalar_type()); modules_[0]->set_output(cond_emb_tensor); - long encoder_start = util::time_in_ms(); + long encoder_start = time_in_ms(); auto cond_res = modules_[0]->forward(cond_tokens_tensor); - stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start); + stats_.text_encoder_execution_time += (time_in_ms() - encoder_start); modules_[0]->set_output(uncond_emb_tensor); - encoder_start = util::time_in_ms(); + encoder_start = time_in_ms(); auto uncond_res = modules_[0]->forward(uncond_tokens_tensor); - stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start); + stats_.text_encoder_execution_time += (time_in_ms() - encoder_start); // Initialize unet parameters MethodMeta unet_method_meta = method_metas[1].get(); @@ -451,7 +456,7 @@ Error Runner::generate(std::string prompt) { // Execute unet for (int step_index = 0; step_index < num_time_steps_; step_index++) { - long start_post_process = util::time_in_ms(); + long start_post_process = time_in_ms(); scale_model_input(latent, fp_latent_model_input, sigmas[step_index]); quant_tensor( @@ -461,24 +466,24 @@ Error Runner::generate(std::string prompt) { unet_input_latent_offset_); stats_.unet_aggregate_post_processing_time += - (util::time_in_ms() - start_post_process); + (time_in_ms() - start_post_process); modules_[1]->set_output(noise_pred_text_tensor); - long start_unet_execution = util::time_in_ms(); + long start_unet_execution = time_in_ms(); auto cond_res = modules_[1]->forward( {latent_tensor, time_emb_tensors[step_index], cond_emb_tensor}); stats_.unet_aggregate_execution_time += - (util::time_in_ms() - start_unet_execution); + (time_in_ms() - start_unet_execution); modules_[1]->set_output(noise_pred_uncond_tensor); - start_unet_execution = util::time_in_ms(); + start_unet_execution = time_in_ms(); auto uncond_res = modules_[1]->forward( {latent_tensor, time_emb_tensors[step_index], uncond_emb_tensor}); // results in noise_pred_uncond_vec stats_.unet_aggregate_execution_time += - (util::time_in_ms() - start_unet_execution); + (time_in_ms() - start_unet_execution); // start unet post processing - start_post_process = util::time_in_ms(); + start_post_process = time_in_ms(); dequant_tensor( noise_pred_text, @@ -497,7 +502,7 @@ Error Runner::generate(std::string prompt) { } step(fp_noise_pred_text, sigmas, latent, prev_sample, step_index); stats_.unet_aggregate_post_processing_time += - (util::time_in_ms() - start_post_process); + (time_in_ms() - start_post_process); } // Start VAE @@ -520,10 +525,10 @@ Error Runner::generate(std::string prompt) { quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_); modules_[2]->set_output(output_tensor); - long start_vae_execution = util::time_in_ms(); + long start_vae_execution = time_in_ms(); auto vae_res = modules_[2]->forward(vae_input_tensor); - stats_.vae_execution_time = (util::time_in_ms() - start_vae_execution); - stats_.generate_end_ms = util::time_in_ms(); + stats_.vae_execution_time = (time_in_ms() - start_vae_execution); + stats_.generate_end_ms = time_in_ms(); // Dequant uint16 output to fp32 output dequant_tensor(q_out, out, vae_output_scale_, vae_output_offset_); @@ -605,5 +610,4 @@ Error Runner::print_performance() { return Error::Ok; } -} // namespace executor -} // namespace torch +} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h index e081ab80cc..f91efd5b83 100644 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h @@ -17,8 +17,7 @@ #include -namespace torch { -namespace executor { +namespace example { class Runner { public: @@ -77,9 +76,9 @@ class Runner { }; bool is_loaded() const; - Error load(); - Error init_tokenizer(const std::string& vocab_json_path); - Error print_performance(); + executorch::runtime::Error load(); + executorch::runtime::Error init_tokenizer(const std::string& vocab_json_path); + executorch::runtime::Error print_performance(); std::vector tokenize(std::string prompt); std::vector gen_latent_from_file(); std::vector gen_random_latent(float sigma); @@ -89,15 +88,16 @@ class Runner { std::vector& sample, std::vector& prev_sample, int step_index); - std::vector> get_methods_meta(); + std::vector> + get_methods_meta(); std::vector get_time_steps(); std::vector get_sigmas(const std::vector& time_steps); void scale_model_input( const std::vector& vec, std::vector& latent_model_input, float sigma); - Error parse_input_list(std::string& path); - Error generate(std::string prompt); + executorch::runtime::Error parse_input_list(std::string& path); + executorch::runtime::Error generate(std::string prompt); void quant_tensor( const std::vector& fp_vec, std::vector& quant_vec, @@ -111,7 +111,7 @@ class Runner { private: Stats stats_; - std::vector> modules_; + std::vector> modules_; std::vector> time_emb_list_; std::unordered_map vocab_to_token_map_; @@ -137,5 +137,4 @@ class Runner { const bool fix_latents_ = false; }; -} // namespace executor -} // namespace torch +} // namespace example