Skip to content

Add high-level operator interface #708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../..
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <cassert>

namespace torchao {
namespace bitpacking {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
#include <cassert>
#include <cstring>

namespace torchao::kernels::cpu::aarch64::linear {
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::
Expand Down Expand Up @@ -251,7 +252,7 @@ int inline weight_data_size_impl(
}

// Replace n with next multiple of 4 >= n
n = ((n + 3) >> 2) << 2;
n = ((n + 3) / 4) * 4;

return col_size * n;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
#include <cassert>
#include <cstring>

namespace torchao::kernels::cpu::aarch64::linear {
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::
Expand Down Expand Up @@ -324,7 +325,7 @@ int inline weight_data_size_impl(
}

// Replace n with next multiple of 8 >= n
n = ((n + 3) >> 3) << 3;
n = ((n + 7) / 8) * 8;

return col_size * n;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../..
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once
#include <torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h>
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <cassert>
#include <functional>
#include <random>
#include <vector>
Expand Down
53 changes: 53 additions & 0 deletions torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

cmake_minimum_required(VERSION 3.19)
project(benchmarks)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_BUILD_TYPE Release)

include(FetchContent)
FetchContent_Declare(googlebenchmark
GIT_REPOSITORY https://github.com/google/benchmark.git
GIT_TAG main) # need main for benchmark::benchmark

set(BENCHMARK_ENABLE_TESTING OFF)
FetchContent_MakeAvailable(
googlebenchmark)

add_compile_options("-Wall" "-Werror")

include(CMakePrintHelpers)
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
include_directories(${TORCHAO_LIBRARIES})

add_library(
dep
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
)

add_executable(benchmark_linear_operator benchmark_linear_operator.cpp)
target_link_libraries(
benchmark_linear_operator
PRIVATE
benchmark::benchmark
dep
)

option(TORCHAO_PARALLEL_OMP "" OFF)
option(TORCHAO_PARALLEL_SINGLE_THREADED "" ON)

if (TORCHAO_PARALLEL_OMP)
message("OpenMP_ROOT: ${OpenMP_ROOT}")
add_definitions(-DTORCHAO_PARALLEL_OMP=1)
find_package(OpenMP REQUIRED)
if(OpenMP_CXX_FOUND)
target_link_libraries(benchmark_linear_operator PUBLIC OpenMP::OpenMP_CXX)
endif()
endif()

if (TORCHAO_PARALLEL_SINGLE_THREADED)
add_definitions(-DTORCHAO_PARALLEL_SINGLE_THREADED=1)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <benchmark/benchmark.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h>
#include <torchao/experimental/kernels/cpu/memory.h>
#include <vector>

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
static void channelwise_8bit_activation_groupwise_lowbit_weight(
benchmark::State& state) {
int m = state.range(0);
int n = state.range(1);
int k = state.range(2);
int group_size = state.range(3);
int num_threads = state.range(4);

// OMP appears to cache when repeating the same task in the benchmark
// To prevent this, we benchmark a number of tasks
int num_test_cases = state.range(5);

// Initialize config and tiling params
using namespace torchao::operators::cpu::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;

auto ukernel_config =
get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp>();
auto pack_weight_data_tiling_params =
get_default_pack_weight_data_tiling_params(ukernel_config, n);
auto linear_tiling_params =
get_default_linear_tiling_params(ukernel_config, m, n);
auto linear_scheduling_policy =
LinearTileSchedulingPolicy::single_mc_parallel_nc;

// Set number of threads
torchao::set_num_threads(num_threads);
assert(num_threads == torchao::get_num_threads());

// Generate test cases
std::vector<
torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case>
test_cases;
for (int i = 0; i < num_test_cases; ++i) {
test_cases.emplace_back(
torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case::
generate(
m,
k,
n,
group_size,
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp));
}

// Pack test case weights
size_t packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
size_t packed_weight_data_alignment =
get_packed_weight_data_alignment(ukernel_config);

std::vector<std::unique_ptr<char[], void (*)(void*)>> packed_weight_data;
for (int i = 0; i < test_cases.size(); i++) {
packed_weight_data.emplace_back(torchao::make_aligned_byte_array_unique_ptr(
packed_weight_data_alignment, packed_weight_data_size));
pack_weight_data_operator(
ukernel_config,
pack_weight_data_tiling_params,
packed_weight_data[i].get(),
n,
k,
group_size,
test_cases[i].weight_qvals.data(),
test_cases[i].weight_scales.data(),
test_cases[i].weight_zeros.data());
}

// Allocate activation data buffer for test cases
size_t activation_data_buffer_size = get_activation_data_buffer_size(
ukernel_config,
linear_tiling_params,
linear_scheduling_policy,
m,
k,
group_size);
size_t activation_data_buffer_alignment =
get_activation_data_buffer_alignment(ukernel_config);

auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr(
activation_data_buffer_alignment, activation_data_buffer_size);

auto output = std::vector<float>(m * n);
for (auto _ : state) {
for (int i = 0; i < test_cases.size(); i++) {
linear_operator(
ukernel_config,
linear_tiling_params,
linear_scheduling_policy,
activation_data_buffer.get(),
output.data(),
m,
n,
k,
group_size,
packed_weight_data[i].get(),
test_cases[i].activations.data(),
test_cases[i].bias.data(),
test_cases[i].clamp_min,
test_cases[i].clamp_max);
}
}
}

#define BENCHMARK_PARAMS \
{ \
/*m*/ {1}, /*n*/ {4096}, /*k*/ {4096}, /*group_size*/ {16, 32, 256}, \
/*num_threads*/ {1, 2, 4, 6, 8}, /*num_test_cases*/ { \
10 \
} \
}

#define BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT( \
weight_nbit) \
BENCHMARK(channelwise_8bit_activation_groupwise_lowbit_weight< \
weight_nbit, \
false /*has_weight_zeros*/, \
false /*has_bias*/, \
false /*has_clamp*/>) \
->ArgsProduct(BENCHMARK_PARAMS) \
->ArgNames( \
{"m", "n", "k", "group_size", "num_threads", "num_test_cases"});

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT(3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT(4);

// Run the benchmark
BENCHMARK_MAIN();
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/benchmarks \
-B ${CMAKE_OUT} \
-DOpenMP_ROOT=$(brew --prefix libomp) \
-DTORCHAO_PARALLEL_OMP=ON

cmake --build ${CMAKE_OUT}

# Run
case "$1" in
linear_operator) ${CMAKE_OUT}/benchmark_linear_operator; ;;
*) echo "Unknown benchmark: $1. Please specify one of: linear_operator."; exit 1; ;;
esac
Loading
Loading