diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index ae795b12ab..4e1cb99cc0 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -130,9 +130,9 @@ cleanup_files() { prepare_artifacts_upload() { if [ -n "$UPLOAD_DIR" ]; then echo "Preparing for uploading generated artifacs" + zip -j model.zip "${EXPORTED_MODEL_NAME}" tokenizer.bin mkdir -p "${UPLOAD_DIR}" - zip -j "model.zip" "${MODEL_NAME}" tokenizer.bin - cp "model.zip" "${UPLOAD_DIR}" + mv model.zip "${UPLOAD_DIR}" fi } diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index f34c944bc1..18f7f06d0b 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -82,11 +82,27 @@ jobs: - name: Set parameters id: set-parameters shell: bash + env: + # Separate default values from the workflow dispatch. To ensure defaults are accessible + # during scheduled runs and to provide flexibility for different defaults between + # on-demand and periodic benchmarking. + CRON_DEFAULT_MODELS: "stories110M" + CRON_DEFAULT_DEVICES: "samsung_galaxy_s2x" + CRON_DEFAULT_DELEGATES: "xnnpack" run: | set -ex MODELS="${{ inputs.models }}" + if [ -z "$MODELS" ]; then + MODELS="$CRON_DEFAULT_MODELS" + fi DEVICES="${{ inputs.devices }}" + if [ -z "$DEVICES" ]; then + DEVICES="$CRON_DEFAULT_DEVICES" + fi DELEGATES="${{ inputs.delegates }}" + if [ -z "$DELEGATES" ]; then + DELEGATES="$CRON_DEFAULT_DELEGATES" + fi # Mapping devices to their corresponding device-pool-arn declare -A DEVICE_POOL_ARNS diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index c60afc2dd3..353169bc18 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -16,6 +16,10 @@ exir_ops.edge.aten.copy.default, ] +to_be_implemented_operator = [ + exir_ops.edge.aten.where.default, +] + allow_list_operator = [ _operator.getitem, ] diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index c3afc23dae..86028d0d44 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -27,7 +27,11 @@ from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase -from .common_defs import allow_list_operator, not_supported_operator +from .common_defs import ( + allow_list_operator, + not_supported_operator, + to_be_implemented_operator, +) class QnnOperatorSupport(OperatorSupportBase): @@ -62,6 +66,12 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: if node.op != "call_function" or node.target in not_supported_operator: return False + if node.target in to_be_implemented_operator: + print( + f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped, this op can be supported, please report an issue in https://github.com/pytorch/executorch/issues" + ) + return False + if node.target in allow_list_operator: return True diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index e69a4937e5..b1a02a6d2e 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -277,6 +277,14 @@ class vTensor final { return sizes_.size(); } + inline const std::vector& strides() const { + return strides_; + } + + inline const std::vector& unsqueezed_strides() const { + return unsqueezed_strides_; + } + /* * Returns a GPU buffer containing the sizes of the tensor in WHCN order. * Note that dimensions that are not present in the tensor's sizes are set to diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.glsl index fe69501f9c..9d4b18f0d1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.glsl @@ -1,4 +1,3 @@ - #version 450 core #define PRECISION ${PRECISION} diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl new file mode 100644 index 0000000000..58796879e8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl @@ -0,0 +1,35 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +#include "indexing_utils.h" + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "nchw_buf", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_ubo(2, "ivec4", "in_sizes")} +${layout_declare_ubo(3, "ivec4", "in_strides")} +${layout_declare_ubo(4, "int", "numel")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This constant is unused in this shader but is kept so that the signature is +// consistent with image_to_nchw. +layout(constant_id = 3) const int UNUSED_packed_dim = W_DIM; + +void main() { + int out_id = int(gl_GlobalInvocationID.x); + if (out_id >= numel) { + return; + } + + ivec4 t_in_idx = from_nchw_buffer_i(out_id, in_sizes); + const int in_id = to_buffer_id(t_in_idx, in_strides); + + nchw_buf[out_id] = t_in[in_id]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml new file mode 100644 index 0000000000..653bda9ccc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +buffer_to_nchw: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int + - VALUE: int8 + shader_variants: + - NAME: buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index d3264e43a2..21eadff0b3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -41,6 +41,21 @@ */ #define alignup4(x) ((x + 3) & -4) +/* + * Input: (W, H, C, N) strides of a tensor + * Returns: the WHCN index of the fastest moving dimension + */ +int find_packed_dim(const ivec4 strides) { + int packed_dim = 0; + for (int i = 0; i <= 3; i++) { + if (strides[i] == 1) { + packed_dim = i; + break; + } + } + return packed_dim; +} + // // (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion // @@ -74,27 +89,49 @@ ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) { (buf_i / (sizes.x * sizes.y * sizes.z))); } +int to_nchw_buffer_i(const ivec4 tensor_idx, const ivec4 sizes) { + return tensor_idx.w * sizes.x * sizes.y * sizes.z + + tensor_idx.z * sizes.x * sizes.y + tensor_idx.y * sizes.x + tensor_idx.x; +} + /* * Input: Texel buffer index, (W, H, C, N) strides of a tensor, which dim is * packed along a texel - * Returns: The (x, y, z, n) texel position corresponding to the first element - * of the texel at the specified buffer index + * Returns: The (w, h, c, n) tensor index corresponding to the buffer element */ -ivec4 to_tensor_idx(int buf_i, ivec4 strides, int packed_dim) { +ivec4 to_tensor_idx(int buffer_id, const ivec4 strides, const int packed_dim) { ivec4 idx; for (int i = 3; i >= 0; i--) { if (i != packed_dim) { - idx[i] = buf_i / strides[i]; - buf_i %= strides[i]; + idx[i] = buffer_id / strides[i]; + buffer_id %= strides[i]; } } - idx[packed_dim] = buf_i; + idx[packed_dim] = buffer_id; return idx; } -int to_texel_idx(const ivec4 texel_pos, ivec4 strides) { - return texel_pos.x * strides.x + texel_pos.y * strides.y + - texel_pos.z * strides.z + texel_pos.w * strides.w; +/* + * Input: Texel buffer index, (W, H, C, N) strides of a tensor + * Returns: The (w, h, c, n) tensor index corresponding to the buffer element + * + * This is a convenience overload of the above function. If the packed dim is + * not known, it can be found by finding the first dimension with a stride of 1. + * However, this process adds some overhead, so if performance is a concern then + * the above function should be used instead so that the packed dim is provided. + */ +ivec4 to_tensor_idx(int buffer_id, const ivec4 strides) { + int packed_dim = find_packed_dim(strides); + return to_tensor_idx(buffer_id, strides, packed_dim); +} + +/* + * Input: (w, h, c, n) tensor index, (W, H, C, N) strides of the tensor buffer + * Returns: the buffer index corresponding to the specified tensor index + */ +int to_buffer_id(const ivec4 tensor_idx, ivec4 strides) { + return tensor_idx.x * strides.x + tensor_idx.y * strides.y + + tensor_idx.z * strides.z + tensor_idx.w * strides.w; } // diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl new file mode 100644 index 0000000000..d861972f93 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -0,0 +1,35 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +#include "indexing_utils.h" + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "nchw_in", DTYPE, STORAGE)} +${layout_declare_ubo(2, "ivec4", "out_sizes")} +${layout_declare_ubo(3, "ivec4", "out_strides")} +${layout_declare_ubo(4, "int", "numel")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This constant is unused in this shader but is kept so that the signature is +// consistent with nchw_to_image. +layout(constant_id = 3) const int UNUSED_packed_dim = W_DIM; + +void main() { + int out_id = int(gl_GlobalInvocationID.x); + if (out_id >= numel) { + return; + } + + ivec4 out_idx = to_tensor_idx(out_id, out_strides); + const int in_id = to_nchw_buffer_i(out_idx, out_sizes); + + t_out[out_id] = nchw_in[in_id]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml new file mode 100644 index 0000000000..6292ef9333 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +nchw_to_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int + - VALUE: int8 + shader_variants: + - NAME: nchw_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index b35d4b0175..b02613c208 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -26,7 +26,10 @@ void add_staging_to_tensor_node( vkapi::ParamsBindList ubos; if (graph.is_buffer_storage(out_tensor)) { - ubos.append(graph.numel_ubo(out_tensor)); + ubos.append( + {graph.sizes_ubo(out_tensor), + graph.strides_ubo(out_tensor), + graph.numel_ubo(out_tensor)}); } else { ubos.append(graph.sizes_ubo(out_tensor)); } @@ -61,7 +64,10 @@ void add_tensor_to_staging_node( vkapi::ParamsBindList ubos; if (graph.is_buffer_storage(in_tensor)) { - ubos.append(graph.numel_ubo(in_tensor)); + ubos.append( + {graph.sizes_ubo(in_tensor), + graph.strides_ubo(in_tensor), + graph.numel_ubo(in_tensor)}); } else { ubos.append(graph.sizes_ubo(in_tensor)); } @@ -105,7 +111,7 @@ ValueRef prepack( vkapi::ParamsBindList ubos; if (graph.is_buffer_storage(v)) { - ubos.append(graph.numel_ubo(v)); + ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)}); } else { ubos.append(graph.sizes_ubo(v)); } diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index daec2666f8..294e36b9a8 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -107,7 +107,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } if (v_dst.storage_type() == utils::kBuffer) { - kernel_name = "buffer_to_buffer"; + kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, v_dst); return VK_KERNEL_FROM_STR(kernel_name); } @@ -131,7 +131,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } if (v_src.storage_type() == utils::kBuffer) { - kernel_name = "buffer_to_buffer"; + kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, v_src); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 29cd7bf995..e6f2863470 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -23,15 +23,13 @@ void record_nchw_to_buffer_op( vkapi::VulkanBuffer& src_buffer, api::vTensor& v_dst) { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::SpecVarList specialization_constants = { - SV(v_dst.packed_dim_whcn_idx())}; context->submit_compute_job( get_nchw_to_tensor_shader(v_dst), pipeline_barrier, {uint32_t(v_dst.numel()), 1, 1}, {64, 1, 1}, - specialization_constants, + {}, VK_NULL_HANDLE, 0, v_dst.buffer( @@ -39,6 +37,8 @@ void record_nchw_to_buffer_op( vkapi::PipelineStage::COMPUTE, vkapi::MemoryAccessType::WRITE), src_buffer, + v_dst.sizes_ubo(), + v_dst.strides_ubo(), v_dst.numel_ubo()); } @@ -47,19 +47,18 @@ void record_buffer_to_nchw_op( api::vTensor& v_src, vkapi::VulkanBuffer& dst_buffer) { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::SpecVarList specialization_constants = { - SV(v_src.packed_dim_whcn_idx())}; - context->submit_compute_job( get_tensor_to_nchw_shader(v_src), pipeline_barrier, {uint32_t(v_src.numel()), 1, 1}, {64, 1, 1}, - specialization_constants, + {}, VK_NULL_HANDLE, 0, dst_buffer, v_src.buffer(pipeline_barrier, vkapi::PipelineStage::COMPUTE), + v_src.sizes_ubo(), + v_src.strides_ubo(), v_src.numel_ubo()); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml index bb231420df..02d8503a4d 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml @@ -47,6 +47,15 @@ + + + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java new file mode 100644 index 0000000000..33b230b1df --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.Activity; +import android.content.Intent; +import android.os.Bundle; +import android.util.Log; +import android.widget.TextView; +import androidx.annotation.NonNull; +import java.io.FileWriter; +import java.io.IOException; + +public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; + + String mPrompt; + TextView mTextView; + StatsDump mStatsDump; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_benchmarking); + mTextView = findViewById(R.id.log_view); + + Intent intent = getIntent(); + + String modelPath = intent.getStringExtra("model_path"); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + mStatsDump = new StatsDump(); + mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this); + mStatsDump.loadStart = System.currentTimeMillis(); + } + + @Override + public void onModelLoaded(int status) { + mStatsDump.loadEnd = System.currentTimeMillis(); + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsDump.generateStart = System.currentTimeMillis(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) { + runOnUiThread( + () -> { + mTextView.append(token); + }); + } + + @Override + public void onStats(String stats) { + mStatsDump.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsDump.generateEnd = System.currentTimeMillis(); + runOnUiThread( + () -> { + mTextView.append(mStatsDump.toString()); + }); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { + writer.write(mStatsDump.toString()); + } catch (IOException e) { + e.printStackTrace(); + } + } +} + +class StatsDump { + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + + @NonNull + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java new file mode 100644 index 0000000000..4dc32d1475 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import androidx.annotation.NonNull; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlamaCallback { + LlamaModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(float tps) { + mCallback.onStats("tokens/second: " + tps); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(@NonNull android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java new file mode 100644 index 0000000000..c8bdc53075 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +/** + * A helper interface within the app for MainActivity and Benchmarking to handle callback from + * ModelRunner. + */ +public interface ModelRunnerCallback { + + void onModelLoaded(int status); + + void onTokenGenerated(String token); + + void onStats(String token); + + void onGenerationStopped(); +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml new file mode 100644 index 0000000000..6e48b5de8b --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml @@ -0,0 +1,16 @@ + + + + + + diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index eeafa3dee3..56ca1f5873 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -553,27 +553,29 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 def _load_llama_model_metadata( weight_type: WeightType, - dtype: DType, use_kv_cache: bool, use_sdpa_with_kv_cache: bool, enable_dynamic_shape: bool, - modelArgs: ModelArgs, + model_args: ModelArgs, metadata_str: Optional[str] = None, ): is_fairseq2 = weight_type == WeightType.FAIRSEQ2 metadata = { "append_eos_to_prompt": is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt. - "get_bos_id": 3 if is_fairseq2 else 1, - "get_dtype": 5 if dtype == DType.fp16 else 6, - "get_eos_id": 3 if is_fairseq2 else 2, - "get_head_dim": modelArgs.dim // modelArgs.n_heads, - "get_max_batch_size": modelArgs.max_batch_size, - "get_max_seq_len": modelArgs.max_seq_len, + "get_bos_id": ( + model_args.bos_idx + if model_args.bos_idx is not None + else (3 if is_fairseq2 else 1) + ), + "get_eos_id": ( + model_args.eos_idx + if model_args.eos_idx is not None + else (3 if is_fairseq2 else 2) + ), + "get_max_seq_len": model_args.max_seq_len, "get_n_bos": 1, "get_n_eos": 2 if is_fairseq2 else 1, - "get_n_kv_heads": modelArgs.n_kv_heads, - "get_n_layers": modelArgs.n_layers, - "get_vocab_size": modelArgs.vocab_size, + "get_vocab_size": model_args.vocab_size, "use_kv_cache": use_kv_cache, "use_sdpa_with_kv_cache": use_sdpa_with_kv_cache, "enable_dynamic_shape": enable_dynamic_shape, @@ -655,7 +657,6 @@ def _load_llama_model( verbose=verbose, metadata=_load_llama_model_metadata( weight_type, - dtype, use_kv_cache, use_sdpa_with_kv_cache, enable_dynamic_shape, diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index dacf9eb1fd..99544426fd 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -104,8 +104,8 @@ class ModelArgs: rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. # Additional Model Metadata needed at runtime - bos_idx: int = 1 - eos_idx: int = 3 + bos_idx: Optional[int] = None + eos_idx: Optional[int] = None bos_count: int = -1 # i.e., a single EOS is used as BOS eos_count: int = 2 diff --git a/examples/models/llama2/runner/generation.py b/examples/models/llama2/runner/generation.py index 56a15005ef..404ff4717e 100644 --- a/examples/models/llama2/runner/generation.py +++ b/examples/models/llama2/runner/generation.py @@ -14,11 +14,7 @@ import torch.nn.functional as F from executorch.examples.models.llama2.llama_transformer import ModelArgs -from executorch.examples.models.llama2.tokenizer.tiktoken import ( - Dialog, - Message, - Tokenizer, -) +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer from executorch.extension.pybindings.portable_lib import _load_for_executorch @@ -28,12 +24,6 @@ class CompletionPrediction(TypedDict, total=False): logprobs: List[float] # not required -class ChatPrediction(TypedDict, total=False): - generation: Message - tokens: List[str] # not required - logprobs: List[float] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -225,72 +215,6 @@ def text_completion( ] return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] - def chat_completion( - self, - dialogs: List[Dialog], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - ) -> List[ChatPrediction]: - """ - Generate assistant responses for a list of conversational dialogs using the language generation model. - - Args: - dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. - temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. - top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. - If not provided, it's set to the model's maximum sequence length minus 1. - logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. - - Returns: - List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. - - Raises: - AssertionError: If the last message in a dialog is not from the user. - AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. - - Note: - This method generates assistant responses for the provided conversational dialogs. - It employs nucleus sampling to introduce controlled randomness in text generation. - If logprobs is True, token log probabilities are computed for each generated token. - """ - if max_gen_len is None: - max_gen_len = self.model.params.max_seq_len - 1 - - prompt_tokens = [ - self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs - ] - generation_tokens, generation_logprobs = self.generate( - prompt_tokens=prompt_tokens, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - ) - if logprobs: - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t), - }, - "tokens": [self.tokenizer.decode([x]) for x in t], - "logprobs": logprobs_i, - } - for t, logprobs_i in zip(generation_tokens, generation_logprobs) - ] - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t), - }, - } - for t in generation_tokens - ] - def build_args_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index a44b56d5d3..6bbbc05736 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -16,7 +16,7 @@ #include #endif /* ET_USE_TIKTOKEN*/ #include -#include +#include #include #include @@ -228,19 +228,19 @@ Error Runner::generate( tokens_managed.resize({1, static_cast(token_data.size())}); } - // print the token as string, decode it with the Tokenizer object - wrapped_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); - - if (shouldStop_) { - break; - } - // data-dependent terminating condition: we have n_eos_ number of EOS if (pos >= num_prompt_tokens && cur_token == eos_id_) { printf("\n"); ET_LOG(Info, "\nReached to the end of generation"); break; } + + // print the token as string, decode it with the Tokenizer object + wrapped_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); + + if (shouldStop_) { + break; + } } stats_.inference_end_ms = util::time_in_ms(); printf("\n"); diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 9800430b1f..2d0f1d5fe5 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -32,6 +32,7 @@ def define_common_targets(): ], exported_deps = [ "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/extension/llm/runner:metadata_util" + aten_suffix, "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix, "//executorch/extension/llm/runner:text_prefiller" + aten_suffix, diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index 7b041fef00..bf06d0c9be 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -49,7 +49,6 @@ static uint8_t* align_pointer(void* ptr, size_t alignment) { addr = (addr | (alignment - 1)) + 1; return reinterpret_cast(addr); } - } // namespace FileDataLoader::~FileDataLoader() { @@ -143,19 +142,6 @@ Result FileDataLoader::load( return FreeableBuffer(nullptr, 0, /*free_fn=*/nullptr); } - // Seek to the right place in the file. - off_t seek_offset = ::lseek(fd_, offset, SEEK_SET); - if (seek_offset != offset) { - ET_LOG( - Error, - "Seeking %s to offset %zu returned %zd: %s", - file_name_, - offset, - (ssize_t)seek_offset, - strerror(errno)); - return Error::AccessFailed; - } - // Allocate memory for the FreeableBuffer. size_t alloc_size = size; if (alignment_ > alignof(std::max_align_t)) { @@ -187,9 +173,75 @@ Result FileDataLoader::load( buffer, alloc_size); + auto err = load_into(offset, size, segment_info, aligned_buffer); + if (err != Error::Ok) { + // Free `buffer`, which is what malloc() gave us, not `aligned_buffer`. + std::free(buffer); + return err; + } + + // We can't naively free this pointer, since it may not be what malloc() gave + // us. Pass the offset to the real buffer as context. This is the number of + // bytes that need to be subtracted from the FreeableBuffer::data() pointer to + // find the actual pointer to free. + return FreeableBuffer( + aligned_buffer, + size, + FreeSegment, + /*free_fn_context=*/ + reinterpret_cast( + // Using signed types here because it will produce a signed ptrdiff_t + // value, though for us it will always be non-negative. + reinterpret_cast(aligned_buffer) - + reinterpret_cast(buffer))); +} + +Result FileDataLoader::size() const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + return file_size_; +} + +__ET_NODISCARD Error FileDataLoader::load_into( + size_t offset, + size_t size, + __ET_UNUSED const SegmentInfo& segment_info, + void* buffer) { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_name_, + offset, + size, + file_size_); + ET_CHECK_OR_RETURN_ERROR( + buffer != nullptr, InvalidArgument, "Provided buffer cannot be null"); + + // Seek to the right place in the file. + off_t seek_offset = ::lseek(fd_, offset, SEEK_SET); + if (seek_offset != offset) { + ET_LOG( + Error, + "Seeking %s to offset %zu returned %zd: %s", + file_name_, + offset, + (ssize_t)seek_offset, + strerror(errno)); + return Error::AccessFailed; + } + // Read the data into the aligned address. size_t needed = size; - uint8_t* buf = reinterpret_cast(aligned_buffer); + uint8_t* buf = reinterpret_cast(buffer); while (needed > 0) { // Reads on macos will fail with EINVAL if size > INT32_MAX. ssize_t nread = ::read( @@ -211,37 +263,12 @@ Result FileDataLoader::load( size, offset, nread == 0 ? "EOF" : strerror(errno)); - // Free `buffer`, which is what malloc() gave us, not `aligned_buffer`. - std::free(buffer); return Error::AccessFailed; } needed -= nread; buf += nread; } - - // We can't naively free this pointer, since it may not be what malloc() gave - // us. Pass the offset to the real buffer as context. This is the number of - // bytes that need to be subtracted from the FreeableBuffer::data() pointer to - // find the actual pointer to free. - return FreeableBuffer( - aligned_buffer, - size, - FreeSegment, - /*free_fn_context=*/ - reinterpret_cast( - // Using signed types here because it will produce a signed ptrdiff_t - // value, though for us it will always be non-negative. - reinterpret_cast(aligned_buffer) - - reinterpret_cast(buffer))); -} - -Result FileDataLoader::size() const { - ET_CHECK_OR_RETURN_ERROR( - // Probably had its value moved to another instance. - fd_ >= 0, - InvalidState, - "Uninitialized"); - return file_size_; + return Error::Ok; } } // namespace util diff --git a/extension/data_loader/file_data_loader.h b/extension/data_loader/file_data_loader.h index c6ab25933a..b7cfe3a1b9 100644 --- a/extension/data_loader/file_data_loader.h +++ b/extension/data_loader/file_data_loader.h @@ -72,6 +72,12 @@ class FileDataLoader : public DataLoader { __ET_NODISCARD Result size() const override; + __ET_NODISCARD Error load_into( + size_t offset, + size_t size, + __ET_UNUSED const SegmentInfo& segment_info, + void* buffer) override; + private: FileDataLoader( int fd, diff --git a/extension/module/metadata_util.h b/extension/llm/runner/metadata_util.h similarity index 100% rename from extension/module/metadata_util.h rename to extension/llm/runner/metadata_util.h diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 2e37547437..30241169ae 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -44,3 +44,14 @@ def define_common_targets(): "//executorch/extension/runner_util:managed_tensor" + aten_suffix, ], ) + + runtime.cxx_library( + name = "metadata_util" + aten_suffix, + exported_headers = ["metadata_util.h"], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/extension/module:module" + aten_suffix, + ], + ) diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index 07020b03a8..61251047dc 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -17,7 +17,6 @@ def define_common_targets(): ], exported_headers = [ "module.h", - "metadata_util.h", ], visibility = [ "@EXECUTORCH_CLIENTS",